mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-06 21:57:36 +03:00
use Transport.VerifySourceAddress to control the Retry Mechanism (#4362)
* use Transport.VerifySourceAddress to control the Retry Mechanism This can be used to rate-limit handshakes originating from unverified source addresses. Rate-limiting for handshakes can be implemented using the GetConfigForClient callback on the Config. * pass the remote address to Transport.VerifySourceAddress
This commit is contained in:
parent
497d3f58a5
commit
9971fedd42
12 changed files with 120 additions and 382 deletions
131
server_test.go
131
server_test.go
|
@ -10,6 +10,8 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/handshake"
|
||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -48,6 +50,7 @@ var _ = Describe("Server", func() {
|
|||
data = data[:len(data)+16]
|
||||
sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
|
||||
return receivedPacket{
|
||||
rcvTime: time.Now(),
|
||||
remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
|
||||
data: data,
|
||||
buffer: buf,
|
||||
|
@ -259,7 +262,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("creates a connection when the token is accepted", func() {
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
serv.verifySourceAddress = func(net.Addr) bool { return true }
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
retryToken, err := serv.tokenGenerator.NewRetryToken(
|
||||
raddr,
|
||||
|
@ -432,7 +435,13 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("replies with a Retry packet, if a token is required", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
var called bool
|
||||
serv.verifySourceAddress = func(addr net.Addr) bool {
|
||||
Expect(addr).To(Equal(raddr))
|
||||
called = true
|
||||
return true
|
||||
}
|
||||
hdr := &wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
||||
|
@ -440,7 +449,6 @@ var _ = Describe("Server", func() {
|
|||
Version: protocol.Version1,
|
||||
}
|
||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
packet.remoteAddr = raddr
|
||||
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
|
||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||
|
@ -462,6 +470,7 @@ var _ = Describe("Server", func() {
|
|||
phm.EXPECT().Get(connID)
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(called).To(BeTrue())
|
||||
})
|
||||
|
||||
It("creates a connection, if no token is required", func() {
|
||||
|
@ -539,8 +548,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("drops packets if the receive queue is full", func() {
|
||||
serv.maxNumHandshakesTotal = 10000
|
||||
serv.maxNumHandshakesUnvalidated = 10000
|
||||
serv.verifySourceAddress = func(net.Addr) bool { return false }
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
||||
|
@ -641,17 +649,17 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("limits the number of unvalidated handshakes", func() {
|
||||
const limit = 3
|
||||
serv.maxNumHandshakesTotal = 10000
|
||||
serv.maxNumHandshakesUnvalidated = limit
|
||||
limiter := rate.NewLimiter(0, limit)
|
||||
serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() }
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
|
||||
|
||||
handshakeChan := make(chan struct{})
|
||||
connChan := make(chan *MockQUICConn, 1)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2 * limit)
|
||||
wg.Add(limit)
|
||||
done := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
|
@ -675,7 +683,7 @@ var _ = Describe("Server", func() {
|
|||
conn.EXPECT().handlePacket(gomock.Any())
|
||||
conn.EXPECT().run()
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan).Do(func() <-chan struct{} { wg.Done(); return nil })
|
||||
conn.EXPECT().HandshakeComplete().DoAndReturn(func() <-chan struct{} { wg.Done(); return done })
|
||||
return conn
|
||||
}
|
||||
|
||||
|
@ -688,7 +696,6 @@ var _ = Describe("Server", func() {
|
|||
|
||||
// Now initiate another connection attempt.
|
||||
p := getInitialWithRandomDestConnID()
|
||||
done := make(chan struct{})
|
||||
tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
||||
defer GinkgoRecover()
|
||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||
|
@ -704,34 +711,19 @@ var _ = Describe("Server", func() {
|
|||
serv.handlePacket(p)
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
||||
close(handshakeChan)
|
||||
for i := 0; i < limit; i++ {
|
||||
_, err := serv.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
for i := 0; i < limit; i++ {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
|
||||
connChan <- conn
|
||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
})
|
||||
|
||||
It("limits the number of total handshakes", func() {
|
||||
const limit = 3
|
||||
serv.maxNumHandshakesTotal = limit
|
||||
serv.maxNumHandshakesUnvalidated = limit // same limit, but we check that we send CONNECTION_REFUSED and not Retry
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
|
||||
|
||||
handshakeChan := make(chan struct{})
|
||||
connChan := make(chan *MockQUICConn, 1)
|
||||
Context("token validation", func() {
|
||||
It("decodes the token from the token field", func() {
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ *protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
|
@ -748,67 +740,15 @@ var _ = Describe("Server", func() {
|
|||
_ utils.Logger,
|
||||
_ protocol.Version,
|
||||
) quicConn {
|
||||
conn := <-connChan
|
||||
conn.EXPECT().handlePacket(gomock.Any())
|
||||
conn.EXPECT().run()
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
||||
return conn
|
||||
c := NewMockQUICConn(mockCtrl)
|
||||
c.EXPECT().handlePacket(gomock.Any())
|
||||
c.EXPECT().run()
|
||||
c.EXPECT().HandshakeComplete()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
c.EXPECT().Context().Return(ctx)
|
||||
return c
|
||||
}
|
||||
|
||||
for i := 0; i < limit; i++ {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
connChan <- conn
|
||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||
}
|
||||
|
||||
p := getInitialWithRandomDestConnID()
|
||||
done := make(chan struct{})
|
||||
tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
||||
defer GinkgoRecover()
|
||||
hdr, _, _, err := wire.ParsePacket(p.data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
||||
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||
Expect(frames).To(HaveLen(1))
|
||||
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
||||
ccf := frames[0].(*logging.ConnectionCloseFrame)
|
||||
Expect(ccf.IsApplicationError).To(BeFalse())
|
||||
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ConnectionRefused))
|
||||
})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
hdr, _, _, err := wire.ParsePacket(p.data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
checkConnectionCloseError(b, hdr, qerr.ConnectionRefused)
|
||||
return len(b), nil
|
||||
})
|
||||
serv.handlePacket(p)
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
||||
close(handshakeChan)
|
||||
for i := 0; i < limit; i++ {
|
||||
_, err := serv.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
// make sure we can enqueue and accept more connections after that
|
||||
for i := 0; i < limit; i++ {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
|
||||
connChan <- conn
|
||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||
}
|
||||
for i := 0; i < limit; i++ {
|
||||
_, err := serv.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("token validation", func() {
|
||||
It("decodes the token from the token field", func() {
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
|
||||
token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -834,7 +774,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
serv.verifySourceAddress = func(net.Addr) bool { return true }
|
||||
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr := &wire.Header{
|
||||
|
@ -870,7 +810,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
serv.verifySourceAddress = func(net.Addr) bool { return true }
|
||||
serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
|
||||
Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
|
@ -908,7 +848,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
serv.verifySourceAddress = func(net.Addr) bool { return true }
|
||||
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr := &wire.Header{
|
||||
|
@ -937,7 +877,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
serv.verifySourceAddress = func(net.Addr) bool { return true }
|
||||
serv.maxTokenAge = time.Millisecond
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
token, err := serv.tokenGenerator.NewToken(raddr)
|
||||
|
@ -966,7 +906,6 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr := &wire.Header{
|
||||
|
@ -1011,7 +950,7 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
})
|
||||
|
||||
It("closes connection that are still handshaking after Close", func() {
|
||||
PIt("closes connection that are still handshaking after Close", func() {
|
||||
serv.Close()
|
||||
|
||||
destroyed := make(chan struct{})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue