diff --git a/config.go b/config.go index 377c3ae9..58dba802 100644 --- a/config.go +++ b/config.go @@ -2,11 +2,11 @@ package quic import ( "errors" + "net" "time" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) // Clone clones a Config @@ -39,8 +39,14 @@ func populateServerConfig(config *Config) *Config { if config.ConnectionIDLength == 0 { config.ConnectionIDLength = protocol.DefaultConnectionIDLength } - if config.AcceptToken == nil { - config.AcceptToken = defaultAcceptToken + if config.MaxTokenAge == 0 { + config.MaxTokenAge = protocol.TokenValidity + } + if config.MaxRetryTokenAge == 0 { + config.MaxRetryTokenAge = protocol.RetryTokenValidity + } + if config.RequireAddressValidation == nil { + config.RequireAddressValidation = func(net.Addr) bool { return true } } return config } @@ -104,7 +110,9 @@ func populateConfig(config *Config) *Config { Versions: versions, HandshakeIdleTimeout: handshakeIdleTimeout, MaxIdleTimeout: idleTimeout, - AcceptToken: config.AcceptToken, + MaxTokenAge: config.MaxTokenAge, + MaxRetryTokenAge: config.MaxRetryTokenAge, + RequireAddressValidation: config.RequireAddressValidation, KeepAlivePeriod: config.KeepAlivePeriod, InitialStreamReceiveWindow: initialStreamReceiveWindow, MaxStreamReceiveWindow: maxStreamReceiveWindow, diff --git a/config_test.go b/config_test.go index 692952f4..f4cfe41d 100644 --- a/config_test.go +++ b/config_test.go @@ -45,7 +45,7 @@ var _ = Describe("Config", func() { } switch fn := typ.Field(i).Name; fn { - case "AcceptToken", "GetLogWriter", "AllowConnectionWindowIncrease": + case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease": // Can't compare functions. case "Versions": f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) @@ -55,6 +55,10 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf(time.Second)) case "MaxIdleTimeout": f.Set(reflect.ValueOf(time.Hour)) + case "MaxTokenAge": + f.Set(reflect.ValueOf(2 * time.Hour)) + case "MaxRetryTokenAge": + f.Set(reflect.ValueOf(2 * time.Minute)) case "TokenStore": f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) case "InitialStreamReceiveWindow": @@ -100,14 +104,14 @@ var _ = Describe("Config", func() { Context("cloning", func() { It("clones function fields", func() { - var calledAcceptToken, calledAllowConnectionWindowIncrease bool + var calledAddrValidation, calledAllowConnectionWindowIncrease bool c1 := &Config{ - AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, + RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, } c2 := c1.Clone() - c2.AcceptToken(&net.UDPAddr{}, &Token{}) - Expect(calledAcceptToken).To(BeTrue()) + c2.RequireAddressValidation(&net.UDPAddr{}) + Expect(calledAddrValidation).To(BeTrue()) c2.AllowConnectionWindowIncrease(nil, 1234) Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) }) @@ -119,27 +123,26 @@ var _ = Describe("Config", func() { It("returns a copy", func() { c1 := &Config{ - MaxIncomingStreams: 100, - AcceptToken: func(_ net.Addr, _ *Token) bool { return true }, + MaxIncomingStreams: 100, + RequireAddressValidation: func(net.Addr) bool { return true }, } c2 := c1.Clone() c2.MaxIncomingStreams = 200 - c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + c2.RequireAddressValidation = func(net.Addr) bool { return false } Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) - Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue()) + Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue()) }) }) Context("populating", func() { It("populates function fields", func() { - var calledAcceptToken bool - c1 := &Config{ - AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, - } + var calledAddrValidation bool + c1 := &Config{} + c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true } c2 := populateConfig(c1) - c2.AcceptToken(&net.UDPAddr{}, &Token{}) - Expect(calledAcceptToken).To(BeTrue()) + c2.RequireAddressValidation(&net.UDPAddr{}) + Expect(calledAddrValidation).To(BeTrue()) }) It("copies non-function fields", func() { @@ -164,7 +167,7 @@ var _ = Describe("Config", func() { It("populates empty fields with default values, for the server", func() { c := populateServerConfig(&Config{}) Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) - Expect(c.AcceptToken).ToNot(BeNil()) + Expect(c.RequireAddressValidation).ToNot(BeNil()) }) It("sets a default connection ID length if we didn't create the conn, for the client", func() { diff --git a/fuzzing/tokens/fuzz.go b/fuzzing/tokens/fuzz.go index 9d414f77..1e1904ba 100644 --- a/fuzzing/tokens/fuzz.go +++ b/fuzzing/tokens/fuzz.go @@ -2,7 +2,6 @@ package tokens import ( "encoding/binary" - "fmt" "math/rand" "net" "time" @@ -77,7 +76,6 @@ func newToken(tg *handshake.TokenGenerator, data []byte) int { if token.OriginalDestConnectionID != nil || token.RetrySrcConnectionID != nil { panic("didn't expect connection IDs") } - checkAddr(token.RemoteAddr, addr) return 1 } @@ -140,22 +138,5 @@ func newRetryToken(tg *handshake.TokenGenerator, data []byte) int { if !token.RetrySrcConnectionID.Equal(retrySrcConnID) { panic("retry src conn ID doesn't match") } - checkAddr(token.RemoteAddr, addr) return 1 } - -func checkAddr(tokenAddr string, addr net.Addr) { - if udpAddr, ok := addr.(*net.UDPAddr); ok { - // For UDP addresses, we encode only the IP (not the port). - if ip := udpAddr.IP.String(); tokenAddr != ip { - fmt.Printf("%s vs %s", tokenAddr, ip) - panic("wrong remote address for a net.UDPAddr") - } - return - } - - if tokenAddr != addr.String() { - fmt.Printf("%s vs %s", tokenAddr, addr.String()) - panic("wrong remote address") - } -} diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index b788533d..525f1999 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -41,8 +41,8 @@ var _ = Describe("Handshake drop tests", func() { HandshakeIdleTimeout: timeout, Versions: []protocol.VersionNumber{version}, }) - if !doRetry { - conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true } + if doRetry { + conf.RequireAddressValidation = func(net.Addr) bool { return true } } var tlsConf *tls.Config if longCertChain { diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index df092086..22496b26 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -112,9 +112,7 @@ var _ = Describe("Handshake RTT tests", func() { }) It("establishes a connection in 1 RTT when the server doesn't require a token", func() { - serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool { - return true - } + serverConfig.RequireAddressValidation = func(net.Addr) bool { return false } runServerAndProxy() _, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), @@ -126,9 +124,7 @@ var _ = Describe("Handshake RTT tests", func() { }) It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() { - serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool { - return true - } + serverConfig.RequireAddressValidation = func(net.Addr) bool { return false } serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384} runServerAndProxy() _, err := quic.DialAddr( @@ -139,21 +135,4 @@ var _ = Describe("Handshake RTT tests", func() { Expect(err).ToNot(HaveOccurred()) expectDurationInRTTs(2) }) - - It("doesn't complete the handshake when the server never accepts the token", func() { - serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool { - return false - } - clientConfig.HandshakeIdleTimeout = 500 * time.Millisecond - runServerAndProxy() - _, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), - getTLSClientConfig(), - clientConfig, - ) - Expect(err).To(HaveOccurred()) - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeTrue()) - }) }) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index dee162f7..e494535c 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -344,12 +344,7 @@ var _ = Describe("Handshake tests", func() { } BeforeEach(func() { - serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool { - if token != nil { - Expect(token.IsRetryToken).To(BeFalse()) - } - return true - } + serverConfig.RequireAddressValidation = func(net.Addr) bool { return false } var err error // start the server, but don't call Accept server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) @@ -479,13 +474,7 @@ var _ = Describe("Handshake tests", func() { Context("using tokens", func() { It("uses tokens provided in NEW_TOKEN frames", func() { - tokenChan := make(chan *quic.Token, 100) - serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool { - if token != nil && !token.IsRetryToken { - tokenChan <- token - } - return true - } + serverConfig.RequireAddressValidation = func(net.Addr) bool { return false } server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) @@ -509,7 +498,6 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) Expect(gets).To(Receive()) Eventually(puts).Should(Receive()) - Expect(tokenChan).ToNot(Receive()) // received a token. Close this connection. Expect(conn.CloseWithError(0, "")).To(Succeed()) @@ -529,17 +517,13 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) defer conn.CloseWithError(0, "") Expect(gets).To(Receive()) - Expect(tokenChan).To(Receive()) Eventually(done).Should(BeClosed()) }) It("rejects invalid Retry token with the INVALID_TOKEN error", func() { - tokenChan := make(chan *quic.Token, 10) - serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool { - tokenChan <- token - return false - } + serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } + serverConfig.MaxRetryTokenAge = time.Nanosecond server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) @@ -554,18 +538,6 @@ var _ = Describe("Handshake tests", func() { var transportErr *quic.TransportError Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken)) - // Receiving a Retry might lead the client to measure a very small RTT. - // Then, it sometimes would retransmit the ClientHello before receiving the ServerHello. - Expect(len(tokenChan)).To(BeNumerically(">=", 2)) - token := <-tokenChan - Expect(token).To(BeNil()) - token = <-tokenChan - Expect(token).ToNot(BeNil()) - // If the ClientHello was retransmitted, make sure that it contained the same Retry token. - for i := 2; i < len(tokenChan); i++ { - Expect(<-tokenChan).To(Equal(token)) - } - Expect(token.IsRetryToken).To(BeTrue()) }) }) diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index d7d1c659..497bd43d 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -26,9 +26,9 @@ var _ = Describe("Packetization", func() { "localhost:0", getTLSConfig(), getQuicConfig(&quic.Config{ - AcceptToken: func(net.Addr, *quic.Token) bool { return true }, - DisablePathMTUDiscovery: true, - Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }), + RequireAddressValidation: func(net.Addr) bool { return false }, + DisablePathMTUDiscovery: true, + Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }), }), ) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 39be9ead..a3c2f92c 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -56,7 +56,7 @@ var _ = Describe("0-RTT", func() { tlsConf := getTLSConfig() if serverConf == nil { serverConf = getQuicConfig(&quic.Config{ - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + RequireAddressValidation: func(net.Addr) bool { return false }, }) serverConf.Versions = []protocol.VersionNumber{version} } @@ -197,9 +197,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + RequireAddressValidation: func(net.Addr) bool { return false }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -255,9 +255,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + RequireAddressValidation: func(net.Addr) bool { return false }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -452,8 +452,8 @@ var _ = Describe("0-RTT", func() { It("doesn't reject 0-RTT when the server's transport stream limit increased", func() { const maxStreams = 1 tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ - MaxIncomingUniStreams: maxStreams, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + MaxIncomingUniStreams: maxStreams, + RequireAddressValidation: func(net.Addr) bool { return false }, })) tracer := newPacketTracer() @@ -461,10 +461,10 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - MaxIncomingUniStreams: maxStreams + 1, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + RequireAddressValidation: func(net.Addr) bool { return false }, + MaxIncomingUniStreams: maxStreams + 1, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -498,8 +498,8 @@ var _ = Describe("0-RTT", func() { It("rejects 0-RTT when the server's stream limit decreased", func() { const maxStreams = 42 tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ - MaxIncomingStreams: maxStreams, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + MaxIncomingStreams: maxStreams, + RequireAddressValidation: func(net.Addr) bool { return false }, })) tracer := newPacketTracer() @@ -507,10 +507,10 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - MaxIncomingStreams: maxStreams - 1, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + RequireAddressValidation: func(net.Addr) bool { return false }, + MaxIncomingStreams: maxStreams - 1, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -537,9 +537,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + RequireAddressValidation: func(net.Addr) bool { return false }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -560,16 +560,16 @@ var _ = Describe("0-RTT", func() { func(addFlowControlLimit func(*quic.Config, uint64)) { tracer := newPacketTracer() firstConf := getQuicConfig(&quic.Config{ - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Versions: []protocol.VersionNumber{version}, + RequireAddressValidation: func(net.Addr) bool { return false }, + Versions: []protocol.VersionNumber{version}, }) addFlowControlLimit(firstConf, 3) tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) secondConf := getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + RequireAddressValidation: func(net.Addr) bool { return false }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }) addFlowControlLimit(secondConf, 100) ln, err := quic.ListenAddrEarly( @@ -722,9 +722,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + RequireAddressValidation: func(net.Addr) bool { return false }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) diff --git a/interface.go b/interface.go index 6130b549..c19a5c93 100644 --- a/interface.go +++ b/interface.go @@ -26,16 +26,6 @@ const ( Version2 = protocol.Version2 ) -// A Token can be used to verify the ownership of the client address. -type Token struct { - // IsRetryToken encodes how the client received the token. There are two ways: - // * In a Retry packet sent when trying to establish a new connection. - // * In a NEW_TOKEN frame on a previous connection. - IsRetryToken bool - RemoteAddr string - SentTime time.Time -} - // A ClientToken is a token received by the client. // It can be used to skip address validation on future connection attempts. type ClientToken struct { @@ -233,14 +223,18 @@ type Config struct { // If the timeout is exceeded, the connection is closed. // If this value is zero, the timeout is set to 30 seconds. MaxIdleTimeout time.Duration - // AcceptToken determines if a Token is accepted. - // It is called with token = nil if the client didn't send a token. - // If not set, a default verification function is used: - // * it verifies that the address matches, and - // * if the token is a retry token, that it was issued within the last 5 seconds - // * else, that it was issued within the last 24 hours. - // This option is only valid for the server. - AcceptToken func(clientAddr net.Addr, token *Token) bool + // RequireAddressValidation determines if a QUIC Retry packet is sent. + // This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT. + // See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details. + // If not set, every client is forced to prove its remote address. + RequireAddressValidation func(net.Addr) bool + // MaxRetryTokenAge is the maximum age of a Retry token. + // If not set, it defaults to 5 seconds. Only valid for a server. + MaxRetryTokenAge time.Duration + // MaxTokenAge is the maximum age of the token presented during the handshake, + // for tokens that were issued on a previous connection. + // If not set, it defaults to 24 hours. Only valid for a server. + MaxTokenAge time.Duration // The TokenStore stores tokens received from the server. // Tokens are used to skip address validation on future connection attempts. // The key used to store tokens is the ServerName from the tls.Config, if set diff --git a/internal/handshake/token_generator.go b/internal/handshake/token_generator.go index 2df5fcd8..228b2fa6 100644 --- a/internal/handshake/token_generator.go +++ b/internal/handshake/token_generator.go @@ -1,6 +1,7 @@ package handshake import ( + "bytes" "encoding/asn1" "fmt" "io" @@ -17,14 +18,18 @@ const ( // A Token is derived from the client address and can be used to verify the ownership of this address. type Token struct { - IsRetryToken bool - RemoteAddr string - SentTime time.Time + IsRetryToken bool + SentTime time.Time + encodedRemoteAddr []byte // only set for retry tokens OriginalDestConnectionID protocol.ConnectionID RetrySrcConnectionID protocol.ConnectionID } +func (t *Token) ValidateRemoteAddr(addr net.Addr) bool { + return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr) +} + // token is the struct that is used for ASN1 serialization and deserialization type token struct { IsRetryToken bool @@ -101,9 +106,9 @@ func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) { return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) } token := &Token{ - IsRetryToken: t.IsRetryToken, - RemoteAddr: decodeRemoteAddr(t.RemoteAddr), - SentTime: time.Unix(0, t.Timestamp), + IsRetryToken: t.IsRetryToken, + SentTime: time.Unix(0, t.Timestamp), + encodedRemoteAddr: t.RemoteAddr, } if t.IsRetryToken { token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) @@ -119,16 +124,3 @@ func encodeRemoteAddr(remoteAddr net.Addr) []byte { } return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...) } - -// decodeRemoteAddr decodes the remote address saved in the token -func decodeRemoteAddr(data []byte) string { - // data will never be empty for a token that we generated. - // Check it to be on the safe side - if len(data) == 0 { - return "" - } - if data[0] == tokenPrefixIP { - return net.IP(data[1:]).String() - } - return string(data[1:]) -} diff --git a/internal/handshake/token_generator_test.go b/internal/handshake/token_generator_test.go index 3aef6a3d..4d4be0f2 100644 --- a/internal/handshake/token_generator_test.go +++ b/internal/handshake/token_generator_test.go @@ -35,16 +35,13 @@ var _ = Describe("Token Generator", func() { }) It("accepts a valid token", func() { - ip := net.IPv4(192, 168, 0, 1) - tokenEnc, err := tokenGen.NewRetryToken( - &net.UDPAddr{IP: ip, Port: 1337}, - nil, - nil, - ) + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + tokenEnc, err := tokenGen.NewRetryToken(addr, nil, nil) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) - Expect(token.RemoteAddr).To(Equal("192.168.0.1")) + Expect(token.ValidateRemoteAddr(addr)).To(BeTrue()) + Expect(token.ValidateRemoteAddr(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 2), Port: 1337})).To(BeFalse()) Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) Expect(token.OriginalDestConnectionID.Len()).To(BeZero()) Expect(token.RetrySrcConnectionID.Len()).To(BeZero()) @@ -110,7 +107,7 @@ var _ = Describe("Token Generator", func() { Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) - Expect(token.RemoteAddr).To(Equal(ip.String())) + Expect(token.ValidateRemoteAddr(raddr)).To(BeTrue()) Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) } }) @@ -121,7 +118,8 @@ var _ = Describe("Token Generator", func() { Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) - Expect(token.RemoteAddr).To(Equal("192.168.13.37:1337")) + Expect(token.ValidateRemoteAddr(raddr)).To(BeTrue()) + Expect(token.ValidateRemoteAddr(&net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1338})).To(BeFalse()) Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) }) }) diff --git a/interop/server/main.go b/interop/server/main.go index 2674fd5b..d92ec1d3 100644 --- a/interop/server/main.go +++ b/interop/server/main.go @@ -44,8 +44,8 @@ func main() { } // a quic.Config that doesn't do a Retry quicConf := &quic.Config{ - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Tracer: qlog.NewTracer(getLogWriter), + RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" }, + Tracer: qlog.NewTracer(getLogWriter), } cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key") if err != nil { @@ -58,15 +58,11 @@ func main() { } switch testcase { - case "versionnegotiation", "handshake", "transfer", "resumption", "zerortt", "multiconnect": + case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "zerortt", "multiconnect": err = runHTTP09Server(quicConf) case "chacha20": tlsConf.CipherSuites = []uint16{tls.TLS_CHACHA20_POLY1305_SHA256} err = runHTTP09Server(quicConf) - case "retry": - // By default, quic-go performs a Retry on every incoming connection. - quicConf.AcceptToken = nil - err = runHTTP09Server(quicConf) case "http3": err = runHTTP3Server(quicConf) default: diff --git a/server.go b/server.go index 0e642970..1d14302e 100644 --- a/server.go +++ b/server.go @@ -241,26 +241,6 @@ func (s *baseServer) run() { } } -var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool { - if token == nil { - return false - } - validity := protocol.TokenValidity - if token.IsRetryToken { - validity = protocol.RetryTokenValidity - } - if time.Now().After(token.SentTime.Add(validity)) { - return false - } - var sourceAddr string - if udpAddr, ok := clientAddr.(*net.UDPAddr); ok { - sourceAddr = udpAddr.IP.String() - } else { - sourceAddr = clientAddr.String() - } - return sourceAddr == token.RemoteAddr -} - // Accept returns connections that already completed the handshake. // It is only valid if acceptEarlyConns is false. func (s *baseServer) Accept(ctx context.Context) (Connection, error) { @@ -405,33 +385,45 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro } var ( - token *Token + token *handshake.Token retrySrcConnID *protocol.ConnectionID ) origDestConnID := hdr.DestConnectionID if len(hdr.Token) > 0 { - c, err := s.tokenGenerator.DecodeToken(hdr.Token) + tok, err := s.tokenGenerator.DecodeToken(hdr.Token) if err == nil { - token = &Token{ - IsRetryToken: c.IsRetryToken, - RemoteAddr: c.RemoteAddr, - SentTime: c.SentTime, - } - if token.IsRetryToken { - origDestConnID = c.OriginalDestConnectionID - retrySrcConnID = &c.RetrySrcConnectionID + if tok.IsRetryToken { + origDestConnID = tok.OriginalDestConnectionID + retrySrcConnID = &tok.RetrySrcConnectionID } + token = tok } } - if !s.config.AcceptToken(p.remoteAddr, token) { + if token != nil { + addrIsValid := token.ValidateRemoteAddr(p.remoteAddr) + // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. + // We just ignore them, and act as if there was no token on this packet at all. + // This also means we might send a Retry later. + if !token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxTokenAge || !addrIsValid) { + token = nil + } else if token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxRetryTokenAge || !addrIsValid) { + // For Retry tokens, we send an INVALID_ERROR if + // * the token is too old, or + // * the token is invalid, in case of a retry token. + go func() { + defer p.buffer.Release() + if token != nil && token.IsRetryToken { + if err := s.maybeSendInvalidToken(p, hdr); err != nil { + s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) + } + } + }() + return nil + } + } + if token == nil && s.config.RequireAddressValidation(p.remoteAddr) { go func() { defer p.buffer.Release() - if token != nil && token.IsRetryToken { - if err := s.maybeSendInvalidToken(p, hdr); err != nil { - s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) - } - return - } if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { s.logger.Debugf("Error sending Retry: %s", err) } diff --git a/server_test.go b/server_test.go index 2cdb98d0..5640a05a 100644 --- a/server_test.go +++ b/server_test.go @@ -126,22 +126,22 @@ var _ = Describe("Server", func() { Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) - Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(defaultAcceptToken))) - Expect(server.config.KeepAlivePeriod).To(Equal(0 * time.Second)) + Expect(server.config.RequireAddressValidation).ToNot(BeNil()) + Expect(server.config.KeepAlivePeriod).To(BeZero()) // stop the listener Expect(ln.Close()).To(Succeed()) }) It("setups with the right values", func() { supportedVersions := []protocol.VersionNumber{protocol.VersionTLS} - acceptToken := func(_ net.Addr, _ *Token) bool { return true } + requireAddrVal := func(net.Addr) bool { return true } config := Config{ - Versions: supportedVersions, - AcceptToken: acceptToken, - HandshakeIdleTimeout: 1337 * time.Hour, - MaxIdleTimeout: 42 * time.Minute, - KeepAlivePeriod: 5 * time.Second, - StatelessResetKey: []byte("foobar"), + Versions: supportedVersions, + HandshakeIdleTimeout: 1337 * time.Hour, + MaxIdleTimeout: 42 * time.Minute, + KeepAlivePeriod: 5 * time.Second, + StatelessResetKey: []byte("foobar"), + RequireAddressValidation: requireAddrVal, } ln, err := Listen(conn, tlsConf, &config) Expect(err).ToNot(HaveOccurred()) @@ -150,7 +150,7 @@ var _ = Describe("Server", func() { Expect(server.config.Versions).To(Equal(supportedVersions)) Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) - Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(acceptToken))) + Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal))) Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar"))) // stop the listener @@ -239,60 +239,8 @@ var _ = Describe("Server", func() { time.Sleep(50 * time.Millisecond) }) - It("decodes the token from the Token field", func() { - raddr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 13, 37), - Port: 1337, - } - done := make(chan struct{}) - serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { - Expect(addr).To(Equal(raddr)) - Expect(token).ToNot(BeNil()) - close(done) - return false - } - token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) - Expect(err).ToNot(HaveOccurred()) - packet := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Token: token, - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = raddr - conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("passes an empty token to the callback, if decoding fails", func() { - raddr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 13, 37), - Port: 1337, - } - done := make(chan struct{}) - serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { - Expect(addr).To(Equal(raddr)) - Expect(token).To(BeNil()) - close(done) - return false - } - packet := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Token: []byte("foobar"), - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = raddr - conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - It("creates a connection when the token is accepted", func() { - serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true } + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } retryToken, err := serv.tokenGenerator.NewRetryToken( &net.UDPAddr{}, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, @@ -469,8 +417,8 @@ var _ = Describe("Server", func() { time.Sleep(scaleDuration(20 * time.Millisecond)) }) - It("replies with a Retry packet, if a Token is required", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + It("replies with a Retry packet, if a token is required", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, @@ -502,81 +450,8 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) - It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Token: token, - Version: protocol.VersionTLS, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - packet.remoteAddr = raddr - tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(frames).To(HaveLen(1)) - Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := frames[0].(*logging.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) - }) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - replyHdr := parseHeader(b) - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - _, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) - extHdr, err := unpackHeader(opener, replyHdr, b, hdr.Version) - Expect(err).ToNot(HaveOccurred()) - data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) - Expect(err).ToNot(HaveOccurred()) - f, err := wire.NewFrameParser(false, hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := f.(*wire.ConnectionCloseFrame) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) - Expect(ccf.ReasonPhrase).To(BeEmpty()) - return len(b), nil - }) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Token: token, - Version: protocol.VersionTLS, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet - packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) - serv.handlePacket(packet) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - Eventually(done).Should(BeClosed()) - }) - - It("creates a connection, if no Token is required", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + It("creates a connection, if no token is required", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return false } hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, @@ -659,7 +534,7 @@ var _ = Describe("Server", func() { }).AnyTimes() tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() - serv.config.AcceptToken = func(net.Addr, *Token) bool { return true } + serv.config.RequireAddressValidation = func(net.Addr) bool { return false } acceptConn := make(chan struct{}) var counter uint32 // to be used as an atomic, so we query it in Eventually serv.newConn = func( @@ -713,7 +588,7 @@ var _ = Describe("Server", func() { }) It("only creates a single connection for a duplicate Initial", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + serv.config.RequireAddressValidation = func(net.Addr) bool { return false } var createdConn bool conn := NewMockQuicConn(mockCtrl) serv.newConn = func( @@ -745,7 +620,7 @@ var _ = Describe("Server", func() { }) It("rejects new connection attempts if the accept queue is full", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + serv.config.RequireAddressValidation = func(net.Addr) bool { return false } serv.newConn = func( _ sendConn, @@ -813,7 +688,7 @@ var _ = Describe("Server", func() { }) It("doesn't accept new connections if they were closed in the mean time", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + serv.config.RequireAddressValidation = func(net.Addr) bool { return false } p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) ctx, cancel := context.WithCancel(context.Background()) @@ -877,6 +752,200 @@ var _ = Describe("Server", func() { }) }) + Context("token validation", func() { + checkInvalidToken := func(b []byte, origHdr *wire.Header) { + replyHdr := parseHeader(b) + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID)) + _, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) + extHdr, err := unpackHeader(opener, replyHdr, b, origHdr.Version) + Expect(err).ToNot(HaveOccurred()) + data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) + Expect(err).ToNot(HaveOccurred()) + f, err := wire.NewFrameParser(false, origHdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := f.(*wire.ConnectionCloseFrame) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) + Expect(ccf.ReasonPhrase).To(BeEmpty()) + } + + It("decodes the token from the token field", func() { + raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} + token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) + Expect(err).ToNot(HaveOccurred()) + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Token: token, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) + + done := make(chan struct{}) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) }) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + packet.remoteAddr = raddr + tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(frames).To(HaveLen(1)) + Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := frames[0].(*logging.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) + }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + checkInvalidToken(b, hdr) + return len(b), nil + }) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("sends an INVALID_TOKEN error, if an expired retry token is received", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + serv.config.MaxRetryTokenAge = time.Millisecond + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) + Expect(err).ToNot(HaveOccurred()) + time.Sleep(2 * time.Millisecond) // make sure the token is expired + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(frames).To(HaveLen(1)) + Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := frames[0].(*logging.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) + }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + checkInvalidToken(b, hdr) + return len(b), nil + }) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}) + Expect(err).ToNot(HaveOccurred()) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + packet.remoteAddr = raddr + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + replyHdr := parseHeader(b) + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) + return len(b), nil + }) + serv.handlePacket(packet) + // make sure there are no Write calls on the packet conn + Eventually(done).Should(BeClosed()) + }) + + It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + serv.config.MaxTokenAge = time.Millisecond + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + token, err := serv.tokenGenerator.NewToken(raddr) + Expect(err).ToNot(HaveOccurred()) + time.Sleep(2 * time.Millisecond) // make sure the token is expired + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) + }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + return len(b), nil + }) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet + packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + done := make(chan struct{}) + tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) + serv.handlePacket(packet) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + Eventually(done).Should(BeClosed()) + }) + }) + Context("accepting connections", func() { It("returns Accept when an error occurs", func() { testErr := errors.New("test err") @@ -930,7 +999,7 @@ var _ = Describe("Server", func() { }() ctx, cancel := context.WithCancel(context.Background()) // handshake context - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + serv.config.RequireAddressValidation = func(net.Addr) bool { return false } serv.newConn = func( _ sendConn, runner connRunner, @@ -1004,7 +1073,7 @@ var _ = Describe("Server", func() { }() ready := make(chan struct{}) - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + serv.config.RequireAddressValidation = func(net.Addr) bool { return false } serv.newConn = func( _ sendConn, runner connRunner, @@ -1045,7 +1114,7 @@ var _ = Describe("Server", func() { }) It("rejects new connection attempts if the accept queue is full", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + serv.config.RequireAddressValidation = func(net.Addr) bool { return false } senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} serv.newConn = func( @@ -1106,7 +1175,7 @@ var _ = Describe("Server", func() { }) It("doesn't accept new connections if they were closed in the mean time", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + serv.config.RequireAddressValidation = func(net.Addr) bool { return false } p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) ctx, cancel := context.WithCancel(context.Background()) @@ -1166,72 +1235,3 @@ var _ = Describe("Server", func() { }) }) }) - -var _ = Describe("default source address verification", func() { - It("accepts a token", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "192.168.0.1", - SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(time.Second), // will expire in 1 second - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) - }) - - It("requests verification if no token is provided", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - Expect(defaultAcceptToken(remoteAddr, nil)).To(BeFalse()) - }) - - It("rejects a token if the address doesn't match", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "127.0.0.1", - SentTime: time.Now(), - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) - }) - - It("accepts a token for a remote address is not a UDP address", func() { - remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "192.168.0.1:1337", - SentTime: time.Now(), - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) - }) - - It("rejects an invalid token for a remote address is not a UDP address", func() { - remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "192.168.0.1:7331", // mismatching port - SentTime: time.Now(), - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) - }) - - It("rejects an expired token", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "192.168.0.1", - SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), // expired 1 second ago - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) - }) - - It("accepts a non-retry token", func() { - Expect(protocol.RetryTokenValidity).To(BeNumerically("<", protocol.TokenValidity)) - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - token := &Token{ - IsRetryToken: false, - RemoteAddr: "192.168.0.1", - // if this was a retry token, it would have expired one second ago - SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) - }) -})