diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index 0eaffd57..13e18d26 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/marten-seemann/qtls" + "github.com/lucas-clemente/quic-go/internal/qtls" ) func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD { diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index 180a76a2..aca94d6a 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -8,7 +8,7 @@ import ( "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/marten-seemann/qtls" + "github.com/lucas-clemente/quic-go/internal/qtls" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 04cab43c..807ebd18 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -12,10 +12,10 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/qtls" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/logging" - "github.com/marten-seemann/qtls" ) // TLS unexpected_message alert @@ -60,9 +60,32 @@ func (m messageType) String() string { const clientSessionStateRevision = 3 +type conn struct { + localAddr, remoteAddr net.Addr +} + +func newConn(local, remote net.Addr) net.Conn { + return &conn{ + localAddr: local, + remoteAddr: remote, + } +} + +var _ net.Conn = &conn{} + +func (c *conn) Read([]byte) (int, error) { return 0, nil } +func (c *conn) Write([]byte) (int, error) { return 0, nil } +func (c *conn) Close() error { return nil } +func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *conn) LocalAddr() net.Addr { return c.localAddr } +func (c *conn) SetReadDeadline(time.Time) error { return nil } +func (c *conn) SetWriteDeadline(time.Time) error { return nil } +func (c *conn) SetDeadline(time.Time) error { return nil } + type cryptoSetup struct { - tlsConf *qtls.Config - conn *qtls.Conn + tlsConf *tls.Config + extraConf *qtls.ExtraConfig + conn *qtls.Conn messageChan chan []byte @@ -152,7 +175,7 @@ func NewCryptoSetupClient( logger, protocol.PerspectiveClient, ) - cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf) + cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) return cs, clientHelloWritten } @@ -184,7 +207,7 @@ func NewCryptoSetupServer( logger, protocol.PerspectiveServer, ) - cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf) + cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) return cs } @@ -208,6 +231,7 @@ func newCryptoSetup( } extHandler := newExtensionHandler(tp.Marshal(perspective), perspective) cs := &cryptoSetup{ + tlsConf: tlsConf, initialStream: initialStream, initialSealer: initialSealer, initialOpener: initialOpener, @@ -231,8 +255,22 @@ func newCryptoSetup( writeRecord: make(chan struct{}, 1), closeChan: make(chan struct{}), } - qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, rttStats, cs.marshalDataForSessionState, cs.handleDataFromSessionState, cs.accept0RTT, cs.rejected0RTT, enable0RTT) - cs.tlsConf = qtlsConf + var maxEarlyData uint32 + if enable0RTT { + maxEarlyData = 0xffffffff + } + cs.extraConf = &qtls.ExtraConfig{ + GetExtensions: extHandler.GetExtensions, + ReceivedExtensions: extHandler.ReceivedExtensions, + AlternativeRecordLayer: cs, + EnforceNextProtoSelection: true, + MaxEarlyData: maxEarlyData, + Accept0RTT: cs.accept0RTT, + Rejected0RTT: cs.rejected0RTT, + Enable0RTT: enable0RTT, + GetAppDataForSessionState: cs.marshalDataForSessionState, + SetAppDataFromSessionState: cs.handleDataFromSessionState, + } return cs, cs.clientHelloWrittenChan } @@ -499,7 +537,7 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { var appData []byte // Save transport parameters to the session ticket if we're allowing 0-RTT. - if h.tlsConf.MaxEarlyData > 0 { + if h.extraConf.MaxEarlyData > 0 { appData = (&sessionTicket{ Parameters: h.ourParams, RTT: h.rttStats.SmoothedRTT(), @@ -819,5 +857,5 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { } func (h *cryptoSetup) ConnectionState() ConnectionState { - return h.conn.ConnectionState() + return qtls.GetConnectionState(h.conn) } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 2b085c67..df06f007 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -1,22 +1,20 @@ package handshake import ( - "bytes" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" - "errors" "math/big" "time" + mocktls "github.com/lucas-clemente/quic-go/internal/mocks/tls" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/marten-seemann/qtls" "github.com/golang/mock/gomock" @@ -74,48 +72,6 @@ var _ = Describe("Crypto Setup TLS", func() { } }) - It("creates a qtls.Config", func() { - tlsConf := &tls.Config{ - ServerName: "quic.clemente.io", - GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { - return nil, errors.New("GetCertificate") - }, - GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { - return nil, errors.New("GetClientCertificate") - }, - GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - return &tls.Config{ServerName: ch.ServerName}, nil - }, - } - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - &bytes.Buffer{}, - &bytes.Buffer{}, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - NewMockHandshakeRunner(mockCtrl), - tlsConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - ) - qtlsConf := server.(*cryptoSetup).tlsConf - Expect(qtlsConf.ServerName).To(Equal(tlsConf.ServerName)) - _, getCertificateErr := qtlsConf.GetCertificate(nil) - Expect(getCertificateErr).To(MatchError("GetCertificate")) - _, getClientCertificateErr := qtlsConf.GetClientCertificate(nil) - Expect(getClientCertificateErr).To(MatchError("GetClientCertificate")) - cconf, err := qtlsConf.GetConfigForClient(&qtls.ClientHelloInfo{ServerName: "foo.bar"}) - Expect(err).ToNot(HaveOccurred()) - Expect(cconf.ServerName).To(Equal("foo.bar")) - Expect(cconf.AlternativeRecordLayer).ToNot(BeNil()) - Expect(cconf.GetExtensions).ToNot(BeNil()) - Expect(cconf.ReceivedExtensions).ToNot(BeNil()) - }) - It("returns Handshake() when an error occurs in qtls", func() { sErrChan := make(chan error, 1) runner := NewMockHandshakeRunner(mockCtrl) @@ -420,7 +376,7 @@ var _ = Describe("Crypto Setup TLS", func() { It("handshakes with client auth", func() { clientConf.Certificates = []tls.Certificate{generateCert()} - serverConf.ClientAuth = qtls.RequireAnyClientCert + serverConf.ClientAuth = tls.RequireAnyClientCert _, _, clientErr, _, serverErr := handshakeWithTLSConf( clientConf, serverConf, &utils.RTTStats{}, &utils.RTTStats{}, @@ -647,7 +603,7 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("uses session resumption", func() { - csc := NewMockClientSessionCache(mockCtrl) + csc := mocktls.NewMockClientSessionCache(mockCtrl) var state *tls.ClientSessionState receivedSessionTicket := make(chan struct{}) csc.EXPECT().Get(gomock.Any()) @@ -690,7 +646,7 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("doesn't use session resumption if the server disabled it", func() { - csc := NewMockClientSessionCache(mockCtrl) + csc := mocktls.NewMockClientSessionCache(mockCtrl) var state *tls.ClientSessionState receivedSessionTicket := make(chan struct{}) csc.EXPECT().Get(gomock.Any()) @@ -727,7 +683,7 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("uses 0-RTT", func() { - csc := NewMockClientSessionCache(mockCtrl) + csc := mocktls.NewMockClientSessionCache(mockCtrl) var state *tls.ClientSessionState receivedSessionTicket := make(chan struct{}) csc.EXPECT().Get(gomock.Any()) @@ -782,7 +738,7 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("rejects 0-RTT, whent the transport parameters changed", func() { - csc := NewMockClientSessionCache(mockCtrl) + csc := mocktls.NewMockClientSessionCache(mockCtrl) var state *tls.ClientSessionState receivedSessionTicket := make(chan struct{}) csc.EXPECT().Get(gomock.Any()) diff --git a/internal/handshake/handshake_suite_test.go b/internal/handshake/handshake_suite_test.go index b62f205c..807b4afe 100644 --- a/internal/handshake/handshake_suite_test.go +++ b/internal/handshake/handshake_suite_test.go @@ -1,14 +1,13 @@ package handshake import ( - "crypto" - "crypto/cipher" + "crypto/tls" "encoding/hex" "strings" - "unsafe" + + "github.com/lucas-clemente/quic-go/internal/qtls" "github.com/golang/mock/gomock" - "github.com/marten-seemann/qtls" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -31,27 +30,6 @@ var _ = AfterEach(func() { mockCtrl.Finish() }) -var cipherSuites = []*qtls.CipherSuiteTLS13{ - &qtls.CipherSuiteTLS13{ - ID: qtls.TLS_AES_128_GCM_SHA256, - KeyLen: 16, - AEAD: qtls.AEADAESGCMTLS13, - Hash: crypto.SHA256, - }, - &qtls.CipherSuiteTLS13{ - ID: qtls.TLS_AES_256_GCM_SHA384, - KeyLen: 32, - AEAD: qtls.AEADAESGCMTLS13, - Hash: crypto.SHA384, - }, - &qtls.CipherSuiteTLS13{ - ID: qtls.TLS_CHACHA20_POLY1305_SHA256, - KeyLen: 32, - AEAD: nil, // will be set by init - Hash: crypto.SHA256, - }, -} - func splitHexString(s string) (slice []byte) { for _, ss := range strings.Split(s, " ") { if ss[0:2] == "0x" { @@ -64,25 +42,8 @@ func splitHexString(s string) (slice []byte) { return } -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -func init() { - val := cipherSuiteTLS13ByID(qtls.TLS_CHACHA20_POLY1305_SHA256) - chacha := (*cipherSuiteTLS13)(unsafe.Pointer(val)) - for _, s := range cipherSuites { - if s.ID == qtls.TLS_CHACHA20_POLY1305_SHA256 { - if s.KeyLen != chacha.KeyLen || s.Hash != chacha.Hash { - panic("invalid parameters for ChaCha20") - } - s.AEAD = chacha.AEAD - } - } +var cipherSuites = []*qtls.CipherSuiteTLS13{ + qtls.CipherSuiteTLS13ByID(tls.TLS_AES_128_GCM_SHA256), + qtls.CipherSuiteTLS13ByID(tls.TLS_AES_256_GCM_SHA384), + qtls.CipherSuiteTLS13ByID(tls.TLS_CHACHA20_POLY1305_SHA256), } diff --git a/internal/handshake/header_protector.go b/internal/handshake/header_protector.go index dad59697..309af6b1 100644 --- a/internal/handshake/header_protector.go +++ b/internal/handshake/header_protector.go @@ -3,12 +3,13 @@ package handshake import ( "crypto/aes" "crypto/cipher" + "crypto/tls" "encoding/binary" "fmt" "golang.org/x/crypto/chacha20" - "github.com/marten-seemann/qtls" + "github.com/lucas-clemente/quic-go/internal/qtls" ) type headerProtector interface { @@ -18,9 +19,9 @@ type headerProtector interface { func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector { switch suite.ID { - case qtls.TLS_AES_128_GCM_SHA256, qtls.TLS_AES_256_GCM_SHA384: + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: return newAESHeaderProtector(suite, trafficSecret, isLongHeader) - case qtls.TLS_CHACHA20_POLY1305_SHA256: + case tls.TLS_CHACHA20_POLY1305_SHA256: return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader) default: panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID)) diff --git a/internal/handshake/hkdf_test.go b/internal/handshake/hkdf_test.go index 3fe538fa..ca37d209 100644 --- a/internal/handshake/hkdf_test.go +++ b/internal/handshake/hkdf_test.go @@ -5,7 +5,7 @@ import ( "crypto/rand" mrand "math/rand" - "github.com/marten-seemann/qtls" + "github.com/lucas-clemente/quic-go/internal/qtls" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index a95d5cfd..85522cde 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -2,15 +2,16 @@ package handshake import ( "crypto" + "crypto/tls" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/marten-seemann/qtls" + "github.com/lucas-clemente/quic-go/internal/qtls" ) var quicVersion1Salt = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99} var initialSuite = &qtls.CipherSuiteTLS13{ - ID: qtls.TLS_AES_128_GCM_SHA256, + ID: tls.TLS_AES_128_GCM_SHA256, KeyLen: 16, AEAD: qtls.AEADAESGCMTLS13, Hash: crypto.SHA256, diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 651fca31..4c72281d 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -6,8 +6,8 @@ import ( "time" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qtls" "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/marten-seemann/qtls" ) var ( diff --git a/internal/handshake/mockgen.go b/internal/handshake/mockgen.go index 96e92646..c7a8d13e 100644 --- a/internal/handshake/mockgen.go +++ b/internal/handshake/mockgen.go @@ -1,7 +1,3 @@ package handshake //go:generate sh -c "../../mockgen_private.sh handshake mock_handshake_runner_test.go github.com/lucas-clemente/quic-go/internal/handshake handshakeRunner" - -// The following command produces a warning message on OSX, however, it still generates the correct mock file. -// See https://github.com/golang/mock/issues/339 for details. -//go:generate sh -c "mockgen -package handshake -destination mock_client_session_cache_test.go crypto/tls ClientSessionCache && goimports -w mock_client_session_cache_test.go" diff --git a/internal/handshake/qtls.go b/internal/handshake/qtls.go deleted file mode 100644 index 1aad299d..00000000 --- a/internal/handshake/qtls.go +++ /dev/null @@ -1,233 +0,0 @@ -package handshake - -import ( - "crypto/tls" - "net" - "time" - "unsafe" - - "github.com/lucas-clemente/quic-go/internal/utils" - - "github.com/marten-seemann/qtls" -) - -func init() { - if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) { - panic("clientHelloInfo not compatible with tls.ClientHelloInfo") - } - if !structsEqual(&qtls.ClientHelloInfo{}, &qtlsClientHelloInfo{}) { - panic("qtlsClientHelloInfo not compatible with qtls.ClientHelloInfo") - } -} - -type conn struct { - localAddr, remoteAddr net.Addr -} - -func newConn(local, remote net.Addr) net.Conn { - return &conn{ - localAddr: local, - remoteAddr: remote, - } -} - -var _ net.Conn = &conn{} - -func (c *conn) Read([]byte) (int, error) { return 0, nil } -func (c *conn) Write([]byte) (int, error) { return 0, nil } -func (c *conn) Close() error { return nil } -func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } -func (c *conn) LocalAddr() net.Addr { return c.localAddr } -func (c *conn) SetReadDeadline(time.Time) error { return nil } -func (c *conn) SetWriteDeadline(time.Time) error { return nil } -func (c *conn) SetDeadline(time.Time) error { return nil } - -func tlsConfigToQtlsConfig( - c *tls.Config, - recordLayer qtls.RecordLayer, - extHandler tlsExtensionHandler, - rttStats *utils.RTTStats, - getDataForSessionState func() []byte, - setDataFromSessionState func([]byte), - accept0RTT func([]byte) bool, - rejected0RTT func(), - enable0RTT bool, -) *qtls.Config { - if c == nil { - c = &tls.Config{} - } - // Clone the config first. This executes the tls.Config.serverInit(). - // This sets the SessionTicketKey, if the user didn't supply one. - c = c.Clone() - // QUIC requires TLS 1.3 or newer - minVersion := c.MinVersion - if minVersion < qtls.VersionTLS13 { - minVersion = qtls.VersionTLS13 - } - maxVersion := c.MaxVersion - if maxVersion < qtls.VersionTLS13 { - maxVersion = qtls.VersionTLS13 - } - var getConfigForClient func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) - if c.GetConfigForClient != nil { - getConfigForClient = func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) { - tlsConf, err := c.GetConfigForClient(toTLSClientHelloInfo(ch)) - if err != nil { - return nil, err - } - if tlsConf == nil { - return nil, nil - } - return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, rttStats, getDataForSessionState, setDataFromSessionState, accept0RTT, rejected0RTT, enable0RTT), nil - } - } - var getCertificate func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) - if c.GetCertificate != nil { - getCertificate = func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) { - cert, err := c.GetCertificate(toTLSClientHelloInfo(ch)) - if err != nil { - return nil, err - } - if cert == nil { - return nil, nil - } - return (*qtls.Certificate)(cert), nil - } - } - var csc qtls.ClientSessionCache - if c.ClientSessionCache != nil { - csc = &clientSessionCache{c.ClientSessionCache} - } - conf := &qtls.Config{ - Rand: c.Rand, - Time: c.Time, - Certificates: *(*[]qtls.Certificate)(unsafe.Pointer(&c.Certificates)), - // NameToCertificate is deprecated, but we still need to copy it if the user sets it. - //nolint:staticcheck - NameToCertificate: *(*map[string]*qtls.Certificate)(unsafe.Pointer(&c.NameToCertificate)), - GetCertificate: getCertificate, - GetClientCertificate: *(*func(*qtls.CertificateRequestInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetClientCertificate)), - GetConfigForClient: getConfigForClient, - VerifyPeerCertificate: c.VerifyPeerCertificate, - RootCAs: c.RootCAs, - NextProtos: c.NextProtos, - EnforceNextProtoSelection: true, - ServerName: c.ServerName, - ClientAuth: c.ClientAuth, - ClientCAs: c.ClientCAs, - InsecureSkipVerify: c.InsecureSkipVerify, - CipherSuites: c.CipherSuites, - PreferServerCipherSuites: c.PreferServerCipherSuites, - SessionTicketsDisabled: c.SessionTicketsDisabled, - SessionTicketKey: c.SessionTicketKey, - ClientSessionCache: csc, - MinVersion: minVersion, - MaxVersion: maxVersion, - CurvePreferences: c.CurvePreferences, - DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, - // no need to copy Renegotiation, it's not supported by TLS 1.3 - KeyLogWriter: c.KeyLogWriter, - AlternativeRecordLayer: recordLayer, - GetExtensions: extHandler.GetExtensions, - ReceivedExtensions: extHandler.ReceivedExtensions, - Accept0RTT: accept0RTT, - Rejected0RTT: rejected0RTT, - GetAppDataForSessionState: getDataForSessionState, - SetAppDataFromSessionState: setDataFromSessionState, - } - if enable0RTT { - conf.Enable0RTT = true - conf.MaxEarlyData = 0xffffffff - } - return conf -} - -type clientSessionCache struct { - tls.ClientSessionCache -} - -var _ qtls.ClientSessionCache = &clientSessionCache{} - -func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, bool) { - sess, ok := c.ClientSessionCache.Get(sessionKey) - if sess == nil { - return nil, ok - } - // qtls.ClientSessionState is identical to the tls.ClientSessionState. - // In order to allow users of quic-go to use a tls.Config, - // we need this workaround to use the ClientSessionCache. - // In unsafe.go we check that the two structs are actually identical. - return (*qtls.ClientSessionState)(unsafe.Pointer(sess)), ok -} - -func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) { - if cs == nil { - c.ClientSessionCache.Put(sessionKey, nil) - return - } - // qtls.ClientSessionState is identical to the tls.ClientSessionState. - // In order to allow users of quic-go to use a tls.Config, - // we need this workaround to use the ClientSessionCache. - // In unsafe.go we check that the two structs are actually identical. - c.ClientSessionCache.Put(sessionKey, (*tls.ClientSessionState)(unsafe.Pointer(cs))) -} - -type clientHelloInfo struct { - CipherSuites []uint16 - ServerName string - SupportedCurves []tls.CurveID - SupportedPoints []uint8 - SignatureSchemes []tls.SignatureScheme - SupportedProtos []string - SupportedVersions []uint16 - Conn net.Conn - - config *tls.Config -} - -type qtlsClientHelloInfo struct { - CipherSuites []uint16 - ServerName string - SupportedCurves []tls.CurveID - SupportedPoints []uint8 - SignatureSchemes []tls.SignatureScheme - SupportedProtos []string - SupportedVersions []uint16 - Conn net.Conn - - config *qtls.Config -} - -func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo { - if chi == nil { - return nil - } - qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi)) - var config *tls.Config - if qtlsCHI.config != nil { - config = qtlsConfigToTLSConfig(qtlsCHI.config) - } - return (*tls.ClientHelloInfo)(unsafe.Pointer(&clientHelloInfo{ - CipherSuites: chi.CipherSuites, - ServerName: chi.ServerName, - SupportedCurves: chi.SupportedCurves, - SupportedPoints: chi.SupportedPoints, - SignatureSchemes: chi.SignatureSchemes, - SupportedProtos: chi.SupportedProtos, - SupportedVersions: chi.SupportedVersions, - Conn: chi.Conn, - config: config, - })) -} - -// qtlsConfigToTLSConfig is used to transform a qtls.Config to a tls.Config. -// It is used to create the tls.Config in the ClientHelloInfo. -// It doesn't copy all values, but only those used by ClientHelloInfo.SupportsCertificate. -func qtlsConfigToTLSConfig(config *qtls.Config) *tls.Config { - return &tls.Config{ - MinVersion: config.MinVersion, - MaxVersion: config.MaxVersion, - CipherSuites: config.CipherSuites, - CurvePreferences: config.CurvePreferences, - } -} diff --git a/internal/handshake/tls_extension_handler.go b/internal/handshake/tls_extension_handler.go index 590aafd1..33409b8e 100644 --- a/internal/handshake/tls_extension_handler.go +++ b/internal/handshake/tls_extension_handler.go @@ -2,7 +2,7 @@ package handshake import ( "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/marten-seemann/qtls" + "github.com/lucas-clemente/quic-go/internal/qtls" ) const quicTLSExtensionType = 0xffa5 diff --git a/internal/handshake/tls_extension_handler_test.go b/internal/handshake/tls_extension_handler_test.go index 819497d0..453c5101 100644 --- a/internal/handshake/tls_extension_handler_test.go +++ b/internal/handshake/tls_extension_handler_test.go @@ -2,7 +2,7 @@ package handshake import ( "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/marten-seemann/qtls" + "github.com/lucas-clemente/quic-go/internal/qtls" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" diff --git a/internal/handshake/unsafe.go b/internal/handshake/unsafe.go deleted file mode 100644 index 1bb31f8e..00000000 --- a/internal/handshake/unsafe.go +++ /dev/null @@ -1,44 +0,0 @@ -package handshake - -// This package uses unsafe to convert between: -// * qtls.Certificate and tls.Certificate -// * qtls.CertificateRequestInfo and tls.CertificateRequestInfo -// * qtls.ClientHelloInfo and tls.ClientHelloInfo -// * qtls.ConnectionState and tls.ConnectionState -// * qtls.ClientSessionState and tls.ClientSessionState -// We check in init() that this conversion actually is safe. - -import ( - "crypto/tls" - "reflect" - - "github.com/marten-seemann/qtls" -) - -func init() { - if !structsEqual(&tls.Certificate{}, &qtls.Certificate{}) { - panic("qtls.Certificate not compatible with tls.Certificate") - } - if !structsEqual(&tls.CertificateRequestInfo{}, &qtls.CertificateRequestInfo{}) { - panic("qtls.CertificateRequestInfo not compatible with tls.CertificateRequestInfo") - } - if !structsEqual(&tls.ClientSessionState{}, &qtls.ClientSessionState{}) { - panic("qtls.ClientSessionState not compatible with tls.ClientSessionState") - } -} - -func structsEqual(a, b interface{}) bool { - sa := reflect.ValueOf(a).Elem() - sb := reflect.ValueOf(b).Elem() - if sa.NumField() != sb.NumField() { - return false - } - for i := 0; i < sa.NumField(); i++ { - fa := sa.Type().Field(i) - fb := sb.Type().Field(i) - if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { - return false - } - } - return true -} diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 11c0accb..bf45394a 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -9,12 +9,11 @@ import ( "strconv" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/qtls" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/logging" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/marten-seemann/qtls" ) // By setting this environment variable, the key update interval can be adjusted. diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 013b9ff6..7ec8cb61 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -2,13 +2,14 @@ package handshake import ( "crypto/rand" + "crypto/tls" "fmt" "os" "time" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qtls" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/marten-seemann/qtls" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -19,7 +20,7 @@ var _ = Describe("Updatable AEAD", func() { secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b") aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil) chacha := cipherSuites[2] - Expect(chacha.ID).To(Equal(qtls.TLS_CHACHA20_POLY1305_SHA256)) + Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256)) aead.SetWriteKey(chacha, secret) header := splitHexString("4200bff4") const pnOffset = 1 diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index ed03c433..193756e3 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -10,7 +10,7 @@ import ( gomock "github.com/golang/mock/gomock" handshake "github.com/lucas-clemente/quic-go/internal/handshake" protocol "github.com/lucas-clemente/quic-go/internal/protocol" - qtls "github.com/marten-seemann/qtls" + qtls "github.com/lucas-clemente/quic-go/internal/qtls" ) // MockCryptoSetup is a mock of CryptoSetup interface diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 9d84e3d2..c2128c2d 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -1,16 +1,20 @@ package mocks //go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go" -//go:generate sh -c "mockgen -package mockquic -destination quic/early_session.go github.com/lucas-clemente/quic-go EarlySession && goimports -w quic/early_session.go" +//go:generate sh -c "mockgen -package mockquic -destination quic/early_session_tmp.go github.com/lucas-clemente/quic-go EarlySession && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_session_tmp.go > quic/early_session.go && rm quic/early_session_tmp.go && goimports -w quic/early_session.go" //go:generate sh -c "mockgen -package mockquic -destination quic/early_listener.go github.com/lucas-clemente/quic-go EarlyListener && goimports -w quic/early_listener.go" //go:generate sh -c "mockgen -package mocks -destination tracer.go github.com/lucas-clemente/quic-go/logging Tracer && goimports -w tracer.go" //go:generate sh -c "mockgen -package mocks -destination connection_tracer.go github.com/lucas-clemente/quic-go/logging ConnectionTracer && goimports -w connection_tracer.go" //go:generate sh -c "mockgen -package mocks -destination short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer && goimports -w short_header_sealer.go" //go:generate sh -c "mockgen -package mocks -destination short_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderOpener && goimports -w short_header_opener.go" //go:generate sh -c "mockgen -package mocks -destination long_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake LongHeaderOpener && goimports -w long_header_opener.go" -//go:generate sh -c "mockgen -package mocks -destination crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup && goimports -w crypto_setup.go" +//go:generate sh -c "mockgen -package mocks -destination crypto_setup_tmp.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup && sed 's/github.com\\/marten-seemann\\/qtls/github.com\\/lucas-clemente\\/quic-go\\/internal\\/qtls/g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && goimports -w crypto_setup.go" //go:generate sh -c "mockgen -package mocks -destination stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController && goimports -w stream_flow_controller.go" //go:generate sh -c "mockgen -package mocks -destination congestion.go github.com/lucas-clemente/quic-go/internal/congestion SendAlgorithmWithDebugInfos && goimports -w congestion.go" //go:generate sh -c "mockgen -package mocks -destination connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController && goimports -w connection_flow_controller.go" //go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler && goimports -w ackhandler/sent_packet_handler.go" //go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler && goimports -w ackhandler/received_packet_handler.go" + +// The following command produces a warning message on OSX, however, it still generates the correct mock file. +// See https://github.com/golang/mock/issues/339 for details. +//go:generate sh -c "mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache && goimports -w tls/client_session_cache.go" diff --git a/internal/mocks/quic/early_session.go b/internal/mocks/quic/early_session.go index e5fcf0ae..78d44ea1 100644 --- a/internal/mocks/quic/early_session.go +++ b/internal/mocks/quic/early_session.go @@ -12,7 +12,6 @@ import ( gomock "github.com/golang/mock/gomock" quic "github.com/lucas-clemente/quic-go" protocol "github.com/lucas-clemente/quic-go/internal/protocol" - qtls "github.com/marten-seemann/qtls" ) // MockEarlySession is a mock of EarlySession interface @@ -83,10 +82,10 @@ func (mr *MockEarlySessionMockRecorder) CloseWithError(arg0, arg1 interface{}) * } // ConnectionState mocks base method -func (m *MockEarlySession) ConnectionState() qtls.ConnectionState { +func (m *MockEarlySession) ConnectionState() quic.ConnectionState { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(qtls.ConnectionState) + ret0, _ := ret[0].(quic.ConnectionState) return ret0 } diff --git a/internal/handshake/mock_client_session_cache_test.go b/internal/mocks/tls/client_session_cache.go similarity index 96% rename from internal/handshake/mock_client_session_cache_test.go rename to internal/mocks/tls/client_session_cache.go index 04da10eb..0e43c552 100644 --- a/internal/handshake/mock_client_session_cache_test.go +++ b/internal/mocks/tls/client_session_cache.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. // Source: crypto/tls (interfaces: ClientSessionCache) -// Package handshake is a generated GoMock package. -package handshake +// Package mocktls is a generated GoMock package. +package mocktls import ( tls "crypto/tls" diff --git a/internal/qerr/error_codes.go b/internal/qerr/error_codes.go index 1e1cf7ba..188219f7 100644 --- a/internal/qerr/error_codes.go +++ b/internal/qerr/error_codes.go @@ -3,7 +3,7 @@ package qerr import ( "fmt" - "github.com/marten-seemann/qtls" + "github.com/lucas-clemente/quic-go/internal/qtls" ) // ErrorCode can be used as a normal error without reason. diff --git a/internal/qtls/interface.go b/internal/qtls/interface.go new file mode 100644 index 00000000..e0f3747f --- /dev/null +++ b/internal/qtls/interface.go @@ -0,0 +1,164 @@ +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "net" + "unsafe" + + "github.com/marten-seemann/qtls" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains inforamtion about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionState + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +type ExtraConfig struct { + // GetExtensions, if not nil, is called before a message that allows + // sending of extensions is sent. + // Currently only implemented for the ClientHello message (for the client) + // and for the EncryptedExtensions message (for the server). + // Only valid for TLS 1.3. + GetExtensions func(handshakeMessageType uint8) []Extension + + // ReceivedExtensions, if not nil, is called when a message that allows the + // inclusion of extensions is received. + // It is called with an empty slice of extensions, if the message didn't + // contain any extensions. + // Currently only implemented for the ClientHello message (sent by the + // client) and for the EncryptedExtensions message (sent by the server). + // Only valid for TLS 1.3. + ReceivedExtensions func(handshakeMessageType uint8, exts []Extension) + + // AlternativeRecordLayer is used by QUIC + AlternativeRecordLayer RecordLayer + + // Enforce the selection of a supported application protocol. + // Only works for TLS 1.3. + // If enabled, client and server have to agree on an application protocol. + // Otherwise, connection establishment fails. + EnforceNextProtoSelection bool + + // If MaxEarlyData is greater than 0, the client will be allowed to send early + // data when resuming a session. + // Requires the AlternativeRecordLayer to be set. + // + // It has no meaning on the client. + MaxEarlyData uint32 + + // The Accept0RTT callback is called when the client offers 0-RTT. + // The server then has to decide if it wants to accept or reject 0-RTT. + // It is only used for servers. + Accept0RTT func(appData []byte) bool + + // 0RTTRejected is called when the server rejectes 0-RTT. + // It is only used for clients. + Rejected0RTT func() + + // If set, the client will export the 0-RTT key when resuming a session that + // allows sending of early data. + // Requires the AlternativeRecordLayer to be set. + // + // It has no meaning to the server. + Enable0RTT bool + + // Is called when the client saves a session ticket to the session ticket. + // This gives the application the opportunity to save some data along with the ticket, + // which can be restored when the session ticket is used. + GetAppDataForSessionState func() []byte + + // Is called when the client uses a session ticket. + // Restores the application data that was saved earlier on GetAppDataForSessionTicket. + SetAppDataFromSessionState func([]byte) +} + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// CipherSuiteName gets the name of a cipher suite. +func CipherSuiteName(id uint16) string { + return qtls.CipherSuiteName(id) +} + +// HkdfExtract generates a pseudorandom key for use with Expand from an input secret and an optional independent salt. +func HkdfExtract(hash crypto.Hash, newSecret, currentSecret []byte) []byte { + return qtls.HkdfExtract(hash, newSecret, currentSecret) +} + +// HkdfExpandLabel HKDF expands a label +func HkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte { + return qtls.HkdfExpandLabel(hash, secret, hashValue, label, L) +} + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *tls.Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, tlsConfigToQtlsConfig(config, extraConfig)) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *tls.Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, tlsConfigToQtlsConfig(config, extraConfig)) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionState() +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} diff --git a/internal/qtls/qtls.go b/internal/qtls/qtls.go new file mode 100644 index 00000000..c281c37e --- /dev/null +++ b/internal/qtls/qtls.go @@ -0,0 +1,214 @@ +package qtls + +// This package uses unsafe to convert between: +// * Certificate and tls.Certificate +// * CertificateRequestInfo and tls.CertificateRequestInfo +// * ClientHelloInfo and tls.ClientHelloInfo +// * ConnectionState and tls.ConnectionState +// * ClientSessionState and tls.ClientSessionState +// We check in init() that this conversion actually is safe. + +import ( + "crypto/tls" + "net" + "unsafe" +) + +func init() { + if !structsEqual(&tls.Certificate{}, &Certificate{}) { + panic("Certificate not compatible with tls.Certificate") + } + if !structsEqual(&tls.CertificateRequestInfo{}, &CertificateRequestInfo{}) { + panic("CertificateRequestInfo not compatible with tls.CertificateRequestInfo") + } + if !structsEqual(&tls.ClientSessionState{}, &ClientSessionState{}) { + panic("ClientSessionState not compatible with tls.ClientSessionState") + } + if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) { + panic("clientHelloInfo not compatible with tls.ClientHelloInfo") + } + if !structsEqual(&ClientHelloInfo{}, &qtlsClientHelloInfo{}) { + panic("qtlsClientHelloInfo not compatible with ClientHelloInfo") + } +} + +func tlsConfigToQtlsConfig(c *tls.Config, ec *ExtraConfig) *Config { + if c == nil { + c = &tls.Config{} + } + if ec == nil { + ec = &ExtraConfig{} + } + // Clone the config first. This executes the tls.Config.serverInit(). + // This sets the SessionTicketKey, if the user didn't supply one. + c = c.Clone() + // QUIC requires TLS 1.3 or newer + minVersion := c.MinVersion + if minVersion < tls.VersionTLS13 { + minVersion = tls.VersionTLS13 + } + maxVersion := c.MaxVersion + if maxVersion < tls.VersionTLS13 { + maxVersion = tls.VersionTLS13 + } + var getConfigForClient func(ch *ClientHelloInfo) (*Config, error) + if c.GetConfigForClient != nil { + getConfigForClient = func(ch *ClientHelloInfo) (*Config, error) { + tlsConf, err := c.GetConfigForClient(toTLSClientHelloInfo(ch)) + if err != nil { + return nil, err + } + if tlsConf == nil { + return nil, nil + } + return tlsConfigToQtlsConfig(tlsConf, ec), nil + } + } + var getCertificate func(ch *ClientHelloInfo) (*Certificate, error) + if c.GetCertificate != nil { + getCertificate = func(ch *ClientHelloInfo) (*Certificate, error) { + cert, err := c.GetCertificate(toTLSClientHelloInfo(ch)) + if err != nil { + return nil, err + } + if cert == nil { + return nil, nil + } + return (*Certificate)(cert), nil + } + } + var csc ClientSessionCache + if c.ClientSessionCache != nil { + csc = &clientSessionCache{c.ClientSessionCache} + } + conf := &Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: *(*[]Certificate)(unsafe.Pointer(&c.Certificates)), + //nolint:staticcheck // NameToCertificate is deprecated, but we still need to copy it if the user sets it. + NameToCertificate: *(*map[string]*Certificate)(unsafe.Pointer(&c.NameToCertificate)), + GetCertificate: getCertificate, + GetClientCertificate: *(*func(*CertificateRequestInfo) (*Certificate, error))(unsafe.Pointer(&c.GetClientCertificate)), + GetConfigForClient: getConfigForClient, + VerifyPeerCertificate: c.VerifyPeerCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + EnforceNextProtoSelection: true, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + //nolint:staticcheck // SessionTicketKey is deprecated, but we still need to copy it if the user sets it. + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: csc, + MinVersion: minVersion, + MaxVersion: maxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + // no need to copy Renegotiation, it's not supported by TLS 1.3 + KeyLogWriter: c.KeyLogWriter, + AlternativeRecordLayer: ec.AlternativeRecordLayer, + GetExtensions: ec.GetExtensions, + ReceivedExtensions: ec.ReceivedExtensions, + Accept0RTT: ec.Accept0RTT, + Rejected0RTT: ec.Rejected0RTT, + GetAppDataForSessionState: ec.GetAppDataForSessionState, + SetAppDataFromSessionState: ec.SetAppDataFromSessionState, + Enable0RTT: ec.Enable0RTT, + MaxEarlyData: ec.MaxEarlyData, + } + return conf +} + +type clientSessionCache struct { + tls.ClientSessionCache +} + +var _ ClientSessionCache = &clientSessionCache{} + +func (c *clientSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { + sess, ok := c.ClientSessionCache.Get(sessionKey) + if sess == nil { + return nil, ok + } + // ClientSessionState is identical to the tls.ClientSessionState. + // In order to allow users of quic-go to use a tls.Config, + // we need this workaround to use the ClientSessionCache. + // In unsafe.go we check that the two structs are actually identical. + return (*ClientSessionState)(unsafe.Pointer(sess)), ok +} + +func (c *clientSessionCache) Put(sessionKey string, cs *ClientSessionState) { + if cs == nil { + c.ClientSessionCache.Put(sessionKey, nil) + return + } + // ClientSessionState is identical to the tls.ClientSessionState. + // In order to allow users of quic-go to use a tls.Config, + // we need this workaround to use the ClientSessionCache. + // In unsafe.go we check that the two structs are actually identical. + c.ClientSessionCache.Put(sessionKey, (*tls.ClientSessionState)(unsafe.Pointer(cs))) +} + +type clientHelloInfo struct { + CipherSuites []uint16 + ServerName string + SupportedCurves []tls.CurveID + SupportedPoints []uint8 + SignatureSchemes []tls.SignatureScheme + SupportedProtos []string + SupportedVersions []uint16 + Conn net.Conn + + config *tls.Config +} + +type qtlsClientHelloInfo struct { + CipherSuites []uint16 + ServerName string + SupportedCurves []tls.CurveID + SupportedPoints []uint8 + SignatureSchemes []tls.SignatureScheme + SupportedProtos []string + SupportedVersions []uint16 + Conn net.Conn + + config *Config +} + +func toTLSClientHelloInfo(chi *ClientHelloInfo) *tls.ClientHelloInfo { + if chi == nil { + return nil + } + qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi)) + var config *tls.Config + if qtlsCHI.config != nil { + config = qtlsConfigToTLSConfig(qtlsCHI.config) + } + return (*tls.ClientHelloInfo)(unsafe.Pointer(&clientHelloInfo{ + CipherSuites: chi.CipherSuites, + ServerName: chi.ServerName, + SupportedCurves: chi.SupportedCurves, + SupportedPoints: chi.SupportedPoints, + SignatureSchemes: chi.SignatureSchemes, + SupportedProtos: chi.SupportedProtos, + SupportedVersions: chi.SupportedVersions, + Conn: chi.Conn, + config: config, + })) +} + +// qtlsConfigToTLSConfig is used to transform a Config to a tls.Config. +// It is used to create the tls.Config in the ClientHelloInfo. +// It doesn't copy all values, but only those used by ClientHelloInfo.SupportsCertificate. +func qtlsConfigToTLSConfig(config *Config) *tls.Config { + return &tls.Config{ + MinVersion: config.MinVersion, + MaxVersion: config.MaxVersion, + CipherSuites: config.CipherSuites, + CurvePreferences: config.CurvePreferences, + } +} diff --git a/internal/qtls/qtls_suite_test.go b/internal/qtls/qtls_suite_test.go new file mode 100644 index 00000000..24b143b2 --- /dev/null +++ b/internal/qtls/qtls_suite_test.go @@ -0,0 +1,25 @@ +package qtls + +import ( + "testing" + + gomock "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestQTLS(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "qtls Suite") +} + +var mockCtrl *gomock.Controller + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) diff --git a/internal/handshake/qtls_test.go b/internal/qtls/qtls_test.go similarity index 57% rename from internal/handshake/qtls_test.go rename to internal/qtls/qtls_test.go index fe83dc57..0bdd4052 100644 --- a/internal/handshake/qtls_test.go +++ b/internal/qtls/qtls_test.go @@ -1,4 +1,4 @@ -package handshake +package qtls import ( "crypto/tls" @@ -6,99 +6,91 @@ import ( "net" "unsafe" - "github.com/lucas-clemente/quic-go/internal/utils" + mocktls "github.com/lucas-clemente/quic-go/internal/mocks/tls" "github.com/marten-seemann/qtls" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -type mockExtensionHandler struct { - get, received bool -} - -var _ tlsExtensionHandler = &mockExtensionHandler{} - -func (h *mockExtensionHandler) GetExtensions(msgType uint8) []qtls.Extension { - h.get = true - return nil -} -func (h *mockExtensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) { - h.received = true -} -func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not implemented") } - -var _ = Describe("qtls.Config", func() { +var _ = Describe("Config", func() { It("sets MinVersion and MaxVersion", func() { tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12} - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil) Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13)) }) It("works when called with a nil config", func() { - qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) + qtlsConf := tlsConfigToQtlsConfig(nil, nil) Expect(qtlsConf).ToNot(BeNil()) }) It("sets the setter and getter function for TLS extensions", func() { - extHandler := &mockExtensionHandler{} - qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, utils.NewRTTStats(), nil, nil, nil, nil, false) - Expect(extHandler.get).To(BeFalse()) + var get, received bool + extraConfig := &ExtraConfig{ + GetExtensions: func(handshakeMessageType uint8) []Extension { get = true; return nil }, + ReceivedExtensions: func(handshakeMessageType uint8, exts []qtls.Extension) { received = true }, + } + qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, extraConfig) qtlsConf.GetExtensions(10) - Expect(extHandler.get).To(BeTrue()) - Expect(extHandler.received).To(BeFalse()) + Expect(get).To(BeTrue()) + Expect(received).To(BeFalse()) qtlsConf.ReceivedExtensions(10, nil) - Expect(extHandler.received).To(BeTrue()) + Expect(received).To(BeTrue()) }) It("sets the Accept0RTT callback", func() { - accept0RTT := func([]byte) bool { return true } - qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, accept0RTT, nil, false) + qtlsConf := tlsConfigToQtlsConfig(nil, &ExtraConfig{Accept0RTT: func([]byte) bool { return true }}) Expect(qtlsConf.Accept0RTT).ToNot(BeNil()) Expect(qtlsConf.Accept0RTT(nil)).To(BeTrue()) }) - It("sets the Accept0RTT callback", func() { + It("sets the Rejected0RTT callback", func() { var called bool - rejected0RTT := func() { called = true } - qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, rejected0RTT, false) + qtlsConf := tlsConfigToQtlsConfig(nil, &ExtraConfig{Rejected0RTT: func() { called = true }}) Expect(qtlsConf.Rejected0RTT).ToNot(BeNil()) qtlsConf.Rejected0RTT() Expect(called).To(BeTrue()) }) - It("enables 0-RTT", func() { - qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) - Expect(qtlsConf.Enable0RTT).To(BeFalse()) + It("sets MaxEarlyData", func() { + qtlsConf := tlsConfigToQtlsConfig(nil, nil) Expect(qtlsConf.MaxEarlyData).To(BeZero()) - qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, true) + qtlsConf = tlsConfigToQtlsConfig(nil, &ExtraConfig{MaxEarlyData: 1337}) + Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(1337))) + }) + + It("enables 0-RTT", func() { + qtlsConf := tlsConfigToQtlsConfig(nil, nil) + Expect(qtlsConf.Enable0RTT).To(BeFalse()) + qtlsConf = tlsConfigToQtlsConfig(nil, &ExtraConfig{Enable0RTT: true}) Expect(qtlsConf.Enable0RTT).To(BeTrue()) - Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(0xffffffff))) }) It("initializes such that the session ticket key remains constant", func() { tlsConf := &tls.Config{} - qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) - qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) + qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil) + qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil) Expect(qtlsConf1.SessionTicketKey).ToNot(BeZero()) // should now contain a random value Expect(qtlsConf1.SessionTicketKey).To(Equal(qtlsConf2.SessionTicketKey)) }) Context("GetConfigForClient callback", func() { It("doesn't set it if absent", func() { - qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) + qtlsConf := tlsConfigToQtlsConfig(nil, nil) Expect(qtlsConf.GetConfigForClient).To(BeNil()) }) - It("returns a qtls.Config", func() { + It("returns a Config", func() { tlsConf := &tls.Config{ GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return &tls.Config{ServerName: "foo.bar"}, nil }, } - extHandler := &mockExtensionHandler{} - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, utils.NewRTTStats(), nil, nil, nil, nil, false) + var received bool + qtlsConf := tlsConfigToQtlsConfig(tlsConf, &ExtraConfig{ReceivedExtensions: func(uint8, []Extension) { received = true }}) Expect(qtlsConf.GetConfigForClient).ToNot(BeNil()) confForClient, err := qtlsConf.GetConfigForClient(nil) Expect(err).ToNot(HaveOccurred()) @@ -106,9 +98,10 @@ var _ = Describe("qtls.Config", func() { Expect(confForClient).ToNot(BeNil()) Expect(confForClient.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) Expect(confForClient.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13)) - Expect(extHandler.get).To(BeFalse()) - confForClient.GetExtensions(10) - Expect(extHandler.get).To(BeTrue()) + Expect(received).To(BeFalse()) + Expect(confForClient.ReceivedExtensions).ToNot(BeNil()) + confForClient.ReceivedExtensions(10, nil) + Expect(received).To(BeTrue()) }) It("returns errors", func() { @@ -118,7 +111,7 @@ var _ = Describe("qtls.Config", func() { return nil, testErr }, } - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil) _, err := qtlsConf.GetConfigForClient(nil) Expect(err).To(MatchError(testErr)) }) @@ -129,7 +122,7 @@ var _ = Describe("qtls.Config", func() { return nil, nil }, } - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil) Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil()) }) }) @@ -141,7 +134,7 @@ var _ = Describe("qtls.Config", func() { return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil }, } - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil) qtlsCert, err := qtlsConf.GetCertificate(nil) Expect(err).ToNot(HaveOccurred()) Expect(qtlsCert).ToNot(BeNil()) @@ -149,7 +142,7 @@ var _ = Describe("qtls.Config", func() { }) It("doesn't set it if absent", func() { - qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) + qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil) Expect(qtlsConf.GetCertificate).To(BeNil()) }) @@ -159,7 +152,7 @@ var _ = Describe("qtls.Config", func() { return nil, errors.New("test") }, } - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil) _, err := qtlsConf.GetCertificate(nil) Expect(err).To(MatchError("test")) }) @@ -170,21 +163,21 @@ var _ = Describe("qtls.Config", func() { return nil, nil }, } - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil) Expect(qtlsConf.GetCertificate(nil)).To(BeNil()) }) }) Context("ClientSessionCache", func() { It("doesn't set if absent", func() { - qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) + qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil) Expect(qtlsConf.ClientSessionCache).To(BeNil()) }) It("puts a nil session state", func() { - csc := NewMockClientSessionCache(mockCtrl) + csc := mocktls.NewMockClientSessionCache(mockCtrl) tlsConf := &tls.Config{ClientSessionCache: csc} - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil) // put something csc.EXPECT().Put("foobar", nil) qtlsConf.ClientSessionCache.Put("foobar", nil) @@ -192,25 +185,25 @@ var _ = Describe("qtls.Config", func() { }) }) -var _ = Describe("qtls.Config generation", func() { - It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo", func() { +var _ = Describe("Config generation", func() { + It("converts a ClientHelloInfo to a tls.ClientHelloInfo", func() { chi := &qtlsClientHelloInfo{ CipherSuites: []uint16{1, 2, 3}, ServerName: "foo.bar", - SupportedCurves: []qtls.CurveID{4, 5, 6}, + SupportedCurves: []tls.CurveID{4, 5, 6}, SupportedPoints: []uint8{7, 8, 9}, - SignatureSchemes: []qtls.SignatureScheme{10, 11, 12}, + SignatureSchemes: []tls.SignatureScheme{10, 11, 12}, SupportedProtos: []string{"foo", "bar"}, SupportedVersions: []uint16{13, 14, 15}, Conn: &net.UDPConn{}, - config: &qtls.Config{ + config: &Config{ MinVersion: tls.VersionTLS10, MaxVersion: tls.VersionTLS12, CipherSuites: []uint16{16, 17, 18}, - CurvePreferences: []qtls.CurveID{19, 20, 21}, + CurvePreferences: []tls.CurveID{19, 20, 21}, }, } - tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi))) + tlsCHI := toTLSClientHelloInfo((*ClientHelloInfo)(unsafe.Pointer(chi))) Expect(tlsCHI.CipherSuites).To(Equal([]uint16{1, 2, 3})) Expect(tlsCHI.ServerName).To(Equal("foo.bar")) Expect(tlsCHI.SupportedCurves).To(Equal([]tls.CurveID{4, 5, 6})) @@ -226,9 +219,9 @@ var _ = Describe("qtls.Config generation", func() { Expect(c.config.CurvePreferences).To(Equal([]tls.CurveID{19, 20, 21})) }) - It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() { + It("converts a ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() { chi := &qtlsClientHelloInfo{CipherSuites: []uint16{13, 37}} - tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi))) + tlsCHI := toTLSClientHelloInfo((*ClientHelloInfo)(unsafe.Pointer(chi))) Expect(tlsCHI.CipherSuites).To(Equal([]uint16{13, 37})) }) }) diff --git a/internal/qtls/structs_equal.go b/internal/qtls/structs_equal.go new file mode 100644 index 00000000..7df0423a --- /dev/null +++ b/internal/qtls/structs_equal.go @@ -0,0 +1,19 @@ +package qtls + +import "reflect" + +func structsEqual(a, b interface{}) bool { + sa := reflect.ValueOf(a).Elem() + sb := reflect.ValueOf(b).Elem() + if sa.NumField() != sb.NumField() { + return false + } + for i := 0; i < sa.NumField(); i++ { + fa := sa.Type().Field(i) + fb := sb.Type().Field(i) + if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { + return false + } + } + return true +} diff --git a/internal/handshake/unsafe_test.go b/internal/qtls/structs_equal_test.go similarity index 98% rename from internal/handshake/unsafe_test.go rename to internal/qtls/structs_equal_test.go index 360a07bb..2a2f46c2 100644 --- a/internal/handshake/unsafe_test.go +++ b/internal/qtls/structs_equal_test.go @@ -1,4 +1,4 @@ -package handshake +package qtls import ( . "github.com/onsi/ginkgo" diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index e92cbc75..656f77a2 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -11,7 +11,6 @@ import ( gomock "github.com/golang/mock/gomock" protocol "github.com/lucas-clemente/quic-go/internal/protocol" - qtls "github.com/marten-seemann/qtls" ) // MockQuicSession is a mock of QuicSession interface @@ -82,10 +81,10 @@ func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *g } // ConnectionState mocks base method -func (m *MockQuicSession) ConnectionState() qtls.ConnectionState { +func (m *MockQuicSession) ConnectionState() ConnectionState { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(qtls.ConnectionState) + ret0, _ := ret[0].(ConnectionState) return ret0 } diff --git a/mockgen_private.sh b/mockgen_private.sh index e83b6a5e..f0855bbf 100755 --- a/mockgen_private.sh +++ b/mockgen_private.sh @@ -10,6 +10,7 @@ echo -e "package $1\n" > $TMPFILE echo "type $INTERFACE_NAME = $4" >> $TMPFILE mockgen -package $1 -self_package $PACKAGE -destination $2 $PACKAGE $INTERFACE_NAME +mv $2 $TMPFILE && sed 's/qtls.ConnectionState/ConnectionState/g' $TMPFILE > $2 goimports -w $2 rm $TMPFILE