From c968b18a21f1d77872db1b3f6ad2d65e67b25949 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 29 Oct 2020 13:08:16 +0700 Subject: [PATCH] select the H3 ALPN based on the QUIC version in use (for the H3 server) --- fuzzing/handshake/cmd/corpus.go | 2 + fuzzing/handshake/fuzz.go | 2 + http3/client.go | 2 +- http3/client_test.go | 4 +- http3/server.go | 54 ++++++++++------- http3/server_test.go | 80 ++++++++++++++++--------- internal/handshake/crypto_setup.go | 29 +++++---- internal/handshake/crypto_setup_test.go | 13 ++++ internal/handshake/interface.go | 8 +++ session.go | 2 + 10 files changed, 134 insertions(+), 62 deletions(-) diff --git a/fuzzing/handshake/cmd/corpus.go b/fuzzing/handshake/cmd/corpus.go index 68e46084..7150d094 100644 --- a/fuzzing/handshake/cmd/corpus.go +++ b/fuzzing/handshake/cmd/corpus.go @@ -90,6 +90,7 @@ func main() { utils.NewRTTStats(), nil, utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, ) sChunkChan, sInitialStream, sHandshakeStream := initStreams() @@ -108,6 +109,7 @@ func main() { utils.NewRTTStats(), nil, utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, ) serverHandshakeCompleted := make(chan struct{}) diff --git a/fuzzing/handshake/fuzz.go b/fuzzing/handshake/fuzz.go index 0d49b813..8c0d0ce4 100644 --- a/fuzzing/handshake/fuzz.go +++ b/fuzzing/handshake/fuzz.go @@ -387,6 +387,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. utils.NewRTTStats(), nil, utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, ) sChunkChan, sInitialStream, sHandshakeStream := initStreams() @@ -403,6 +404,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. utils.NewRTTStats(), nil, utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, ) if len(data) == 0 { diff --git a/http3/client.go b/http3/client.go index bf850c0e..e05ca2b4 100644 --- a/http3/client.go +++ b/http3/client.go @@ -69,7 +69,7 @@ func newClient( tlsConf = tlsConf.Clone() } // Replace existing ALPNs by H3 - tlsConf.NextProtos = []string{nextProtoH3} + tlsConf.NextProtos = []string{nextProtoH3Draft29} if quicConfig == nil { quicConfig = defaultQuicConfig } diff --git a/http3/client_test.go b/http3/client_test.go index c0cc15b9..0a2474f1 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -53,7 +53,7 @@ var _ = Describe("Client", func() { var dialAddrCalled bool dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) { Expect(quicConf).To(Equal(defaultQuicConfig)) - Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3})) + Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3Draft29})) dialAddrCalled = true return nil, errors.New("test done") } @@ -90,7 +90,7 @@ var _ = Describe("Client", func() { ) (quic.EarlySession, error) { Expect(hostname).To(Equal("localhost:1337")) Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) - Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3})) + Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3Draft29})) Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) dialAddrCalled = true return nil, errors.New("test done") diff --git a/http3/server.go b/http3/server.go index 5e755297..e90b3084 100644 --- a/http3/server.go +++ b/http3/server.go @@ -10,11 +10,13 @@ import ( "net" "net/http" "runtime" + "strings" "sync" "sync/atomic" "time" "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/marten-seemann/qpack" ) @@ -25,7 +27,12 @@ var ( quicListenAddr = quic.ListenAddrEarly ) -const nextProtoH3 = "h3-29" +const ( + nextProtoH3Draft29 = "h3-29" + nextProtoH3Draft32 = "h3-32" +) + +var supportedVersions = []string{nextProtoH3Draft29, nextProtoH3Draft32} // contextKey is a value for use with context.WithValue. It's used as // a pointer so it fits in an interface{} without allocation. @@ -115,32 +122,36 @@ func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { s.logger = utils.DefaultLogger.WithPrefix("server") }) - if tlsConf == nil { - tlsConf = &tls.Config{} - } else { - tlsConf = tlsConf.Clone() - } - // Replace existing ALPNs by H3 - tlsConf.NextProtos = []string{nextProtoH3} - if tlsConf.GetConfigForClient != nil { - getConfigForClient := tlsConf.GetConfigForClient - tlsConf.GetConfigForClient = func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - conf, err := getConfigForClient(ch) - if err != nil || conf == nil { - return conf, err + // The tls.Config we pass to Listen needs to have the GetConfigForClient callback set. + // That way, we can get the QUIC version and set the correct ALPN value. + baseConf := &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + // determine the ALPN from the QUIC version used + proto := nextProtoH3Draft29 + if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok && qconn.GetQUICVersion() == quic.VersionDraft32 { + proto = nextProtoH3Draft32 + } + conf := tlsConf + if tlsConf.GetConfigForClient != nil { + getConfigForClient := tlsConf.GetConfigForClient + var err error + conf, err = getConfigForClient(ch) + if err != nil { + return nil, err + } } conf = conf.Clone() - conf.NextProtos = []string{nextProtoH3} + conf.NextProtos = []string{proto} return conf, nil - } + }, } var ln quic.EarlyListener var err error if conn == nil { - ln, err = quicListenAddr(s.Addr, tlsConf, s.QuicConfig) + ln, err = quicListenAddr(s.Addr, baseConf, s.QuicConfig) } else { - ln, err = quicListen(conn, tlsConf, s.QuicConfig) + ln, err = quicListen(conn, baseConf, s.QuicConfig) } if err != nil { return err @@ -344,8 +355,11 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error { atomic.StoreUint32(&s.port, port) } - hdr.Add("Alt-Svc", fmt.Sprintf(`%s=":%d"; ma=2592000`, nextProtoH3, port)) - + altSvc := make([]string, len(supportedVersions)) + for i, v := range supportedVersions { + altSvc[i] = fmt.Sprintf(`%s=":%d"; ma=2592000`, v, port) + } + hdr.Add("Alt-Svc", strings.Join(altSvc, ",")) return nil } diff --git a/http3/server_test.go b/http3/server_test.go index b4c42e35..6815b22f 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -5,7 +5,6 @@ import ( "context" "crypto/tls" "errors" - "fmt" "io" "net" "net/http" @@ -14,6 +13,7 @@ import ( "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go" mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/marten-seemann/qpack" @@ -22,6 +22,19 @@ import ( . "github.com/onsi/gomega" ) +type mockConn struct { + net.Conn + version protocol.VersionNumber +} + +func newMockConn(version protocol.VersionNumber) net.Conn { + return &mockConn{version: version} +} + +func (c *mockConn) GetQUICVersion() protocol.VersionNumber { + return c.version +} + var _ = Describe("Server", func() { var ( s *Server @@ -339,19 +352,10 @@ var _ = Describe("Server", func() { }) Context("setting http headers", func() { - var expected http.Header - - getExpectedHeader := func() http.Header { - return http.Header{ - "Alt-Svc": {fmt.Sprintf(`%s=":443"; ma=2592000`, nextProtoH3)}, - } + expected := http.Header{ + "Alt-Svc": {`h3-29=":443"; ma=2592000,h3-32=":443"; ma=2592000`}, } - BeforeEach(func() { - Expect(getExpectedHeader()).To(Equal(http.Header{"Alt-Svc": {nextProtoH3 + `=":443"; ma=2592000`}})) - expected = getExpectedHeader() - }) - It("sets proper headers with numeric port", func() { s.Server.Addr = ":443" hdr := http.Header{} @@ -496,6 +500,15 @@ var _ = Describe("Server", func() { Expect(s.Close()).To(Succeed()) }) + checkGetConfigForClientVersions := func(conf *tls.Config) { + c, err := conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft29)}) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + ExpectWithOffset(1, c.NextProtos).To(Equal([]string{nextProtoH3Draft29})) + c, err = conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft32)}) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + ExpectWithOffset(1, c.NextProtos).To(Equal([]string{nextProtoH3Draft32})) + } + It("uses the quic.Config to start the QUIC server", func() { conf := &quic.Config{HandshakeTimeout: time.Nanosecond} var receivedConf *quic.Config @@ -508,8 +521,11 @@ var _ = Describe("Server", func() { Expect(receivedConf).To(Equal(conf)) }) - It("replaces the ALPN token to the tls.Config", func() { - tlsConf := &tls.Config{NextProtos: []string{"foo", "bar"}} + It("sets the GetConfigForClient and replaces the ALPN token to the tls.Config, if the GetConfigForClient callback is not set", func() { + tlsConf := &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + NextProtos: []string{"foo", "bar"}, + } var receivedConf *tls.Config quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { receivedConf = tlsConf @@ -517,25 +533,35 @@ var _ = Describe("Server", func() { } s.TLSConfig = tlsConf Expect(s.ListenAndServe()).To(HaveOccurred()) - Expect(receivedConf.NextProtos).To(Equal([]string{nextProtoH3})) + Expect(receivedConf.NextProtos).To(BeEmpty()) + Expect(receivedConf.ClientAuth).To(BeZero()) // make sure the original tls.Config was not modified Expect(tlsConf.NextProtos).To(Equal([]string{"foo", "bar"})) + // make sure that the config returned from the GetConfigForClient callback sets the fields of the original config + conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert)) + checkGetConfigForClientVersions(receivedConf) }) - It("uses the ALPN token if no tls.Config is given", func() { + It("sets the GetConfigForClient callback if no tls.Config is given", func() { var receivedConf *tls.Config quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { receivedConf = tlsConf return nil, errors.New("listen err") } Expect(s.ListenAndServe()).To(HaveOccurred()) - Expect(receivedConf.NextProtos).To(Equal([]string{nextProtoH3})) + Expect(receivedConf).ToNot(BeNil()) + checkGetConfigForClientVersions(receivedConf) }) It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { tlsConf := &tls.Config{ GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - return &tls.Config{NextProtos: []string{"foo", "bar"}}, nil + return &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + NextProtos: []string{"foo", "bar"}, + }, nil }, } @@ -546,14 +572,15 @@ var _ = Describe("Server", func() { } s.TLSConfig = tlsConf Expect(s.ListenAndServe()).To(HaveOccurred()) - // check that the config used by QUIC uses the h3 ALPN - conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf.NextProtos).To(Equal([]string{nextProtoH3})) // check that the original config was not modified - conf, err = tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) Expect(err).ToNot(HaveOccurred()) Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"})) + // check that the config returned by the GetConfigForClient callback uses the returned config + conf, err = receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert)) + checkGetConfigForClientVersions(receivedConf) }) It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { @@ -571,14 +598,11 @@ var _ = Describe("Server", func() { } s.TLSConfig = tlsConf Expect(s.ListenAndServe()).To(HaveOccurred()) - // check that the config used by QUIC uses the h3 ALPN - conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf.NextProtos).To(Equal([]string{nextProtoH3})) // check that the original config was not modified - conf, err = tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) Expect(err).ToNot(HaveOccurred()) Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"})) + checkGetConfigForClientVersions(receivedConf) }) }) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 3107fffb..91f77a78 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -62,25 +62,30 @@ const clientSessionStateRevision = 3 type conn struct { localAddr, remoteAddr net.Addr + version protocol.VersionNumber } -func newConn(local, remote net.Addr) net.Conn { +var _ ConnWithVersion = &conn{} + +func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion { return &conn{ localAddr: local, remoteAddr: remote, + version: version, } } var _ net.Conn = &conn{} -func (c *conn) Read([]byte) (int, error) { return 0, nil } -func (c *conn) Write([]byte) (int, error) { return 0, nil } -func (c *conn) Close() error { return nil } -func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } -func (c *conn) LocalAddr() net.Addr { return c.localAddr } -func (c *conn) SetReadDeadline(time.Time) error { return nil } -func (c *conn) SetWriteDeadline(time.Time) error { return nil } -func (c *conn) SetDeadline(time.Time) error { return nil } +func (c *conn) Read([]byte) (int, error) { return 0, nil } +func (c *conn) Write([]byte) (int, error) { return 0, nil } +func (c *conn) Close() error { return nil } +func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *conn) LocalAddr() net.Addr { return c.localAddr } +func (c *conn) SetReadDeadline(time.Time) error { return nil } +func (c *conn) SetWriteDeadline(time.Time) error { return nil } +func (c *conn) SetDeadline(time.Time) error { return nil } +func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version } type cryptoSetup struct { tlsConf *tls.Config @@ -156,6 +161,7 @@ func NewCryptoSetupClient( rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, + version protocol.VersionNumber, ) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { cs, clientHelloWritten := newCryptoSetup( initialStream, @@ -170,7 +176,7 @@ func NewCryptoSetupClient( logger, protocol.PerspectiveClient, ) - cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) + cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) return cs, clientHelloWritten } @@ -188,6 +194,7 @@ func NewCryptoSetupServer( rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, + version protocol.VersionNumber, ) CryptoSetup { cs, _ := newCryptoSetup( initialStream, @@ -202,7 +209,7 @@ func NewCryptoSetupServer( logger, protocol.PerspectiveServer, ) - cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) + cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) return cs } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 6eebbe8a..8a663289 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -99,6 +99,7 @@ var _ = Describe("Crypto Setup TLS", func() { &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, ) done := make(chan struct{}) @@ -139,6 +140,7 @@ var _ = Describe("Crypto Setup TLS", func() { &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, ) done := make(chan struct{}) @@ -182,6 +184,7 @@ var _ = Describe("Crypto Setup TLS", func() { &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, ) done := make(chan struct{}) @@ -218,6 +221,7 @@ var _ = Describe("Crypto Setup TLS", func() { &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, ) done := make(chan struct{}) @@ -334,6 +338,7 @@ var _ = Describe("Crypto Setup TLS", func() { clientRTTStats, nil, utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, ) var sHandshakeComplete bool @@ -360,6 +365,7 @@ var _ = Describe("Crypto Setup TLS", func() { serverRTTStats, nil, utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, ) handshake(client, cChunkChan, server, sChunkChan) @@ -429,6 +435,7 @@ var _ = Describe("Crypto Setup TLS", func() { &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, ) done := make(chan struct{}) @@ -471,6 +478,7 @@ var _ = Describe("Crypto Setup TLS", func() { &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, ) sChunkChan, sInitialStream, sHandshakeStream := initStreams() @@ -495,6 +503,7 @@ var _ = Describe("Crypto Setup TLS", func() { &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, ) done := make(chan struct{}) @@ -528,6 +537,7 @@ var _ = Describe("Crypto Setup TLS", func() { &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, ) sChunkChan, sInitialStream, sHandshakeStream := initStreams() @@ -548,6 +558,7 @@ var _ = Describe("Crypto Setup TLS", func() { &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, ) done := make(chan struct{}) @@ -588,6 +599,7 @@ var _ = Describe("Crypto Setup TLS", func() { &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("client"), + protocol.VersionTLS, ) sChunkChan, sInitialStream, sHandshakeStream := initStreams() @@ -608,6 +620,7 @@ var _ = Describe("Crypto Setup TLS", func() { &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), + protocol.VersionTLS, ) done := make(chan struct{}) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index b64cd015..90b7238a 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -3,6 +3,7 @@ package handshake import ( "errors" "io" + "net" "time" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -90,3 +91,10 @@ type CryptoSetup interface { Get0RTTSealer() (LongHeaderSealer, error) Get1RTTSealer() (ShortHeaderSealer, error) } + +// ConnWithVersion is the connection used in the ClientHelloInfo. +// It can be used to determine the QUIC version in use. +type ConnWithVersion interface { + net.Conn + GetQUICVersion() protocol.VersionNumber +} diff --git a/session.go b/session.go index 5bede9e6..2275c834 100644 --- a/session.go +++ b/session.go @@ -325,6 +325,7 @@ var newSession = func( s.rttStats, tracer, logger, + s.version, ) s.cryptoStreamHandler = cs s.packer = newPacketPacker( @@ -442,6 +443,7 @@ var newClientSession = func( s.rttStats, tracer, logger, + s.version, ) s.clientHelloWritten = clientHelloWritten s.cryptoStreamHandler = cs