diff --git a/connection.go b/connection.go index 24760a78..3316069a 100644 --- a/connection.go +++ b/connection.go @@ -242,6 +242,7 @@ var newConnection = func( tlsConf *tls.Config, tokenGenerator *handshake.TokenGenerator, enable0RTT bool, + clientAddressValidated bool, tracer logging.ConnectionTracer, tracingID uint64, logger utils.Logger, @@ -288,6 +289,7 @@ var newConnection = func( 0, getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, + clientAddressValidated, s.perspective, s.tracer, s.logger, @@ -415,6 +417,7 @@ var newClientConnection = func( initialPacketNumber, getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, + false, /* has no effect */ s.perspective, s.tracer, s.logger, diff --git a/connection_test.go b/connection_test.go index 9c5d3fb9..2d792dd2 100644 --- a/connection_test.go +++ b/connection_test.go @@ -107,6 +107,7 @@ var _ = Describe("Connection", func() { nil, // tls.Config tokenGenerator, false, + false, tracer, 1234, utils.DefaultLogger, diff --git a/internal/ackhandler/ackhandler.go b/internal/ackhandler/ackhandler.go index 26291321..2fc9ae4e 100644 --- a/internal/ackhandler/ackhandler.go +++ b/internal/ackhandler/ackhandler.go @@ -6,16 +6,19 @@ import ( "github.com/lucas-clemente/quic-go/logging" ) -// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler +// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler. +// clientAddressValidated indicates whether the address was validated beforehand by an address validation token. +// clientAddressValidated has no effect for a client. func NewAckHandler( initialPacketNumber protocol.PacketNumber, initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, + clientAddressValidated bool, pers protocol.Perspective, tracer logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber, ) (SentPacketHandler, ReceivedPacketHandler) { - sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, pers, tracer, logger) + sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger) return sph, newReceivedPacketHandler(sph, rttStats, logger, version) } diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index cea34d19..5a8cd70e 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -101,10 +101,13 @@ var ( _ sentPacketTracker = &sentPacketHandler{} ) +// clientAddressValidated indicates whether the address was validated beforehand by an address validation token. +// If the address was validated, the amplification limit doesn't apply. It has no effect for a client. func newSentPacketHandler( initialPN protocol.PacketNumber, initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, + clientAddressValidated bool, pers protocol.Perspective, tracer logging.ConnectionTracer, logger utils.Logger, @@ -119,7 +122,7 @@ func newSentPacketHandler( return &sentPacketHandler{ peerCompletedAddressValidation: pers == protocol.PerspectiveServer, - peerAddressValidated: pers == protocol.PerspectiveClient, + peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated, initialPackets: newPacketNumberSpace(initialPN, false, rttStats), handshakePackets: newPacketNumberSpace(0, false, rttStats), appDataPackets: newPacketNumberSpace(0, true, rttStats), diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index dfa76036..7f8f8d02 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -29,7 +29,7 @@ var _ = Describe("SentPacketHandler", func() { JustBeforeEach(func() { lostPackets = nil rttStats := utils.NewRTTStats() - handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, perspective, nil, utils.DefaultLogger) + handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, perspective, nil, utils.DefaultLogger) streamFrame = wire.StreamFrame{ StreamID: 5, Data: []byte{0x13, 0x37}, @@ -944,6 +944,26 @@ var _ = Describe("SentPacketHandler", func() { }) }) + Context("amplification limit, for the server, with validated address", func() { + JustBeforeEach(func() { + rttStats := utils.NewRTTStats() + handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, perspective, nil, utils.DefaultLogger) + }) + + It("do not limits the window", func() { + handler.ReceivedBytes(0) + Expect(handler.SendMode()).To(Equal(SendAny)) + handler.SentPacket(&Packet{ + PacketNumber: 1, + Length: 900, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + Expect(handler.SendMode()).To(Equal(SendAny)) + }) + }) + Context("amplification limit, for the client", func() { BeforeEach(func() { perspective = protocol.PerspectiveClient diff --git a/internal/handshake/token_generator.go b/internal/handshake/token_generator.go index 228b2fa6..a8dda91e 100644 --- a/internal/handshake/token_generator.go +++ b/internal/handshake/token_generator.go @@ -26,6 +26,7 @@ type Token struct { RetrySrcConnectionID protocol.ConnectionID } +// ValidateRemoteAddr validates the address, but does not check expiration func (t *Token) ValidateRemoteAddr(addr net.Addr) bool { return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr) } diff --git a/server.go b/server.go index 1d14302e..726adcfa 100644 --- a/server.go +++ b/server.go @@ -89,6 +89,7 @@ type baseServer struct { *tls.Config, *handshake.TokenGenerator, bool, /* enable 0-RTT */ + bool, /* client address validated by an address validation token */ logging.ConnectionTracer, uint64, utils.Logger, @@ -375,6 +376,26 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s return true } +// validateToken returns false if: +// - address is invalid +// - token is expired +// - token is null +func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { + if token == nil { + return false + } + if !token.ValidateRemoteAddr(addr) { + return false + } + if !token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxTokenAge { + return false + } + if token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxRetryTokenAge { + return false + } + return true +} + func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { p.buffer.Release() @@ -399,23 +420,23 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro token = tok } } - if token != nil { - addrIsValid := token.ValidateRemoteAddr(p.remoteAddr) + + clientAddrIsValid := s.validateToken(token, p.remoteAddr) + + if token != nil && !clientAddrIsValid { // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. // We just ignore them, and act as if there was no token on this packet at all. // This also means we might send a Retry later. - if !token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxTokenAge || !addrIsValid) { + if !token.IsRetryToken { token = nil - } else if token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxRetryTokenAge || !addrIsValid) { + } else { // For Retry tokens, we send an INVALID_ERROR if // * the token is too old, or // * the token is invalid, in case of a retry token. go func() { defer p.buffer.Release() - if token != nil && token.IsRetryToken { - if err := s.maybeSendInvalidToken(p, hdr); err != nil { - s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) - } + if err := s.maybeSendInvalidToken(p, hdr); err != nil { + s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) } }() return nil @@ -476,6 +497,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro s.tlsConf, s.tokenGenerator, s.acceptEarlyConns, + clientAddrIsValid, tracer, tracingID, s.logger, diff --git a/server_test.go b/server_test.go index 4944bccf..eafba0cc 100644 --- a/server_test.go +++ b/server_test.go @@ -287,6 +287,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, enable0RTT bool, + _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -491,6 +492,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, enable0RTT bool, + _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -550,6 +552,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, + _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -603,6 +606,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, + _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -632,6 +636,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, + _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -702,6 +707,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, + _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -1007,6 +1013,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, + _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -1080,6 +1087,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, enable0RTT bool, + _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -1122,6 +1130,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, + _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -1184,6 +1193,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, + _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger,