Disable anti-amplification limit by address validation token (#3326)

This commit is contained in:
Benedikt Spies 2022-08-20 17:02:17 +02:00 committed by GitHub
parent 8c0c481da1
commit 7da024da5a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 75 additions and 12 deletions

View file

@ -242,6 +242,7 @@ var newConnection = func(
tlsConf *tls.Config, tlsConf *tls.Config,
tokenGenerator *handshake.TokenGenerator, tokenGenerator *handshake.TokenGenerator,
enable0RTT bool, enable0RTT bool,
clientAddressValidated bool,
tracer logging.ConnectionTracer, tracer logging.ConnectionTracer,
tracingID uint64, tracingID uint64,
logger utils.Logger, logger utils.Logger,
@ -288,6 +289,7 @@ var newConnection = func(
0, 0,
getMaxPacketSize(s.conn.RemoteAddr()), getMaxPacketSize(s.conn.RemoteAddr()),
s.rttStats, s.rttStats,
clientAddressValidated,
s.perspective, s.perspective,
s.tracer, s.tracer,
s.logger, s.logger,
@ -415,6 +417,7 @@ var newClientConnection = func(
initialPacketNumber, initialPacketNumber,
getMaxPacketSize(s.conn.RemoteAddr()), getMaxPacketSize(s.conn.RemoteAddr()),
s.rttStats, s.rttStats,
false, /* has no effect */
s.perspective, s.perspective,
s.tracer, s.tracer,
s.logger, s.logger,

View file

@ -107,6 +107,7 @@ var _ = Describe("Connection", func() {
nil, // tls.Config nil, // tls.Config
tokenGenerator, tokenGenerator,
false, false,
false,
tracer, tracer,
1234, 1234,
utils.DefaultLogger, utils.DefaultLogger,

View file

@ -6,16 +6,19 @@ import (
"github.com/lucas-clemente/quic-go/logging" "github.com/lucas-clemente/quic-go/logging"
) )
// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler // NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler.
// clientAddressValidated indicates whether the address was validated beforehand by an address validation token.
// clientAddressValidated has no effect for a client.
func NewAckHandler( func NewAckHandler(
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
initialMaxDatagramSize protocol.ByteCount, initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats, rttStats *utils.RTTStats,
clientAddressValidated bool,
pers protocol.Perspective, pers protocol.Perspective,
tracer logging.ConnectionTracer, tracer logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
version protocol.VersionNumber, version protocol.VersionNumber,
) (SentPacketHandler, ReceivedPacketHandler) { ) (SentPacketHandler, ReceivedPacketHandler) {
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, pers, tracer, logger) sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger)
return sph, newReceivedPacketHandler(sph, rttStats, logger, version) return sph, newReceivedPacketHandler(sph, rttStats, logger, version)
} }

View file

@ -101,10 +101,13 @@ var (
_ sentPacketTracker = &sentPacketHandler{} _ sentPacketTracker = &sentPacketHandler{}
) )
// clientAddressValidated indicates whether the address was validated beforehand by an address validation token.
// If the address was validated, the amplification limit doesn't apply. It has no effect for a client.
func newSentPacketHandler( func newSentPacketHandler(
initialPN protocol.PacketNumber, initialPN protocol.PacketNumber,
initialMaxDatagramSize protocol.ByteCount, initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats, rttStats *utils.RTTStats,
clientAddressValidated bool,
pers protocol.Perspective, pers protocol.Perspective,
tracer logging.ConnectionTracer, tracer logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
@ -119,7 +122,7 @@ func newSentPacketHandler(
return &sentPacketHandler{ return &sentPacketHandler{
peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
peerAddressValidated: pers == protocol.PerspectiveClient, peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated,
initialPackets: newPacketNumberSpace(initialPN, false, rttStats), initialPackets: newPacketNumberSpace(initialPN, false, rttStats),
handshakePackets: newPacketNumberSpace(0, false, rttStats), handshakePackets: newPacketNumberSpace(0, false, rttStats),
appDataPackets: newPacketNumberSpace(0, true, rttStats), appDataPackets: newPacketNumberSpace(0, true, rttStats),

View file

@ -29,7 +29,7 @@ var _ = Describe("SentPacketHandler", func() {
JustBeforeEach(func() { JustBeforeEach(func() {
lostPackets = nil lostPackets = nil
rttStats := utils.NewRTTStats() rttStats := utils.NewRTTStats()
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, perspective, nil, utils.DefaultLogger) handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, perspective, nil, utils.DefaultLogger)
streamFrame = wire.StreamFrame{ streamFrame = wire.StreamFrame{
StreamID: 5, StreamID: 5,
Data: []byte{0x13, 0x37}, Data: []byte{0x13, 0x37},
@ -944,6 +944,26 @@ var _ = Describe("SentPacketHandler", func() {
}) })
}) })
Context("amplification limit, for the server, with validated address", func() {
JustBeforeEach(func() {
rttStats := utils.NewRTTStats()
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, perspective, nil, utils.DefaultLogger)
})
It("do not limits the window", func() {
handler.ReceivedBytes(0)
Expect(handler.SendMode()).To(Equal(SendAny))
handler.SentPacket(&Packet{
PacketNumber: 1,
Length: 900,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
Expect(handler.SendMode()).To(Equal(SendAny))
})
})
Context("amplification limit, for the client", func() { Context("amplification limit, for the client", func() {
BeforeEach(func() { BeforeEach(func() {
perspective = protocol.PerspectiveClient perspective = protocol.PerspectiveClient

View file

@ -26,6 +26,7 @@ type Token struct {
RetrySrcConnectionID protocol.ConnectionID RetrySrcConnectionID protocol.ConnectionID
} }
// ValidateRemoteAddr validates the address, but does not check expiration
func (t *Token) ValidateRemoteAddr(addr net.Addr) bool { func (t *Token) ValidateRemoteAddr(addr net.Addr) bool {
return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr) return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr)
} }

View file

@ -89,6 +89,7 @@ type baseServer struct {
*tls.Config, *tls.Config,
*handshake.TokenGenerator, *handshake.TokenGenerator,
bool, /* enable 0-RTT */ bool, /* enable 0-RTT */
bool, /* client address validated by an address validation token */
logging.ConnectionTracer, logging.ConnectionTracer,
uint64, uint64,
utils.Logger, utils.Logger,
@ -375,6 +376,26 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
return true return true
} }
// validateToken returns false if:
// - address is invalid
// - token is expired
// - token is null
func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
if token == nil {
return false
}
if !token.ValidateRemoteAddr(addr) {
return false
}
if !token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxTokenAge {
return false
}
if token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxRetryTokenAge {
return false
}
return true
}
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
p.buffer.Release() p.buffer.Release()
@ -399,23 +420,23 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
token = tok token = tok
} }
} }
if token != nil {
addrIsValid := token.ValidateRemoteAddr(p.remoteAddr) clientAddrIsValid := s.validateToken(token, p.remoteAddr)
if token != nil && !clientAddrIsValid {
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
// We just ignore them, and act as if there was no token on this packet at all. // We just ignore them, and act as if there was no token on this packet at all.
// This also means we might send a Retry later. // This also means we might send a Retry later.
if !token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxTokenAge || !addrIsValid) { if !token.IsRetryToken {
token = nil token = nil
} else if token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxRetryTokenAge || !addrIsValid) { } else {
// For Retry tokens, we send an INVALID_ERROR if // For Retry tokens, we send an INVALID_ERROR if
// * the token is too old, or // * the token is too old, or
// * the token is invalid, in case of a retry token. // * the token is invalid, in case of a retry token.
go func() { go func() {
defer p.buffer.Release() defer p.buffer.Release()
if token != nil && token.IsRetryToken { if err := s.maybeSendInvalidToken(p, hdr); err != nil {
if err := s.maybeSendInvalidToken(p, hdr); err != nil { s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
}
} }
}() }()
return nil return nil
@ -476,6 +497,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
s.tlsConf, s.tlsConf,
s.tokenGenerator, s.tokenGenerator,
s.acceptEarlyConns, s.acceptEarlyConns,
clientAddrIsValid,
tracer, tracer,
tracingID, tracingID,
s.logger, s.logger,

View file

@ -287,6 +287,7 @@ var _ = Describe("Server", func() {
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
enable0RTT bool, enable0RTT bool,
_ bool,
_ logging.ConnectionTracer, _ logging.ConnectionTracer,
_ uint64, _ uint64,
_ utils.Logger, _ utils.Logger,
@ -491,6 +492,7 @@ var _ = Describe("Server", func() {
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
enable0RTT bool, enable0RTT bool,
_ bool,
_ logging.ConnectionTracer, _ logging.ConnectionTracer,
_ uint64, _ uint64,
_ utils.Logger, _ utils.Logger,
@ -550,6 +552,7 @@ var _ = Describe("Server", func() {
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool, _ bool,
_ bool,
_ logging.ConnectionTracer, _ logging.ConnectionTracer,
_ uint64, _ uint64,
_ utils.Logger, _ utils.Logger,
@ -603,6 +606,7 @@ var _ = Describe("Server", func() {
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool, _ bool,
_ bool,
_ logging.ConnectionTracer, _ logging.ConnectionTracer,
_ uint64, _ uint64,
_ utils.Logger, _ utils.Logger,
@ -632,6 +636,7 @@ var _ = Describe("Server", func() {
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool, _ bool,
_ bool,
_ logging.ConnectionTracer, _ logging.ConnectionTracer,
_ uint64, _ uint64,
_ utils.Logger, _ utils.Logger,
@ -702,6 +707,7 @@ var _ = Describe("Server", func() {
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool, _ bool,
_ bool,
_ logging.ConnectionTracer, _ logging.ConnectionTracer,
_ uint64, _ uint64,
_ utils.Logger, _ utils.Logger,
@ -1007,6 +1013,7 @@ var _ = Describe("Server", func() {
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool, _ bool,
_ bool,
_ logging.ConnectionTracer, _ logging.ConnectionTracer,
_ uint64, _ uint64,
_ utils.Logger, _ utils.Logger,
@ -1080,6 +1087,7 @@ var _ = Describe("Server", func() {
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
enable0RTT bool, enable0RTT bool,
_ bool,
_ logging.ConnectionTracer, _ logging.ConnectionTracer,
_ uint64, _ uint64,
_ utils.Logger, _ utils.Logger,
@ -1122,6 +1130,7 @@ var _ = Describe("Server", func() {
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool, _ bool,
_ bool,
_ logging.ConnectionTracer, _ logging.ConnectionTracer,
_ uint64, _ uint64,
_ utils.Logger, _ utils.Logger,
@ -1184,6 +1193,7 @@ var _ = Describe("Server", func() {
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool, _ bool,
_ bool,
_ logging.ConnectionTracer, _ logging.ConnectionTracer,
_ uint64, _ uint64,
_ utils.Logger, _ utils.Logger,