diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 5b492e5a..3dd7d387 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -237,7 +237,7 @@ func newCryptoSetup( tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) } - extHandler := newExtensionHandler(tp.Marshal(perspective), perspective) + extHandler := newExtensionHandler(tp.Marshal(perspective), perspective, version) cs := &cryptoSetup{ tlsConf: tlsConf, initialStream: initialStream, diff --git a/internal/handshake/tls_extension_handler.go b/internal/handshake/tls_extension_handler.go index 33409b8e..ece661fc 100644 --- a/internal/handshake/tls_extension_handler.go +++ b/internal/handshake/tls_extension_handler.go @@ -5,23 +5,33 @@ import ( "github.com/lucas-clemente/quic-go/internal/qtls" ) -const quicTLSExtensionType = 0xffa5 +const ( + quicTLSExtensionTypeOldDrafts = 0xffa5 + quicTLSExtensionType = 0x39 +) type extensionHandler struct { ourParams []byte paramsChan chan []byte + extensionType uint16 + perspective protocol.Perspective } var _ tlsExtensionHandler = &extensionHandler{} // newExtensionHandler creates a new extension handler -func newExtensionHandler(params []byte, pers protocol.Perspective) tlsExtensionHandler { +func newExtensionHandler(params []byte, pers protocol.Perspective, v protocol.VersionNumber) tlsExtensionHandler { + et := uint16(quicTLSExtensionType) + if v != protocol.VersionDraft34 { + et = quicTLSExtensionTypeOldDrafts + } return &extensionHandler{ - ourParams: params, - paramsChan: make(chan []byte), - perspective: pers, + ourParams: params, + paramsChan: make(chan []byte), + perspective: pers, + extensionType: et, } } @@ -31,7 +41,7 @@ func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension { return nil } return []qtls.Extension{{ - Type: quicTLSExtensionType, + Type: h.extensionType, Data: h.ourParams, }} } @@ -44,7 +54,7 @@ func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extensi var data []byte for _, ext := range exts { - if ext.Type == quicTLSExtensionType { + if ext.Type == h.extensionType { data = ext.Data break } diff --git a/internal/handshake/tls_extension_handler_test.go b/internal/handshake/tls_extension_handler_test.go index 453c5101..33e9a4c8 100644 --- a/internal/handshake/tls_extension_handler_test.go +++ b/internal/handshake/tls_extension_handler_test.go @@ -1,6 +1,8 @@ package handshake import ( + "fmt" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qtls" @@ -12,39 +14,61 @@ var _ = Describe("TLS Extension Handler, for the server", func() { var ( handlerServer tlsExtensionHandler handlerClient tlsExtensionHandler + version protocol.VersionNumber ) BeforeEach(func() { + version = protocol.VersionDraft29 + }) + + JustBeforeEach(func() { handlerServer = newExtensionHandler( []byte("foobar"), protocol.PerspectiveServer, + version, ) handlerClient = newExtensionHandler( []byte("raboof"), protocol.PerspectiveClient, + version, ) }) Context("for the server", func() { - Context("sending", func() { - It("only adds TransportParameters for the Encrypted Extensions", func() { - // test 2 other handshake types - Expect(handlerServer.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) - Expect(handlerServer.GetExtensions(uint8(typeFinished))).To(BeEmpty()) - }) + for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.VersionDraft34} { + v := ver - It("adds TransportParameters to the EncryptedExtensions message", func() { - exts := handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) - Expect(exts).To(HaveLen(1)) - Expect(exts[0].Type).To(BeEquivalentTo(quicTLSExtensionType)) - Expect(exts[0].Data).To(Equal([]byte("foobar"))) + Context(fmt.Sprintf("sending, for version %s", v), func() { + var extensionType uint16 + + BeforeEach(func() { + version = v + if v == protocol.VersionDraft29 { + extensionType = quicTLSExtensionTypeOldDrafts + } else { + extensionType = quicTLSExtensionType + } + }) + + It("only adds TransportParameters for the Encrypted Extensions", func() { + // test 2 other handshake types + Expect(handlerServer.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) + Expect(handlerServer.GetExtensions(uint8(typeFinished))).To(BeEmpty()) + }) + + It("adds TransportParameters to the EncryptedExtensions message", func() { + exts := handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) + Expect(exts).To(HaveLen(1)) + Expect(exts[0].Type).To(BeEquivalentTo(extensionType)) + Expect(exts[0].Data).To(Equal([]byte("foobar"))) + }) }) - }) + } Context("receiving", func() { var chExts []qtls.Extension - BeforeEach(func() { + JustBeforeEach(func() { chExts = handlerClient.GetExtensions(uint8(typeClientHello)) Expect(chExts).To(HaveLen(1)) }) @@ -98,25 +122,40 @@ var _ = Describe("TLS Extension Handler, for the server", func() { }) Context("for the client", func() { - Context("sending", func() { - It("only adds TransportParameters for the Encrypted Extensions", func() { - // test 2 other handshake types - Expect(handlerClient.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) - Expect(handlerClient.GetExtensions(uint8(typeFinished))).To(BeEmpty()) - }) + for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.VersionDraft34} { + v := ver - It("adds TransportParameters to the ClientHello message", func() { - exts := handlerClient.GetExtensions(uint8(typeClientHello)) - Expect(exts).To(HaveLen(1)) - Expect(exts[0].Type).To(BeEquivalentTo(quicTLSExtensionType)) - Expect(exts[0].Data).To(Equal([]byte("raboof"))) + Context(fmt.Sprintf("sending, for version %s", v), func() { + var extensionType uint16 + + BeforeEach(func() { + version = v + if v == protocol.VersionDraft29 { + extensionType = quicTLSExtensionTypeOldDrafts + } else { + extensionType = quicTLSExtensionType + } + }) + + It("only adds TransportParameters for the Encrypted Extensions", func() { + // test 2 other handshake types + Expect(handlerClient.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) + Expect(handlerClient.GetExtensions(uint8(typeFinished))).To(BeEmpty()) + }) + + It("adds TransportParameters to the ClientHello message", func() { + exts := handlerClient.GetExtensions(uint8(typeClientHello)) + Expect(exts).To(HaveLen(1)) + Expect(exts[0].Type).To(BeEquivalentTo(extensionType)) + Expect(exts[0].Data).To(Equal([]byte("raboof"))) + }) }) - }) + } Context("receiving", func() { var chExts []qtls.Extension - BeforeEach(func() { + JustBeforeEach(func() { chExts = handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) Expect(chExts).To(HaveLen(1)) })