From a2cf43d75cdcd3483b5c8ec42c4b7390f6ff14d0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 22 Jan 2024 21:24:07 -0800 Subject: [PATCH] remove the RequireAddressValidation callback from the Config (#4253) --- config.go | 12 ---- config_test.go | 29 +------- connection_test.go | 2 +- integrationtests/self/handshake_drop_test.go | 71 +++++++++++--------- integrationtests/self/handshake_rtt_test.go | 13 +++- integrationtests/self/handshake_test.go | 14 +++- integrationtests/self/mitm_test.go | 35 ++++++---- integrationtests/self/zero_rtt_test.go | 19 ++++-- interface.go | 5 -- interop/http09/server.go | 7 +- interop/server/main.go | 13 ++-- server.go | 2 +- server_test.go | 31 ++++----- transport.go | 2 +- 14 files changed, 127 insertions(+), 128 deletions(-) diff --git a/config.go b/config.go index 49b9fc3f..ee032e6e 100644 --- a/config.go +++ b/config.go @@ -2,7 +2,6 @@ package quic import ( "fmt" - "net" "time" "github.com/quic-go/quic-go/internal/protocol" @@ -49,16 +48,6 @@ func validateConfig(config *Config) error { return nil } -// populateServerConfig populates fields in the quic.Config with their default values, if none are set -// it may be called with nil -func populateServerConfig(config *Config) *Config { - config = populateConfig(config) - if config.RequireAddressValidation == nil { - config.RequireAddressValidation = func(net.Addr) bool { return false } - } - return config -} - // populateConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateConfig(config *Config) *Config { @@ -111,7 +100,6 @@ func populateConfig(config *Config) *Config { Versions: versions, HandshakeIdleTimeout: handshakeIdleTimeout, MaxIdleTimeout: idleTimeout, - RequireAddressValidation: config.RequireAddressValidation, KeepAlivePeriod: config.KeepAlivePeriod, InitialStreamReceiveWindow: initialStreamReceiveWindow, MaxStreamReceiveWindow: maxStreamReceiveWindow, diff --git a/config_test.go b/config_test.go index e40c1cfc..500eb1de 100644 --- a/config_test.go +++ b/config_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "reflect" "time" @@ -23,7 +22,7 @@ var _ = Describe("Config", func() { }) It("validates a config with normal values", func() { - conf := populateServerConfig(&Config{ + conf := populateConfig(&Config{ MaxIncomingStreams: 5, MaxStreamReceiveWindow: 10, }) @@ -118,19 +117,16 @@ var _ = Describe("Config", func() { Context("cloning", func() { It("clones function fields", func() { - var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool + var calledAllowConnectionWindowIncrease, calledTracer bool c1 := &Config{ GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, - RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer { calledTracer = true return nil }, } c2 := c1.Clone() - c2.RequireAddressValidation(&net.UDPAddr{}) - Expect(calledAddrValidation).To(BeTrue()) c2.AllowConnectionWindowIncrease(nil, 1234) Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) _, err := c2.GetConfigForClient(&ClientHelloInfo{}) @@ -145,29 +141,15 @@ var _ = Describe("Config", func() { }) It("returns a copy", func() { - c1 := &Config{ - MaxIncomingStreams: 100, - RequireAddressValidation: func(net.Addr) bool { return true }, - } + c1 := &Config{MaxIncomingStreams: 100} c2 := c1.Clone() c2.MaxIncomingStreams = 200 - c2.RequireAddressValidation = func(net.Addr) bool { return false } Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) - Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue()) }) }) Context("populating", func() { - It("populates function fields", func() { - var calledAddrValidation bool - c1 := &Config{} - c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true } - c2 := populateConfig(c1) - c2.RequireAddressValidation(&net.UDPAddr{}) - Expect(calledAddrValidation).To(BeTrue()) - }) - It("copies non-function fields", func() { c := configWithNonZeroNonFunctionFields() Expect(populateConfig(c)).To(Equal(c)) @@ -186,10 +168,5 @@ var _ = Describe("Config", func() { Expect(c.DisablePathMTUDiscovery).To(BeFalse()) Expect(c.GetConfigForClient).To(BeNil()) }) - - It("populates empty fields with default values, for the server", func() { - c := populateServerConfig(&Config{}) - Expect(c.RequireAddressValidation).ToNot(BeNil()) - }) }) }) diff --git a/connection_test.go b/connection_test.go index 4d24bedc..3bdf38f0 100644 --- a/connection_test.go +++ b/connection_test.go @@ -118,7 +118,7 @@ var _ = Describe("Connection", func() { srcConnID, &protocol.DefaultConnectionIDGenerator{}, protocol.StatelessResetToken{}, - populateServerConfig(&Config{DisablePathMTUDiscovery: true}), + populateConfig(&Config{DisablePathMTUDiscovery: true}), &tls.Config{}, tokenGenerator, false, diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index a0f23479..894e4788 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -26,23 +26,17 @@ var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.Di type applicationProtocol struct { name string - run func() + run func(ln *quic.Listener, port int) } var _ = Describe("Handshake drop tests", func() { - var ( - proxy *quicproxy.QuicProxy - ln *quic.Listener - ) - data := GeneratePRData(5000) const timeout = 2 * time.Minute - startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) { + startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (ln *quic.Listener, proxyPort int, closeFn func()) { conf := getQuicConfig(&quic.Config{ - MaxIdleTimeout: timeout, - HandshakeIdleTimeout: timeout, - RequireAddressValidation: func(net.Addr) bool { return doRetry }, + MaxIdleTimeout: timeout, + HandshakeIdleTimeout: timeout, }) var tlsConf *tls.Config if longCertChain { @@ -50,11 +44,18 @@ var _ = Describe("Handshake drop tests", func() { } else { tlsConf = getTLSConfig() } - var err error - ln, err = quic.ListenAddr("localhost:0", tlsConf, conf) + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + tr := &quic.Transport{Conn: conn} + if doRetry { + tr.MaxUnvalidatedHandshakes = -1 + } + ln, err = tr.Listen(tlsConf, conf) Expect(err).ToNot(HaveOccurred()) serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), DropPacket: dropCallback, DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { @@ -62,11 +63,18 @@ var _ = Describe("Handshake drop tests", func() { }, }) Expect(err).ToNot(HaveOccurred()) + + return ln, proxy.LocalPort(), func() { + ln.Close() + tr.Close() + conn.Close() + proxy.Close() + } } clientSpeaksFirst := &applicationProtocol{ name: "client speaks first", - run: func() { + run: func(ln *quic.Listener, port int) { serverConnChan := make(chan quic.Connection) go func() { defer GinkgoRecover() @@ -82,7 +90,7 @@ var _ = Describe("Handshake drop tests", func() { }() conn, err := quic.DialAddr( context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), + fmt.Sprintf("localhost:%d", port), getTLSClientConfig(), getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, @@ -105,7 +113,7 @@ var _ = Describe("Handshake drop tests", func() { serverSpeaksFirst := &applicationProtocol{ name: "server speaks first", - run: func() { + run: func(ln *quic.Listener, port int) { serverConnChan := make(chan quic.Connection) go func() { defer GinkgoRecover() @@ -120,7 +128,7 @@ var _ = Describe("Handshake drop tests", func() { }() conn, err := quic.DialAddr( context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), + fmt.Sprintf("localhost:%d", port), getTLSClientConfig(), getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, @@ -143,7 +151,7 @@ var _ = Describe("Handshake drop tests", func() { nobodySpeaks := &applicationProtocol{ name: "nobody speaks", - run: func() { + run: func(ln *quic.Listener, port int) { serverConnChan := make(chan quic.Connection) go func() { defer GinkgoRecover() @@ -153,7 +161,7 @@ var _ = Describe("Handshake drop tests", func() { }() conn, err := quic.DialAddr( context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), + fmt.Sprintf("localhost:%d", port), getTLSClientConfig(), getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, @@ -169,11 +177,6 @@ var _ = Describe("Handshake drop tests", func() { }, } - AfterEach(func() { - Expect(ln.Close()).To(Succeed()) - Expect(proxy.Close()).To(Succeed()) - }) - for _, d := range directions { direction := d @@ -195,7 +198,7 @@ var _ = Describe("Handshake drop tests", func() { Context(app.name, func() { It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { var incoming, outgoing atomic.Int32 - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { var p int32 //nolint:exhaustive switch d { @@ -206,12 +209,13 @@ var _ = Describe("Handshake drop tests", func() { } return p == 1 && d.Is(direction) }, doRetry, longCertChain) - app.run() + defer closeFn() + app.run(ln, proxyPort) }) It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { var incoming, outgoing atomic.Int32 - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { var p int32 //nolint:exhaustive switch d { @@ -222,7 +226,8 @@ var _ = Describe("Handshake drop tests", func() { } return p == 2 && d.Is(direction) }, doRetry, longCertChain) - app.run() + defer closeFn() + app.run(ln, proxyPort) }) It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() { @@ -230,7 +235,7 @@ var _ = Describe("Handshake drop tests", func() { var mx sync.Mutex var incoming, outgoing int - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { drop := mrand.Int63n(int64(3)) == 0 mx.Lock() @@ -260,7 +265,8 @@ var _ = Describe("Handshake drop tests", func() { } return drop }, doRetry, longCertChain) - app.run() + defer closeFn() + app.run(ln, proxyPort) }) }) } @@ -281,13 +287,14 @@ var _ = Describe("Handshake drop tests", func() { uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b, } - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { if d == quicproxy.DirectionOutgoing { return false } return mrand.Intn(3) == 0 }, false, false) - clientSpeaksFirst.run() + defer closeFn() + clientSpeaksFirst.run(ln, proxyPort) }) } }) diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index 36ea7c78..40e541ab 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -55,8 +55,17 @@ var _ = Describe("Handshake RTT tests", func() { // 1 RTT for verifying the source address // 1 RTT for the TLS handshake It("is forward-secure after 2 RTTs", func() { - serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } - ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + defer udpConn.Close() + tr := &quic.Transport{ + Conn: udpConn, + MaxUnvalidatedHandshakes: -1, + } + defer tr.Close() + ln, err := tr.Listen(serverTLSConfig, serverConfig) Expect(err).ToNot(HaveOccurred()) defer ln.Close() diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 73f0973f..5d7f5868 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -701,14 +701,24 @@ var _ = Describe("Handshake tests", func() { It("rejects invalid Retry token with the INVALID_TOKEN error", func() { const rtt = 10 * time.Millisecond - serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } + // The validity period of the retry token is the handshake timeout, // which is twice the handshake idle timeout. // By setting the handshake timeout shorter than the RTT, the token will have expired by the time // it reaches the server. serverConfig.HandshakeIdleTimeout = rtt / 5 - server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + defer udpConn.Close() + tr := &quic.Transport{ + Conn: udpConn, + MaxUnvalidatedHandshakes: -1, + } + defer tr.Close() + server, err := tr.Listen(getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) defer server.Close() diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index 0eb68d06..fe383d9b 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -32,7 +32,7 @@ var _ = Describe("MITM test", func() { serverConfig *quic.Config ) - startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) { + startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback, forceAddressValidation bool) (proxyPort int, closeFn func()) { addr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) c, err := net.ListenUDP("udp", addr) @@ -41,6 +41,9 @@ var _ = Describe("MITM test", func() { Conn: c, ConnectionIDLength: connIDLen, } + if forceAddressValidation { + serverTransport.MaxUnvalidatedHandshakes = -1 + } ln, err := serverTransport.Listen(getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) @@ -153,7 +156,7 @@ var _ = Describe("MITM test", func() { } runTest := func(delayCb quicproxy.DelayCallback) { - proxyPort, closeFn := startServerAndProxy(delayCb, nil) + proxyPort, closeFn := startServerAndProxy(delayCb, nil, false) defer closeFn() raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) @@ -196,7 +199,7 @@ var _ = Describe("MITM test", func() { }) runTest := func(dropCb quicproxy.DropCallback) { - proxyPort, closeFn := startServerAndProxy(nil, dropCb) + proxyPort, closeFn := startServerAndProxy(nil, dropCb, false) defer closeFn() raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) @@ -310,17 +313,16 @@ var _ = Describe("MITM test", func() { const rtt = 20 * time.Millisecond - runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) { - proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil) + runTest := func(proxyPort int) (closeFn func(), err error) { raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) _, err = clientTransport.Dial( context.Background(), raddr, getTLSClientConfig(), - getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}), + getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(200 * time.Millisecond)}), ) - return func() { clientTransport.Close(); serverCloseFn() }, err + return func() { clientTransport.Close() }, err } // fails immediately because client connection closes when it can't find compatible version @@ -352,7 +354,9 @@ var _ = Describe("MITM test", func() { } return rtt / 2 } - closeFn, err := runTest(delayCb) + proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false) + defer serverCloseFn() + closeFn, err := runTest(proxyPort) defer closeFn() Expect(err).To(HaveOccurred()) vnErr := &quic.VersionNegotiationError{} @@ -363,8 +367,7 @@ var _ = Describe("MITM test", func() { // times out, because client doesn't accept subsequent real retry packets from server // as it has already accepted a retry. // TODO: determine behavior when server does not send Retry packets - It("fails when a forged Retry packet with modified srcConnID is sent to client", func() { - serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } + It("fails when a forged Retry packet with modified Source Connection ID is sent to client", func() { var initialPacketIntercepted bool done := make(chan struct{}) delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { @@ -388,7 +391,9 @@ var _ = Describe("MITM test", func() { } return rtt / 2 } - closeFn, err := runTest(delayCb) + proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, true) + defer serverCloseFn() + closeFn, err := runTest(proxyPort) defer closeFn() Expect(err).To(HaveOccurred()) Expect(err.(net.Error).Timeout()).To(BeTrue()) @@ -418,7 +423,9 @@ var _ = Describe("MITM test", func() { } return rtt } - closeFn, err := runTest(delayCb) + proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false) + defer serverCloseFn() + closeFn, err := runTest(proxyPort) defer closeFn() Expect(err).To(HaveOccurred()) Expect(err.(net.Error).Timeout()).To(BeTrue()) @@ -448,7 +455,9 @@ var _ = Describe("MITM test", func() { } return rtt } - closeFn, err := runTest(delayCb) + proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false) + defer serverCloseFn() + closeFn, err := runTest(proxyPort) defer closeFn() Expect(err).To(HaveOccurred()) var transportErr *quic.TransportError diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 6356e633..85ad015e 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -464,14 +464,19 @@ var _ = Describe("0-RTT", func() { } counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + defer udpConn.Close() + tr := &quic.Transport{ + Conn: udpConn, + MaxUnvalidatedHandshakes: -1, + } + defer tr.Close() + ln, err := tr.ListenEarly( tlsConf, - getQuicConfig(&quic.Config{ - RequireAddressValidation: func(net.Addr) bool { return true }, - Allow0RTT: true, - Tracer: newTracer(tracer), - }), + getQuicConfig(&quic.Config{Allow0RTT: true, Tracer: newTracer(tracer)}), ) Expect(err).ToNot(HaveOccurred()) defer ln.Close() diff --git a/interface.go b/interface.go index b269d790..8741c48b 100644 --- a/interface.go +++ b/interface.go @@ -267,11 +267,6 @@ type Config struct { // If the timeout is exceeded, the connection is closed. // If this value is zero, the timeout is set to 30 seconds. MaxIdleTimeout time.Duration - // RequireAddressValidation determines if a QUIC Retry packet is sent. - // This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT. - // See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details. - // If not set, every client is forced to prove its remote address. - RequireAddressValidation func(net.Addr) bool // The TokenStore stores tokens received from the server. // Tokens are used to skip address validation on future connection attempts. // The key used to store tokens is the ServerName from the tls.Config, if set diff --git a/interop/http09/server.go b/interop/http09/server.go index b7b510d8..e42a9ce1 100644 --- a/interop/http09/server.go +++ b/interop/http09/server.go @@ -37,6 +37,7 @@ func (w *responseWriter) WriteHeader(int) {} type Server struct { *http.Server + ForceRetry bool QuicConfig *quic.Config mutex sync.Mutex @@ -68,7 +69,11 @@ func (s *Server) ListenAndServe() error { tlsConf := s.TLSConfig.Clone() tlsConf.NextProtos = []string{h09alpn} - ln, err := quic.ListenEarly(conn, tlsConf, s.QuicConfig) + tr := quic.Transport{Conn: conn} + if s.ForceRetry { + tr.MaxUnvalidatedHandshakes = -1 + } + ln, err := tr.ListenEarly(tlsConf, s.QuicConfig) if err != nil { return err } diff --git a/interop/server/main.go b/interop/server/main.go index df704462..de059076 100644 --- a/interop/server/main.go +++ b/interop/server/main.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "fmt" "log" - "net" "net/http" "os" @@ -38,9 +37,8 @@ func main() { testcase := os.Getenv("TESTCASE") quicConf := &quic.Config{ - RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" }, - Allow0RTT: testcase == "zerortt", - Tracer: utils.NewQLOGConnectionTracer, + Allow0RTT: testcase == "zerortt", + Tracer: utils.NewQLOGConnectionTracer, } cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key") if err != nil { @@ -54,11 +52,11 @@ func main() { switch testcase { case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect", "zerortt": - err = runHTTP09Server(quicConf) + err = runHTTP09Server(quicConf, testcase == "retry") case "chacha20": reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256) defer reset() - err = runHTTP09Server(quicConf) + err = runHTTP09Server(quicConf, false) case "http3": err = runHTTP3Server(quicConf) default: @@ -72,12 +70,13 @@ func main() { } } -func runHTTP09Server(quicConf *quic.Config) error { +func runHTTP09Server(quicConf *quic.Config, forceRetry bool) error { server := http09.Server{ Server: &http.Server{ Addr: ":443", TLSConfig: tlsConf, }, + ForceRetry: forceRetry, QuicConfig: quicConf, } http.DefaultServeMux.Handle("/", http.FileServer(http.Dir("/www"))) diff --git a/server.go b/server.go index 87bdd6c2..a8c6fd7e 100644 --- a/server.go +++ b/server.go @@ -617,7 +617,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error } return nil } - if token == nil && (s.config.RequireAddressValidation(p.remoteAddr) || numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated)) { + if token == nil && numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated) { // Retry invalidates all 0-RTT packets sent. delete(s.zeroRTTQueues, hdr.DestConnectionID) select { diff --git a/server_test.go b/server_test.go index 6bd8e034..d0319eb8 100644 --- a/server_test.go +++ b/server_test.go @@ -6,7 +6,6 @@ import ( "crypto/tls" "errors" "net" - "reflect" "sync" "sync/atomic" "time" @@ -138,7 +137,6 @@ var _ = Describe("Server", func() { Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) - Expect(server.config.RequireAddressValidation).ToNot(BeNil()) Expect(server.config.KeepAlivePeriod).To(BeZero()) // stop the listener Expect(ln.Close()).To(Succeed()) @@ -146,13 +144,11 @@ var _ = Describe("Server", func() { It("setups with the right values", func() { supportedVersions := []protocol.VersionNumber{protocol.Version1} - requireAddrVal := func(net.Addr) bool { return true } config := Config{ - Versions: supportedVersions, - HandshakeIdleTimeout: 1337 * time.Hour, - MaxIdleTimeout: 42 * time.Minute, - KeepAlivePeriod: 5 * time.Second, - RequireAddressValidation: requireAddrVal, + Versions: supportedVersions, + HandshakeIdleTimeout: 1337 * time.Hour, + MaxIdleTimeout: 42 * time.Minute, + KeepAlivePeriod: 5 * time.Second, } ln, err := Listen(conn, tlsConf, &config) Expect(err).ToNot(HaveOccurred()) @@ -161,7 +157,6 @@ var _ = Describe("Server", func() { Expect(server.config.Versions).To(Equal(supportedVersions)) Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) - Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal))) Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) // stop the listener Expect(ln.Close()).To(Succeed()) @@ -263,7 +258,7 @@ var _ = Describe("Server", func() { }) It("creates a connection when the token is accepted", func() { - serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + serv.maxNumHandshakesUnvalidated = 0 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} retryToken, err := serv.tokenGenerator.NewRetryToken( raddr, @@ -441,7 +436,7 @@ var _ = Describe("Server", func() { It("replies with a Retry packet, if a token is required", func() { connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + serv.maxNumHandshakesUnvalidated = 0 hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), @@ -846,7 +841,7 @@ var _ = Describe("Server", func() { }) It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { - serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + serv.maxNumHandshakesUnvalidated = 0 token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ @@ -882,7 +877,7 @@ var _ = Describe("Server", func() { }) It("sends an INVALID_TOKEN error, if an expired retry token is received", func() { - serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + serv.maxNumHandshakesUnvalidated = 0 serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond)) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} @@ -920,7 +915,7 @@ var _ = Describe("Server", func() { }) It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() { - serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + serv.maxNumHandshakesUnvalidated = 0 token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ @@ -949,7 +944,7 @@ var _ = Describe("Server", func() { }) It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() { - serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + serv.maxNumHandshakesUnvalidated = 0 serv.maxTokenAge = time.Millisecond raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} token, err := serv.tokenGenerator.NewToken(raddr) @@ -978,7 +973,7 @@ var _ = Describe("Server", func() { }) It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { - serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + serv.maxNumHandshakesUnvalidated = 0 token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ @@ -1086,7 +1081,7 @@ var _ = Describe("Server", func() { conn := NewMockQUICConn(mockCtrl) conf := &Config{MaxIncomingStreams: 1234} - serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }}) + serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }}) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -1139,7 +1134,7 @@ var _ = Describe("Server", func() { }) It("rejects a connection attempt when GetConfigClient returns an error", func() { - serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }}) + serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }}) phm.EXPECT().Get(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { diff --git a/transport.go b/transport.go index 5316754a..ccbe32c2 100644 --- a/transport.go +++ b/transport.go @@ -179,7 +179,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo if t.server != nil { return nil, errListenerAlreadySet } - conf = populateServerConfig(conf) + conf = populateConfig(conf) if err := t.init(false); err != nil { return nil, err }