mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
attach the QUIC version to context returned by ClientHelloInfo.Context (#3721)
This commit is contained in:
parent
11f493381f
commit
41ddaa0262
5 changed files with 106 additions and 176 deletions
|
@ -14,7 +14,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/handshake"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
|
@ -66,8 +65,9 @@ func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
|
|||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
// determine the ALPN from the QUIC version used
|
||||
proto := NextProtoH3
|
||||
if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok {
|
||||
proto = versionToALPN(qconn.GetQUICVersion())
|
||||
val := ch.Context().Value(quic.QUICVersionContextKey)
|
||||
if v, ok := val.(quic.VersionNumber); ok {
|
||||
proto = versionToALPN(v)
|
||||
}
|
||||
config := tlsConf
|
||||
if tlsConf.GetConfigForClient != nil {
|
||||
|
|
|
@ -28,19 +28,6 @@ import (
|
|||
gmtypes "github.com/onsi/gomega/types"
|
||||
)
|
||||
|
||||
type mockConn struct {
|
||||
net.Conn
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
func newMockConn(version protocol.VersionNumber) net.Conn {
|
||||
return &mockConn{version: version}
|
||||
}
|
||||
|
||||
func (c *mockConn) GetQUICVersion() protocol.VersionNumber {
|
||||
return c.version
|
||||
}
|
||||
|
||||
type mockAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
@ -940,31 +927,87 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
Context("ConfigureTLSConfig", func() {
|
||||
var tlsConf *tls.Config
|
||||
var ch *tls.ClientHelloInfo
|
||||
|
||||
BeforeEach(func() {
|
||||
tlsConf = &tls.Config{}
|
||||
ch = &tls.ClientHelloInfo{}
|
||||
})
|
||||
|
||||
It("advertises v1 by default", func() {
|
||||
tlsConf = ConfigureTLSConfig(tlsConf)
|
||||
Expect(tlsConf.GetConfigForClient).NotTo(BeNil())
|
||||
|
||||
config, err := tlsConf.GetConfigForClient(ch)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(config.NextProtos).To(Equal([]string{NextProtoH3}))
|
||||
conf := ConfigureTLSConfig(testdata.GetTLSConfig())
|
||||
ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.Version1}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.CloseWithError(0, "")
|
||||
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
|
||||
})
|
||||
|
||||
It("advertises h3-29 for draft-29", func() {
|
||||
tlsConf = ConfigureTLSConfig(tlsConf)
|
||||
Expect(tlsConf.GetConfigForClient).NotTo(BeNil())
|
||||
conf := ConfigureTLSConfig(testdata.GetTLSConfig())
|
||||
ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.VersionDraft29}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3Draft29}}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.CloseWithError(0, "")
|
||||
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3Draft29))
|
||||
})
|
||||
|
||||
ch.Conn = newMockConn(protocol.VersionDraft29)
|
||||
config, err := tlsConf.GetConfigForClient(ch)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(config.NextProtos).To(Equal([]string{NextProtoH3Draft29}))
|
||||
It("sets the GetConfigForClient callback if no tls.Config is given", func() {
|
||||
var receivedConf *tls.Config
|
||||
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
|
||||
receivedConf = tlsConf
|
||||
return nil, errors.New("listen err")
|
||||
}
|
||||
Expect(s.ListenAndServe()).To(HaveOccurred())
|
||||
Expect(receivedConf).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() {
|
||||
tlsConf := &tls.Config{
|
||||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
c := testdata.GetTLSConfig()
|
||||
c.NextProtos = []string{"foo", "bar"}
|
||||
return c, nil
|
||||
},
|
||||
}
|
||||
|
||||
ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.CloseWithError(0, "")
|
||||
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
|
||||
})
|
||||
|
||||
It("works if GetConfigForClient returns a nil tls.Config", func() {
|
||||
tlsConf := testdata.GetTLSConfig()
|
||||
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }
|
||||
|
||||
ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.CloseWithError(0, "")
|
||||
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
|
||||
})
|
||||
|
||||
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
|
||||
tlsClientConf := testdata.GetTLSConfig()
|
||||
tlsClientConf.NextProtos = []string{"foo", "bar"}
|
||||
tlsConf := &tls.Config{
|
||||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
return tlsClientConf, nil
|
||||
},
|
||||
}
|
||||
|
||||
ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.CloseWithError(0, "")
|
||||
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
|
||||
// check that the original config was not modified
|
||||
Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"}))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -1179,15 +1222,6 @@ var _ = Describe("Server", func() {
|
|||
Expect(s.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
checkGetConfigForClientVersions := func(conf *tls.Config) {
|
||||
c, err := conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft29)})
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(1, c.NextProtos).To(Equal([]string{NextProtoH3Draft29}))
|
||||
c, err = conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.Version1)})
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(1, c.NextProtos).To(Equal([]string{NextProtoH3}))
|
||||
}
|
||||
|
||||
It("uses the quic.Config to start the QUIC server", func() {
|
||||
conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond}
|
||||
var receivedConf *quic.Config
|
||||
|
@ -1199,106 +1233,6 @@ var _ = Describe("Server", func() {
|
|||
Expect(s.ListenAndServe()).To(HaveOccurred())
|
||||
Expect(receivedConf).To(Equal(conf))
|
||||
})
|
||||
|
||||
It("sets the GetConfigForClient and replaces the ALPN token to the tls.Config, if the GetConfigForClient callback is not set", func() {
|
||||
tlsConf := &tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
NextProtos: []string{"foo", "bar"},
|
||||
}
|
||||
var receivedConf *tls.Config
|
||||
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
|
||||
receivedConf = tlsConf
|
||||
return nil, errors.New("listen err")
|
||||
}
|
||||
s.TLSConfig = tlsConf
|
||||
Expect(s.ListenAndServe()).To(HaveOccurred())
|
||||
Expect(receivedConf.NextProtos).To(BeEmpty())
|
||||
Expect(receivedConf.ClientAuth).To(BeZero())
|
||||
// make sure the original tls.Config was not modified
|
||||
Expect(tlsConf.NextProtos).To(Equal([]string{"foo", "bar"}))
|
||||
// make sure that the config returned from the GetConfigForClient callback sets the fields of the original config
|
||||
conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert))
|
||||
checkGetConfigForClientVersions(receivedConf)
|
||||
})
|
||||
|
||||
It("sets the GetConfigForClient callback if no tls.Config is given", func() {
|
||||
var receivedConf *tls.Config
|
||||
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
|
||||
receivedConf = tlsConf
|
||||
return nil, errors.New("listen err")
|
||||
}
|
||||
Expect(s.ListenAndServe()).To(HaveOccurred())
|
||||
Expect(receivedConf).ToNot(BeNil())
|
||||
checkGetConfigForClientVersions(receivedConf)
|
||||
})
|
||||
|
||||
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() {
|
||||
tlsConf := &tls.Config{
|
||||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
return &tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
NextProtos: []string{"foo", "bar"},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
var receivedConf *tls.Config
|
||||
quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
|
||||
receivedConf = conf
|
||||
return nil, errors.New("listen err")
|
||||
}
|
||||
s.TLSConfig = tlsConf
|
||||
Expect(s.ListenAndServe()).To(HaveOccurred())
|
||||
// check that the original config was not modified
|
||||
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"}))
|
||||
// check that the config returned by the GetConfigForClient callback uses the returned config
|
||||
conf, err = receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert))
|
||||
checkGetConfigForClientVersions(receivedConf)
|
||||
})
|
||||
|
||||
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
|
||||
tlsClientConf := &tls.Config{NextProtos: []string{"foo", "bar"}}
|
||||
tlsConf := &tls.Config{
|
||||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
return tlsClientConf, nil
|
||||
},
|
||||
}
|
||||
|
||||
var receivedConf *tls.Config
|
||||
quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
|
||||
receivedConf = conf
|
||||
return nil, errors.New("listen err")
|
||||
}
|
||||
s.TLSConfig = tlsConf
|
||||
Expect(s.ListenAndServe()).To(HaveOccurred())
|
||||
// check that the original config was not modified
|
||||
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"}))
|
||||
checkGetConfigForClientVersions(receivedConf)
|
||||
})
|
||||
|
||||
It("works if GetConfigForClient returns a nil tls.Config", func() {
|
||||
tlsConf := &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }}
|
||||
|
||||
var receivedConf *tls.Config
|
||||
quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
|
||||
receivedConf = conf
|
||||
return nil, errors.New("listen err")
|
||||
}
|
||||
s.TLSConfig = tlsConf
|
||||
Expect(s.ListenAndServe()).To(HaveOccurred())
|
||||
conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conf).ToNot(BeNil())
|
||||
checkGetConfigForClientVersions(receivedConf)
|
||||
})
|
||||
})
|
||||
|
||||
It("closes gracefully", func() {
|
||||
|
|
|
@ -57,6 +57,10 @@ var ConnectionTracingKey = connTracingCtxKey{}
|
|||
|
||||
type connTracingCtxKey struct{}
|
||||
|
||||
// QUICVersionContextKey can be used to find out the QUIC version of a TLS handshake from the
|
||||
// context returned by tls.Config.ClientHelloInfo.Context.
|
||||
var QUICVersionContextKey = handshake.QUICVersionContextKey
|
||||
|
||||
// Stream is the interface implemented by QUIC streams
|
||||
// In addition to the errors listed on the Connection,
|
||||
// calls to stream functions can return a StreamError if the stream is canceled.
|
||||
|
|
|
@ -2,6 +2,7 @@ package handshake
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -20,6 +21,10 @@ import (
|
|||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
type quicVersionContextKey struct{}
|
||||
|
||||
var QUICVersionContextKey = &quicVersionContextKey{}
|
||||
|
||||
// TLS unexpected_message alert
|
||||
const alertUnexpectedMessage uint8 = 10
|
||||
|
||||
|
@ -64,30 +69,25 @@ const clientSessionStateRevision = 3
|
|||
|
||||
type conn struct {
|
||||
localAddr, remoteAddr net.Addr
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
var _ ConnWithVersion = &conn{}
|
||||
|
||||
func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion {
|
||||
return &conn{
|
||||
localAddr: local,
|
||||
remoteAddr: remote,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
var _ net.Conn = &conn{}
|
||||
|
||||
func (c *conn) Read([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Write([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Close() error { return nil }
|
||||
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
|
||||
func (c *conn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version }
|
||||
func newConn(local, remote net.Addr) net.Conn {
|
||||
return &conn{
|
||||
localAddr: local,
|
||||
remoteAddr: remote,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) Read([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Write([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Close() error { return nil }
|
||||
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
|
||||
func (c *conn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetDeadline(time.Time) error { return nil }
|
||||
|
||||
type cryptoSetup struct {
|
||||
tlsConf *tls.Config
|
||||
|
@ -183,7 +183,7 @@ func NewCryptoSetupClient(
|
|||
protocol.PerspectiveClient,
|
||||
version,
|
||||
)
|
||||
cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
|
||||
cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
|
||||
return cs, clientHelloWritten
|
||||
}
|
||||
|
||||
|
@ -218,7 +218,7 @@ func NewCryptoSetupServer(
|
|||
version,
|
||||
)
|
||||
cs.allow0RTT = allow0RTT
|
||||
cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
|
||||
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
|
||||
return cs
|
||||
}
|
||||
|
||||
|
@ -307,7 +307,7 @@ func (h *cryptoSetup) RunHandshake() {
|
|||
handshakeErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(h.handshakeDone)
|
||||
if err := h.conn.Handshake(); err != nil {
|
||||
if err := h.conn.HandshakeContext(context.WithValue(context.Background(), QUICVersionContextKey, h.version)); err != nil {
|
||||
handshakeErrChan <- err
|
||||
return
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package handshake
|
|||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -93,10 +92,3 @@ type CryptoSetup interface {
|
|||
Get0RTTSealer() (LongHeaderSealer, error)
|
||||
Get1RTTSealer() (ShortHeaderSealer, error)
|
||||
}
|
||||
|
||||
// ConnWithVersion is the connection used in the ClientHelloInfo.
|
||||
// It can be used to determine the QUIC version in use.
|
||||
type ConnWithVersion interface {
|
||||
net.Conn
|
||||
GetQUICVersion() protocol.VersionNumber
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue