mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
Disable anti-amplification limit by address validation token (#3326)
This commit is contained in:
parent
8c0c481da1
commit
7da024da5a
8 changed files with 75 additions and 12 deletions
|
@ -242,6 +242,7 @@ var newConnection = func(
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
tokenGenerator *handshake.TokenGenerator,
|
tokenGenerator *handshake.TokenGenerator,
|
||||||
enable0RTT bool,
|
enable0RTT bool,
|
||||||
|
clientAddressValidated bool,
|
||||||
tracer logging.ConnectionTracer,
|
tracer logging.ConnectionTracer,
|
||||||
tracingID uint64,
|
tracingID uint64,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
|
@ -288,6 +289,7 @@ var newConnection = func(
|
||||||
0,
|
0,
|
||||||
getMaxPacketSize(s.conn.RemoteAddr()),
|
getMaxPacketSize(s.conn.RemoteAddr()),
|
||||||
s.rttStats,
|
s.rttStats,
|
||||||
|
clientAddressValidated,
|
||||||
s.perspective,
|
s.perspective,
|
||||||
s.tracer,
|
s.tracer,
|
||||||
s.logger,
|
s.logger,
|
||||||
|
@ -415,6 +417,7 @@ var newClientConnection = func(
|
||||||
initialPacketNumber,
|
initialPacketNumber,
|
||||||
getMaxPacketSize(s.conn.RemoteAddr()),
|
getMaxPacketSize(s.conn.RemoteAddr()),
|
||||||
s.rttStats,
|
s.rttStats,
|
||||||
|
false, /* has no effect */
|
||||||
s.perspective,
|
s.perspective,
|
||||||
s.tracer,
|
s.tracer,
|
||||||
s.logger,
|
s.logger,
|
||||||
|
|
|
@ -107,6 +107,7 @@ var _ = Describe("Connection", func() {
|
||||||
nil, // tls.Config
|
nil, // tls.Config
|
||||||
tokenGenerator,
|
tokenGenerator,
|
||||||
false,
|
false,
|
||||||
|
false,
|
||||||
tracer,
|
tracer,
|
||||||
1234,
|
1234,
|
||||||
utils.DefaultLogger,
|
utils.DefaultLogger,
|
||||||
|
|
|
@ -6,16 +6,19 @@ import (
|
||||||
"github.com/lucas-clemente/quic-go/logging"
|
"github.com/lucas-clemente/quic-go/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler
|
// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler.
|
||||||
|
// clientAddressValidated indicates whether the address was validated beforehand by an address validation token.
|
||||||
|
// clientAddressValidated has no effect for a client.
|
||||||
func NewAckHandler(
|
func NewAckHandler(
|
||||||
initialPacketNumber protocol.PacketNumber,
|
initialPacketNumber protocol.PacketNumber,
|
||||||
initialMaxDatagramSize protocol.ByteCount,
|
initialMaxDatagramSize protocol.ByteCount,
|
||||||
rttStats *utils.RTTStats,
|
rttStats *utils.RTTStats,
|
||||||
|
clientAddressValidated bool,
|
||||||
pers protocol.Perspective,
|
pers protocol.Perspective,
|
||||||
tracer logging.ConnectionTracer,
|
tracer logging.ConnectionTracer,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) (SentPacketHandler, ReceivedPacketHandler) {
|
) (SentPacketHandler, ReceivedPacketHandler) {
|
||||||
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, pers, tracer, logger)
|
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger)
|
||||||
return sph, newReceivedPacketHandler(sph, rttStats, logger, version)
|
return sph, newReceivedPacketHandler(sph, rttStats, logger, version)
|
||||||
}
|
}
|
||||||
|
|
|
@ -101,10 +101,13 @@ var (
|
||||||
_ sentPacketTracker = &sentPacketHandler{}
|
_ sentPacketTracker = &sentPacketHandler{}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// clientAddressValidated indicates whether the address was validated beforehand by an address validation token.
|
||||||
|
// If the address was validated, the amplification limit doesn't apply. It has no effect for a client.
|
||||||
func newSentPacketHandler(
|
func newSentPacketHandler(
|
||||||
initialPN protocol.PacketNumber,
|
initialPN protocol.PacketNumber,
|
||||||
initialMaxDatagramSize protocol.ByteCount,
|
initialMaxDatagramSize protocol.ByteCount,
|
||||||
rttStats *utils.RTTStats,
|
rttStats *utils.RTTStats,
|
||||||
|
clientAddressValidated bool,
|
||||||
pers protocol.Perspective,
|
pers protocol.Perspective,
|
||||||
tracer logging.ConnectionTracer,
|
tracer logging.ConnectionTracer,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
|
@ -119,7 +122,7 @@ func newSentPacketHandler(
|
||||||
|
|
||||||
return &sentPacketHandler{
|
return &sentPacketHandler{
|
||||||
peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
|
peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
|
||||||
peerAddressValidated: pers == protocol.PerspectiveClient,
|
peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated,
|
||||||
initialPackets: newPacketNumberSpace(initialPN, false, rttStats),
|
initialPackets: newPacketNumberSpace(initialPN, false, rttStats),
|
||||||
handshakePackets: newPacketNumberSpace(0, false, rttStats),
|
handshakePackets: newPacketNumberSpace(0, false, rttStats),
|
||||||
appDataPackets: newPacketNumberSpace(0, true, rttStats),
|
appDataPackets: newPacketNumberSpace(0, true, rttStats),
|
||||||
|
|
|
@ -29,7 +29,7 @@ var _ = Describe("SentPacketHandler", func() {
|
||||||
JustBeforeEach(func() {
|
JustBeforeEach(func() {
|
||||||
lostPackets = nil
|
lostPackets = nil
|
||||||
rttStats := utils.NewRTTStats()
|
rttStats := utils.NewRTTStats()
|
||||||
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, perspective, nil, utils.DefaultLogger)
|
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, perspective, nil, utils.DefaultLogger)
|
||||||
streamFrame = wire.StreamFrame{
|
streamFrame = wire.StreamFrame{
|
||||||
StreamID: 5,
|
StreamID: 5,
|
||||||
Data: []byte{0x13, 0x37},
|
Data: []byte{0x13, 0x37},
|
||||||
|
@ -944,6 +944,26 @@ var _ = Describe("SentPacketHandler", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("amplification limit, for the server, with validated address", func() {
|
||||||
|
JustBeforeEach(func() {
|
||||||
|
rttStats := utils.NewRTTStats()
|
||||||
|
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, perspective, nil, utils.DefaultLogger)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("do not limits the window", func() {
|
||||||
|
handler.ReceivedBytes(0)
|
||||||
|
Expect(handler.SendMode()).To(Equal(SendAny))
|
||||||
|
handler.SentPacket(&Packet{
|
||||||
|
PacketNumber: 1,
|
||||||
|
Length: 900,
|
||||||
|
EncryptionLevel: protocol.EncryptionInitial,
|
||||||
|
Frames: []Frame{{Frame: &wire.PingFrame{}}},
|
||||||
|
SendTime: time.Now(),
|
||||||
|
})
|
||||||
|
Expect(handler.SendMode()).To(Equal(SendAny))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
Context("amplification limit, for the client", func() {
|
Context("amplification limit, for the client", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
perspective = protocol.PerspectiveClient
|
perspective = protocol.PerspectiveClient
|
||||||
|
|
|
@ -26,6 +26,7 @@ type Token struct {
|
||||||
RetrySrcConnectionID protocol.ConnectionID
|
RetrySrcConnectionID protocol.ConnectionID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateRemoteAddr validates the address, but does not check expiration
|
||||||
func (t *Token) ValidateRemoteAddr(addr net.Addr) bool {
|
func (t *Token) ValidateRemoteAddr(addr net.Addr) bool {
|
||||||
return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr)
|
return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr)
|
||||||
}
|
}
|
||||||
|
|
38
server.go
38
server.go
|
@ -89,6 +89,7 @@ type baseServer struct {
|
||||||
*tls.Config,
|
*tls.Config,
|
||||||
*handshake.TokenGenerator,
|
*handshake.TokenGenerator,
|
||||||
bool, /* enable 0-RTT */
|
bool, /* enable 0-RTT */
|
||||||
|
bool, /* client address validated by an address validation token */
|
||||||
logging.ConnectionTracer,
|
logging.ConnectionTracer,
|
||||||
uint64,
|
uint64,
|
||||||
utils.Logger,
|
utils.Logger,
|
||||||
|
@ -375,6 +376,26 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateToken returns false if:
|
||||||
|
// - address is invalid
|
||||||
|
// - token is expired
|
||||||
|
// - token is null
|
||||||
|
func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
|
||||||
|
if token == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !token.ValidateRemoteAddr(addr) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxTokenAge {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxRetryTokenAge {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error {
|
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error {
|
||||||
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||||
p.buffer.Release()
|
p.buffer.Release()
|
||||||
|
@ -399,23 +420,23 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
||||||
token = tok
|
token = tok
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if token != nil {
|
|
||||||
addrIsValid := token.ValidateRemoteAddr(p.remoteAddr)
|
clientAddrIsValid := s.validateToken(token, p.remoteAddr)
|
||||||
|
|
||||||
|
if token != nil && !clientAddrIsValid {
|
||||||
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
||||||
// We just ignore them, and act as if there was no token on this packet at all.
|
// We just ignore them, and act as if there was no token on this packet at all.
|
||||||
// This also means we might send a Retry later.
|
// This also means we might send a Retry later.
|
||||||
if !token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxTokenAge || !addrIsValid) {
|
if !token.IsRetryToken {
|
||||||
token = nil
|
token = nil
|
||||||
} else if token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxRetryTokenAge || !addrIsValid) {
|
} else {
|
||||||
// For Retry tokens, we send an INVALID_ERROR if
|
// For Retry tokens, we send an INVALID_ERROR if
|
||||||
// * the token is too old, or
|
// * the token is too old, or
|
||||||
// * the token is invalid, in case of a retry token.
|
// * the token is invalid, in case of a retry token.
|
||||||
go func() {
|
go func() {
|
||||||
defer p.buffer.Release()
|
defer p.buffer.Release()
|
||||||
if token != nil && token.IsRetryToken {
|
if err := s.maybeSendInvalidToken(p, hdr); err != nil {
|
||||||
if err := s.maybeSendInvalidToken(p, hdr); err != nil {
|
s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
|
||||||
s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return nil
|
return nil
|
||||||
|
@ -476,6 +497,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
||||||
s.tlsConf,
|
s.tlsConf,
|
||||||
s.tokenGenerator,
|
s.tokenGenerator,
|
||||||
s.acceptEarlyConns,
|
s.acceptEarlyConns,
|
||||||
|
clientAddrIsValid,
|
||||||
tracer,
|
tracer,
|
||||||
tracingID,
|
tracingID,
|
||||||
s.logger,
|
s.logger,
|
||||||
|
|
|
@ -287,6 +287,7 @@ var _ = Describe("Server", func() {
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
_ *handshake.TokenGenerator,
|
_ *handshake.TokenGenerator,
|
||||||
enable0RTT bool,
|
enable0RTT bool,
|
||||||
|
_ bool,
|
||||||
_ logging.ConnectionTracer,
|
_ logging.ConnectionTracer,
|
||||||
_ uint64,
|
_ uint64,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
|
@ -491,6 +492,7 @@ var _ = Describe("Server", func() {
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
_ *handshake.TokenGenerator,
|
_ *handshake.TokenGenerator,
|
||||||
enable0RTT bool,
|
enable0RTT bool,
|
||||||
|
_ bool,
|
||||||
_ logging.ConnectionTracer,
|
_ logging.ConnectionTracer,
|
||||||
_ uint64,
|
_ uint64,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
|
@ -550,6 +552,7 @@ var _ = Describe("Server", func() {
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
_ *handshake.TokenGenerator,
|
_ *handshake.TokenGenerator,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ logging.ConnectionTracer,
|
_ logging.ConnectionTracer,
|
||||||
_ uint64,
|
_ uint64,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
|
@ -603,6 +606,7 @@ var _ = Describe("Server", func() {
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
_ *handshake.TokenGenerator,
|
_ *handshake.TokenGenerator,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ logging.ConnectionTracer,
|
_ logging.ConnectionTracer,
|
||||||
_ uint64,
|
_ uint64,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
|
@ -632,6 +636,7 @@ var _ = Describe("Server", func() {
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
_ *handshake.TokenGenerator,
|
_ *handshake.TokenGenerator,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ logging.ConnectionTracer,
|
_ logging.ConnectionTracer,
|
||||||
_ uint64,
|
_ uint64,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
|
@ -702,6 +707,7 @@ var _ = Describe("Server", func() {
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
_ *handshake.TokenGenerator,
|
_ *handshake.TokenGenerator,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ logging.ConnectionTracer,
|
_ logging.ConnectionTracer,
|
||||||
_ uint64,
|
_ uint64,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
|
@ -1007,6 +1013,7 @@ var _ = Describe("Server", func() {
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
_ *handshake.TokenGenerator,
|
_ *handshake.TokenGenerator,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ logging.ConnectionTracer,
|
_ logging.ConnectionTracer,
|
||||||
_ uint64,
|
_ uint64,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
|
@ -1080,6 +1087,7 @@ var _ = Describe("Server", func() {
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
_ *handshake.TokenGenerator,
|
_ *handshake.TokenGenerator,
|
||||||
enable0RTT bool,
|
enable0RTT bool,
|
||||||
|
_ bool,
|
||||||
_ logging.ConnectionTracer,
|
_ logging.ConnectionTracer,
|
||||||
_ uint64,
|
_ uint64,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
|
@ -1122,6 +1130,7 @@ var _ = Describe("Server", func() {
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
_ *handshake.TokenGenerator,
|
_ *handshake.TokenGenerator,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ logging.ConnectionTracer,
|
_ logging.ConnectionTracer,
|
||||||
_ uint64,
|
_ uint64,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
|
@ -1184,6 +1193,7 @@ var _ = Describe("Server", func() {
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
_ *handshake.TokenGenerator,
|
_ *handshake.TokenGenerator,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ logging.ConnectionTracer,
|
_ logging.ConnectionTracer,
|
||||||
_ uint64,
|
_ uint64,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue