From d7334c16e7d03fdf692c8e2a3d110687b90a1466 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 31 Aug 2023 13:33:40 +0700 Subject: [PATCH] move the DisableVersionNegotiationPackets flag to the Transport (#4047) * move the DisableVersionNegotiationPackets flag to the Transport * add an integration test for DisableVersionNegotiationPackets --- config.go | 41 ++++++++-------- config_test.go | 1 - .../versionnegotiation/handshake_test.go | 44 ++++++++++++++++- interface.go | 4 -- server.go | 41 ++++++++-------- server_test.go | 2 +- transport.go | 47 ++++++++++--------- 7 files changed, 112 insertions(+), 68 deletions(-) diff --git a/config.go b/config.go index 59df4cfd..59a4d922 100644 --- a/config.go +++ b/config.go @@ -110,26 +110,25 @@ func populateConfig(config *Config) *Config { } return &Config{ - GetConfigForClient: config.GetConfigForClient, - Versions: versions, - HandshakeIdleTimeout: handshakeIdleTimeout, - MaxIdleTimeout: idleTimeout, - MaxTokenAge: config.MaxTokenAge, - MaxRetryTokenAge: config.MaxRetryTokenAge, - RequireAddressValidation: config.RequireAddressValidation, - KeepAlivePeriod: config.KeepAlivePeriod, - InitialStreamReceiveWindow: initialStreamReceiveWindow, - MaxStreamReceiveWindow: maxStreamReceiveWindow, - InitialConnectionReceiveWindow: initialConnectionReceiveWindow, - MaxConnectionReceiveWindow: maxConnectionReceiveWindow, - AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, - MaxIncomingStreams: maxIncomingStreams, - MaxIncomingUniStreams: maxIncomingUniStreams, - TokenStore: config.TokenStore, - EnableDatagrams: config.EnableDatagrams, - DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, - DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets, - Allow0RTT: config.Allow0RTT, - Tracer: config.Tracer, + GetConfigForClient: config.GetConfigForClient, + Versions: versions, + HandshakeIdleTimeout: handshakeIdleTimeout, + MaxIdleTimeout: idleTimeout, + MaxTokenAge: config.MaxTokenAge, + MaxRetryTokenAge: config.MaxRetryTokenAge, + RequireAddressValidation: config.RequireAddressValidation, + KeepAlivePeriod: config.KeepAlivePeriod, + InitialStreamReceiveWindow: initialStreamReceiveWindow, + MaxStreamReceiveWindow: maxStreamReceiveWindow, + InitialConnectionReceiveWindow: initialConnectionReceiveWindow, + MaxConnectionReceiveWindow: maxConnectionReceiveWindow, + AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, + MaxIncomingStreams: maxIncomingStreams, + MaxIncomingUniStreams: maxIncomingUniStreams, + TokenStore: config.TokenStore, + EnableDatagrams: config.EnableDatagrams, + DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, + Allow0RTT: config.Allow0RTT, + Tracer: config.Tracer, } } diff --git a/config_test.go b/config_test.go index 1eca3d5d..7208b4ad 100644 --- a/config_test.go +++ b/config_test.go @@ -192,7 +192,6 @@ var _ = Describe("Config", func() { Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow)) Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams)) Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams)) - Expect(c.DisableVersionNegotiationPackets).To(BeFalse()) Expect(c.DisablePathMTUDiscovery).To(BeFalse()) Expect(c.GetConfigForClient).To(BeNil()) }) diff --git a/integrationtests/versionnegotiation/handshake_test.go b/integrationtests/versionnegotiation/handshake_test.go index 965700c1..a079f6e1 100644 --- a/integrationtests/versionnegotiation/handshake_test.go +++ b/integrationtests/versionnegotiation/handshake_test.go @@ -3,8 +3,10 @@ package versionnegotiation import ( "context" "crypto/tls" + "errors" "fmt" "net" + "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/integrationtests/tools/israce" @@ -113,7 +115,7 @@ var _ = Describe("Handshake tests", func() { It("when the client supports more versions than the server supports", func() { expectedVersion := protocol.SupportedVersions[0] - // the server doesn't support the highest supported version, which is the first one the client will try + // The server doesn't support the highest supported version, which is the first one the client will try, // but it supports a bunch of versions that the client doesn't speak serverTracer := &versionNegotiationTracer{} serverConfig := &quic.Config{} @@ -147,5 +149,45 @@ var _ = Describe("Handshake tests", func() { Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) Expect(serverTracer.clientVersions).To(BeEmpty()) }) + + It("fails if the server disables version negotiation", func() { + // The server doesn't support the highest supported version, which is the first one the client will try, + // but it supports a bunch of versions that the client doesn't speak + serverTracer := &versionNegotiationTracer{} + serverConfig := &quic.Config{} + serverConfig.Versions = supportedVersions + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return serverTracer + } + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + Expect(err).ToNot(HaveOccurred()) + tr := &quic.Transport{ + Conn: conn, + DisableVersionNegotiationPackets: true, + } + ln, err := tr.Listen(getTLSConfig(), serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} + clientTracer := &versionNegotiationTracer{} + _, err = quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port), + getTLSClientConfig(), + maybeAddQLOGTracer(&quic.Config{ + Versions: clientVersions, + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return clientTracer + }, + HandshakeIdleTimeout: 100 * time.Millisecond, + }), + ) + Expect(err).To(HaveOccurred()) + var nerr net.Error + Expect(errors.As(err, &nerr)).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(clientTracer.receivedVersionNegotiation).To(BeFalse()) + }) } }) diff --git a/interface.go b/interface.go index c3fb2b10..5d5ab5b0 100644 --- a/interface.go +++ b/interface.go @@ -322,10 +322,6 @@ type Config struct { // Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit. // If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. DisablePathMTUDiscovery bool - // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. - // This can be useful if version information is exchanged out-of-band. - // It has no effect for a client. - DisableVersionNegotiationPackets bool // Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted. // Only valid for the server. Allow0RTT bool diff --git a/server.go b/server.go index c06228c9..14cc6f82 100644 --- a/server.go +++ b/server.go @@ -59,7 +59,8 @@ type zeroRTTQueue struct { type baseServer struct { mutex sync.Mutex - acceptEarlyConns bool + disableVersionNegotiation bool + acceptEarlyConns bool tlsConf *tls.Config config *Config @@ -226,6 +227,7 @@ func newServer( config *Config, tracer logging.Tracer, onClose func(), + disableVersionNegotiation bool, acceptEarly bool, ) (*baseServer, error) { tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) @@ -233,23 +235,24 @@ func newServer( return nil, err } s := &baseServer{ - conn: conn, - tlsConf: tlsConf, - config: config, - tokenGenerator: tokenGenerator, - connIDGenerator: connIDGenerator, - connHandler: connHandler, - connQueue: make(chan quicConn), - errorChan: make(chan struct{}), - running: make(chan struct{}), - receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), - versionNegotiationQueue: make(chan receivedPacket, 4), - invalidTokenQueue: make(chan receivedPacket, 4), - newConn: newConnection, - tracer: tracer, - logger: utils.DefaultLogger.WithPrefix("server"), - acceptEarlyConns: acceptEarly, - onClose: onClose, + conn: conn, + tlsConf: tlsConf, + config: config, + tokenGenerator: tokenGenerator, + connIDGenerator: connIDGenerator, + connHandler: connHandler, + connQueue: make(chan quicConn), + errorChan: make(chan struct{}), + running: make(chan struct{}), + receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), + versionNegotiationQueue: make(chan receivedPacket, 4), + invalidTokenQueue: make(chan receivedPacket, 4), + newConn: newConnection, + tracer: tracer, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlyConns: acceptEarly, + disableVersionNegotiation: disableVersionNegotiation, + onClose: onClose, } if acceptEarly { s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} @@ -383,7 +386,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st } // send a Version Negotiation Packet if the client is speaking a different protocol version if !protocol.IsSupportedVersion(s.config.Versions, v) { - if s.config.DisableVersionNegotiationPackets { + if s.disableVersionNegotiation { return false } diff --git a/server_test.go b/server_test.go index 4705225a..e959aee5 100644 --- a/server_test.go +++ b/server_test.go @@ -357,7 +357,7 @@ var _ = Describe("Server", func() { }) It("doesn't send a Version Negotiation packets if sending them is disabled", func() { - serv.config.DisableVersionNegotiationPackets = true + serv.disableVersionNegotiation = true srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) packet := getPacket(&wire.Header{ diff --git a/transport.go b/transport.go index d8da9b1a..9f93cbb9 100644 --- a/transport.go +++ b/transport.go @@ -57,6 +57,11 @@ type Transport struct { // See section 10.3 of RFC 9000 for details. StatelessResetKey *StatelessResetKey + // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. + // This can be useful if version information is exchanged out-of-band. + // It has no effect for clients. + DisableVersionNegotiationPackets bool + // A Tracer traces events that don't belong to a single QUIC connection. Tracer logging.Tracer @@ -95,28 +100,10 @@ type Transport struct { // There can only be a single listener on any net.PacketConn. // Listen may only be called again after the current Listener was closed. func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(conf); err != nil { - return nil, err - } - - t.mutex.Lock() - defer t.mutex.Unlock() - - if t.server != nil { - return nil, errListenerAlreadySet - } - conf = populateServerConfig(conf) - if err := t.init(false); err != nil { - return nil, err - } - s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false) + s, err := t.createServer(tlsConf, conf, false) if err != nil { return nil, err } - t.server = s return &Listener{baseServer: s}, nil } @@ -124,6 +111,14 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) // There can only be a single listener on any net.PacketConn. // Listen may only be called again after the current Listener was closed. func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) { + s, err := t.createServer(tlsConf, conf, true) + if err != nil { + return nil, err + } + return &EarlyListener{baseServer: s}, nil +} + +func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bool) (*baseServer, error) { if tlsConf == nil { return nil, errors.New("quic: tls.Config not set") } @@ -141,12 +136,22 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen if err := t.init(false); err != nil { return nil, err } - s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true) + s, err := newServer( + t.conn, + t.handlerMap, + t.connIDGenerator, + tlsConf, + conf, + t.Tracer, + t.closeServer, + t.DisableVersionNegotiationPackets, + allow0RTT, + ) if err != nil { return nil, err } t.server = s - return &EarlyListener{baseServer: s}, nil + return s, nil } // Dial dials a new connection to a remote host (not using 0-RTT).