mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
use Transport.VerifySourceAddress to control the Retry Mechanism (#4362)
* use Transport.VerifySourceAddress to control the Retry Mechanism This can be used to rate-limit handshakes originating from unverified source addresses. Rate-limiting for handshakes can be implemented using the GetConfigForClient callback on the Config. * pass the remote address to Transport.VerifySourceAddress
This commit is contained in:
parent
497d3f58a5
commit
9971fedd42
12 changed files with 120 additions and 382 deletions
1
go.mod
1
go.mod
|
@ -13,6 +13,7 @@ require (
|
||||||
golang.org/x/net v0.10.0
|
golang.org/x/net v0.10.0
|
||||||
golang.org/x/sync v0.2.0
|
golang.org/x/sync v0.2.0
|
||||||
golang.org/x/sys v0.8.0
|
golang.org/x/sys v0.8.0
|
||||||
|
golang.org/x/time v0.5.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -175,6 +175,8 @@ golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||||
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
|
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||||
|
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
|
|
@ -41,6 +41,8 @@ golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||||
|
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||||
|
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
|
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
|
||||||
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
|
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
|
||||||
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
|
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
|
||||||
|
|
|
@ -11,11 +11,10 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/quicvarint"
|
|
||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||||
"github.com/quic-go/quic-go/internal/wire"
|
"github.com/quic-go/quic-go/internal/wire"
|
||||||
|
"github.com/quic-go/quic-go/quicvarint"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
@ -50,7 +49,7 @@ var _ = Describe("Handshake drop tests", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
tr := &quic.Transport{Conn: conn}
|
tr := &quic.Transport{Conn: conn}
|
||||||
if doRetry {
|
if doRetry {
|
||||||
tr.MaxUnvalidatedHandshakes = -1
|
tr.VerifySourceAddress = func(net.Addr) bool { return true }
|
||||||
}
|
}
|
||||||
ln, err = tr.Listen(tlsConf, conf)
|
ln, err = tr.Listen(tlsConf, conf)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
|
@ -54,7 +54,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||||
|
|
||||||
// 1 RTT for verifying the source address
|
// 1 RTT for verifying the source address
|
||||||
// 1 RTT for the TLS handshake
|
// 1 RTT for the TLS handshake
|
||||||
It("is forward-secure after 2 RTTs", func() {
|
It("is forward-secure after 2 RTTs with Retry", func() {
|
||||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
udpConn, err := net.ListenUDP("udp", laddr)
|
udpConn, err := net.ListenUDP("udp", laddr)
|
||||||
|
@ -62,7 +62,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||||
defer udpConn.Close()
|
defer udpConn.Close()
|
||||||
tr := &quic.Transport{
|
tr := &quic.Transport{
|
||||||
Conn: udpConn,
|
Conn: udpConn,
|
||||||
MaxUnvalidatedHandshakes: -1,
|
VerifySourceAddress: func(net.Addr) bool { return true },
|
||||||
}
|
}
|
||||||
addTracer(tr)
|
addTracer(tr)
|
||||||
defer tr.Close()
|
defer tr.Close()
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
|
@ -15,7 +14,6 @@ import (
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/qerr"
|
"github.com/quic-go/quic-go/internal/qerr"
|
||||||
"github.com/quic-go/quic-go/internal/qtls"
|
"github.com/quic-go/quic-go/internal/qtls"
|
||||||
"github.com/quic-go/quic-go/logging"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
@ -464,147 +462,6 @@ var _ = Describe("Handshake tests", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("limiting handshakes", func() {
|
|
||||||
var conn *net.UDPConn
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
conn, err = net.ListenUDP("udp", addr)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() { conn.Close() })
|
|
||||||
|
|
||||||
It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() {
|
|
||||||
const limit = 3
|
|
||||||
tr := &quic.Transport{
|
|
||||||
Conn: conn,
|
|
||||||
MaxUnvalidatedHandshakes: limit,
|
|
||||||
}
|
|
||||||
addTracer(tr)
|
|
||||||
defer tr.Close()
|
|
||||||
|
|
||||||
// Block all handshakes.
|
|
||||||
handshakes := make(chan struct{})
|
|
||||||
var tlsConf tls.Config
|
|
||||||
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
|
||||||
handshakes <- struct{}{}
|
|
||||||
return getTLSConfig(), nil
|
|
||||||
}
|
|
||||||
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer ln.Close()
|
|
||||||
|
|
||||||
const additional = 2
|
|
||||||
results := make([]struct{ retry, closed atomic.Bool }, limit+additional)
|
|
||||||
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
|
|
||||||
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
|
|
||||||
// exactly 2 to experience a Retry.
|
|
||||||
for i := 0; i < limit+additional; i++ {
|
|
||||||
go func(index int) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
quicConf := getQuicConfig(&quic.Config{
|
|
||||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
|
||||||
return &logging.ConnectionTracer{
|
|
||||||
ReceivedRetry: func(*logging.Header) { results[index].retry.Store(true) },
|
|
||||||
ClosedConnection: func(error) { results[index].closed.Store(true) },
|
|
||||||
}
|
|
||||||
},
|
|
||||||
})
|
|
||||||
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
conn.CloseWithError(0, "")
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
numRetries := func() (n int) {
|
|
||||||
for i := 0; i < limit+additional; i++ {
|
|
||||||
if results[i].retry.Load() {
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
numClosed := func() (n int) {
|
|
||||||
for i := 0; i < limit+2; i++ {
|
|
||||||
if results[i].closed.Load() {
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
Eventually(numRetries).Should(Equal(additional))
|
|
||||||
// allow the handshakes to complete
|
|
||||||
for i := 0; i < limit+additional; i++ {
|
|
||||||
Eventually(handshakes).Should(Receive())
|
|
||||||
}
|
|
||||||
Eventually(numClosed).Should(Equal(limit + additional))
|
|
||||||
Expect(numRetries()).To(Equal(additional)) // just to be on the safe side
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects connections when the number of handshakes reaches MaxHandshakes", func() {
|
|
||||||
const limit = 3
|
|
||||||
tr := &quic.Transport{
|
|
||||||
Conn: conn,
|
|
||||||
MaxHandshakes: limit,
|
|
||||||
}
|
|
||||||
addTracer(tr)
|
|
||||||
defer tr.Close()
|
|
||||||
|
|
||||||
// Block all handshakes.
|
|
||||||
handshakes := make(chan struct{})
|
|
||||||
var tlsConf tls.Config
|
|
||||||
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
|
||||||
handshakes <- struct{}{}
|
|
||||||
return getTLSConfig(), nil
|
|
||||||
}
|
|
||||||
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer ln.Close()
|
|
||||||
|
|
||||||
const additional = 2
|
|
||||||
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
|
|
||||||
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
|
|
||||||
// exactly 2 to experience a Retry.
|
|
||||||
var numSuccessful, numFailed atomic.Int32
|
|
||||||
for i := 0; i < limit+additional; i++ {
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
quicConf := getQuicConfig(&quic.Config{
|
|
||||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
|
||||||
return &logging.ConnectionTracer{
|
|
||||||
ReceivedRetry: func(*logging.Header) { Fail("didn't expect any Retry") },
|
|
||||||
}
|
|
||||||
},
|
|
||||||
})
|
|
||||||
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
|
|
||||||
if err != nil {
|
|
||||||
var transportErr *quic.TransportError
|
|
||||||
if !errors.As(err, &transportErr) || transportErr.ErrorCode != qerr.ConnectionRefused {
|
|
||||||
Fail(fmt.Sprintf("expected CONNECTION_REFUSED error, got %v", err))
|
|
||||||
}
|
|
||||||
numFailed.Add(1)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
numSuccessful.Add(1)
|
|
||||||
conn.CloseWithError(0, "")
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
Eventually(func() int { return int(numFailed.Load()) }).Should(Equal(additional))
|
|
||||||
// allow the handshakes to complete
|
|
||||||
for i := 0; i < limit; i++ {
|
|
||||||
Eventually(handshakes).Should(Receive())
|
|
||||||
}
|
|
||||||
Eventually(func() int { return int(numSuccessful.Load()) }).Should(Equal(limit))
|
|
||||||
|
|
||||||
// make sure that the server is reachable again after these handshakes have completed
|
|
||||||
go func() { <-handshakes }() // allow this handshake to complete immediately
|
|
||||||
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
conn.CloseWithError(0, "")
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("ALPN", func() {
|
Context("ALPN", func() {
|
||||||
It("negotiates an application protocol", func() {
|
It("negotiates an application protocol", func() {
|
||||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||||
|
@ -719,7 +576,7 @@ var _ = Describe("Handshake tests", func() {
|
||||||
defer udpConn.Close()
|
defer udpConn.Close()
|
||||||
tr := &quic.Transport{
|
tr := &quic.Transport{
|
||||||
Conn: udpConn,
|
Conn: udpConn,
|
||||||
MaxUnvalidatedHandshakes: -1,
|
VerifySourceAddress: func(net.Addr) bool { return true },
|
||||||
}
|
}
|
||||||
addTracer(tr)
|
addTracer(tr)
|
||||||
defer tr.Close()
|
defer tr.Close()
|
||||||
|
|
|
@ -43,7 +43,7 @@ var _ = Describe("MITM test", func() {
|
||||||
}
|
}
|
||||||
addTracer(serverTransport)
|
addTracer(serverTransport)
|
||||||
if forceAddressValidation {
|
if forceAddressValidation {
|
||||||
serverTransport.MaxUnvalidatedHandshakes = -1
|
serverTransport.VerifySourceAddress = func(net.Addr) bool { return true }
|
||||||
}
|
}
|
||||||
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
|
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
|
@ -462,7 +462,7 @@ var _ = Describe("0-RTT", func() {
|
||||||
defer udpConn.Close()
|
defer udpConn.Close()
|
||||||
tr := &quic.Transport{
|
tr := &quic.Transport{
|
||||||
Conn: udpConn,
|
Conn: udpConn,
|
||||||
MaxUnvalidatedHandshakes: -1,
|
VerifySourceAddress: func(net.Addr) bool { return true },
|
||||||
}
|
}
|
||||||
addTracer(tr)
|
addTracer(tr)
|
||||||
defer tr.Close()
|
defer tr.Close()
|
||||||
|
|
|
@ -71,7 +71,7 @@ func (s *Server) ListenAndServe() error {
|
||||||
tlsConf.NextProtos = []string{h09alpn}
|
tlsConf.NextProtos = []string{h09alpn}
|
||||||
tr := quic.Transport{Conn: conn}
|
tr := quic.Transport{Conn: conn}
|
||||||
if s.ForceRetry {
|
if s.ForceRetry {
|
||||||
tr.MaxUnvalidatedHandshakes = -1
|
tr.VerifySourceAddress = func(net.Addr) bool { return true }
|
||||||
}
|
}
|
||||||
ln, err := tr.ListenEarly(tlsConf, s.QuicConfig)
|
ln, err := tr.ListenEarly(tlsConf, s.QuicConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
72
server.go
72
server.go
|
@ -7,7 +7,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/handshake"
|
"github.com/quic-go/quic-go/internal/handshake"
|
||||||
|
@ -108,10 +107,7 @@ type baseServer struct {
|
||||||
connectionRefusedQueue chan rejectedPacket
|
connectionRefusedQueue chan rejectedPacket
|
||||||
retryQueue chan rejectedPacket
|
retryQueue chan rejectedPacket
|
||||||
|
|
||||||
maxNumHandshakesUnvalidated int
|
verifySourceAddress func(net.Addr) bool
|
||||||
maxNumHandshakesTotal int
|
|
||||||
numHandshakesUnvalidated atomic.Int64
|
|
||||||
numHandshakesValidated atomic.Int64
|
|
||||||
|
|
||||||
connQueue chan quicConn
|
connQueue chan quicConn
|
||||||
|
|
||||||
|
@ -241,7 +237,7 @@ func newServer(
|
||||||
onClose func(),
|
onClose func(),
|
||||||
tokenGeneratorKey TokenGeneratorKey,
|
tokenGeneratorKey TokenGeneratorKey,
|
||||||
maxTokenAge time.Duration,
|
maxTokenAge time.Duration,
|
||||||
maxNumHandshakesUnvalidated, maxNumHandshakesTotal int,
|
verifySourceAddress func(net.Addr) bool,
|
||||||
disableVersionNegotiation bool,
|
disableVersionNegotiation bool,
|
||||||
acceptEarly bool,
|
acceptEarly bool,
|
||||||
) *baseServer {
|
) *baseServer {
|
||||||
|
@ -251,8 +247,7 @@ func newServer(
|
||||||
config: config,
|
config: config,
|
||||||
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
|
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
|
||||||
maxTokenAge: maxTokenAge,
|
maxTokenAge: maxTokenAge,
|
||||||
maxNumHandshakesUnvalidated: maxNumHandshakesUnvalidated,
|
verifySourceAddress: verifySourceAddress,
|
||||||
maxNumHandshakesTotal: maxNumHandshakesTotal,
|
|
||||||
connIDGenerator: connIDGenerator,
|
connIDGenerator: connIDGenerator,
|
||||||
connHandler: connHandler,
|
connHandler: connHandler,
|
||||||
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
|
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
|
||||||
|
@ -567,6 +562,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
var (
|
var (
|
||||||
token *handshake.Token
|
token *handshake.Token
|
||||||
retrySrcConnID *protocol.ConnectionID
|
retrySrcConnID *protocol.ConnectionID
|
||||||
|
clientAddrVerified bool
|
||||||
)
|
)
|
||||||
origDestConnID := hdr.DestConnectionID
|
origDestConnID := hdr.DestConnectionID
|
||||||
if len(hdr.Token) > 0 {
|
if len(hdr.Token) > 0 {
|
||||||
|
@ -579,9 +575,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
token = tok
|
token = tok
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if token != nil {
|
||||||
clientAddrValidated := s.validateToken(token, p.remoteAddr)
|
clientAddrVerified = s.validateToken(token, p.remoteAddr)
|
||||||
if token != nil && !clientAddrValidated {
|
if !clientAddrVerified {
|
||||||
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
||||||
// We just ignore them, and act as if there was no token on this packet at all.
|
// We just ignore them, and act as if there was no token on this packet at all.
|
||||||
// This also means we might send a Retry later.
|
// This also means we might send a Retry later.
|
||||||
|
@ -600,25 +596,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Until the next call to handleInitialImpl, these numbers are guaranteed to not increase.
|
|
||||||
// They might decrease if another connection completes the handshake.
|
|
||||||
numHandshakesUnvalidated := s.numHandshakesUnvalidated.Load()
|
|
||||||
numHandshakesValidated := s.numHandshakesValidated.Load()
|
|
||||||
|
|
||||||
// Check the total handshake limit first. It's better to reject than to initiate a retry.
|
|
||||||
if total := numHandshakesUnvalidated + numHandshakesValidated; total >= int64(s.maxNumHandshakesTotal) {
|
|
||||||
s.logger.Debugf("Rejecting new connection. Server currently busy. Currently handshaking: %d (max %d)", total, s.maxNumHandshakesTotal)
|
|
||||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
|
||||||
select {
|
|
||||||
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
|
||||||
default:
|
|
||||||
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
|
|
||||||
p.buffer.Release()
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress(p.remoteAddr) {
|
||||||
if token == nil && numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated) {
|
|
||||||
// Retry invalidates all 0-RTT packets sent.
|
// Retry invalidates all 0-RTT packets sent.
|
||||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||||
select {
|
select {
|
||||||
|
@ -630,18 +610,11 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
connID, err := s.connIDGenerator.GenerateConnectionID()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.logger.Debugf("Changing connection ID to %s.", connID)
|
|
||||||
var conn quicConn
|
|
||||||
tracingID := nextConnTracingID()
|
|
||||||
config := s.config
|
config := s.config
|
||||||
if s.config.GetConfigForClient != nil {
|
if s.config.GetConfigForClient != nil {
|
||||||
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{
|
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{
|
||||||
RemoteAddr: p.remoteAddr,
|
RemoteAddr: p.remoteAddr,
|
||||||
AddrVerified: clientAddrValidated,
|
AddrVerified: clientAddrVerified,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
|
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
|
||||||
|
@ -656,6 +629,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
}
|
}
|
||||||
config = populateConfig(conf)
|
config = populateConfig(conf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var conn quicConn
|
||||||
|
tracingID := nextConnTracingID()
|
||||||
var tracer *logging.ConnectionTracer
|
var tracer *logging.ConnectionTracer
|
||||||
if config.Tracer != nil {
|
if config.Tracer != nil {
|
||||||
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
||||||
|
@ -665,6 +641,11 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
}
|
}
|
||||||
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
|
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
|
||||||
}
|
}
|
||||||
|
connID, err := s.connIDGenerator.GenerateConnectionID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.logger.Debugf("Changing connection ID to %s.", connID)
|
||||||
conn = s.newConn(
|
conn = s.newConn(
|
||||||
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
|
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
|
||||||
s.connHandler,
|
s.connHandler,
|
||||||
|
@ -678,7 +659,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
config,
|
config,
|
||||||
s.tlsConf,
|
s.tlsConf,
|
||||||
s.tokenGenerator,
|
s.tokenGenerator,
|
||||||
clientAddrValidated,
|
clientAddrVerified,
|
||||||
tracer,
|
tracer,
|
||||||
tracingID,
|
tracingID,
|
||||||
s.logger,
|
s.logger,
|
||||||
|
@ -702,22 +683,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if clientAddrValidated {
|
|
||||||
s.numHandshakesValidated.Add(1)
|
|
||||||
} else {
|
|
||||||
s.numHandshakesUnvalidated.Add(1)
|
|
||||||
}
|
|
||||||
go conn.run()
|
go conn.run()
|
||||||
go func() {
|
go func() {
|
||||||
completed := s.handleNewConn(conn)
|
if completed := s.handleNewConn(conn); !completed {
|
||||||
if clientAddrValidated {
|
|
||||||
if s.numHandshakesValidated.Add(-1) < 0 {
|
|
||||||
panic("server BUG: number of validated handshakes negative")
|
|
||||||
}
|
|
||||||
} else if s.numHandshakesUnvalidated.Add(-1) < 0 {
|
|
||||||
panic("server BUG: number of unvalidated handshakes negative")
|
|
||||||
}
|
|
||||||
if !completed {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
131
server_test.go
131
server_test.go
|
@ -10,6 +10,8 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/handshake"
|
"github.com/quic-go/quic-go/internal/handshake"
|
||||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
|
@ -48,6 +50,7 @@ var _ = Describe("Server", func() {
|
||||||
data = data[:len(data)+16]
|
data = data[:len(data)+16]
|
||||||
sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
|
sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
|
||||||
return receivedPacket{
|
return receivedPacket{
|
||||||
|
rcvTime: time.Now(),
|
||||||
remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
|
remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
|
||||||
data: data,
|
data: data,
|
||||||
buffer: buf,
|
buffer: buf,
|
||||||
|
@ -259,7 +262,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("creates a connection when the token is accepted", func() {
|
It("creates a connection when the token is accepted", func() {
|
||||||
serv.maxNumHandshakesUnvalidated = 0
|
serv.verifySourceAddress = func(net.Addr) bool { return true }
|
||||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
retryToken, err := serv.tokenGenerator.NewRetryToken(
|
retryToken, err := serv.tokenGenerator.NewRetryToken(
|
||||||
raddr,
|
raddr,
|
||||||
|
@ -432,7 +435,13 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
It("replies with a Retry packet, if a token is required", func() {
|
It("replies with a Retry packet, if a token is required", func() {
|
||||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||||
serv.maxNumHandshakesUnvalidated = 0
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
|
var called bool
|
||||||
|
serv.verifySourceAddress = func(addr net.Addr) bool {
|
||||||
|
Expect(addr).To(Equal(raddr))
|
||||||
|
called = true
|
||||||
|
return true
|
||||||
|
}
|
||||||
hdr := &wire.Header{
|
hdr := &wire.Header{
|
||||||
Type: protocol.PacketTypeInitial,
|
Type: protocol.PacketTypeInitial,
|
||||||
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
||||||
|
@ -440,7 +449,6 @@ var _ = Describe("Server", func() {
|
||||||
Version: protocol.Version1,
|
Version: protocol.Version1,
|
||||||
}
|
}
|
||||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
||||||
packet.remoteAddr = raddr
|
packet.remoteAddr = raddr
|
||||||
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
|
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
|
||||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||||
|
@ -462,6 +470,7 @@ var _ = Describe("Server", func() {
|
||||||
phm.EXPECT().Get(connID)
|
phm.EXPECT().Get(connID)
|
||||||
serv.handlePacket(packet)
|
serv.handlePacket(packet)
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
|
Expect(called).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("creates a connection, if no token is required", func() {
|
It("creates a connection, if no token is required", func() {
|
||||||
|
@ -539,8 +548,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops packets if the receive queue is full", func() {
|
It("drops packets if the receive queue is full", func() {
|
||||||
serv.maxNumHandshakesTotal = 10000
|
serv.verifySourceAddress = func(net.Addr) bool { return false }
|
||||||
serv.maxNumHandshakesUnvalidated = 10000
|
|
||||||
|
|
||||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
||||||
|
@ -641,17 +649,17 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
It("limits the number of unvalidated handshakes", func() {
|
It("limits the number of unvalidated handshakes", func() {
|
||||||
const limit = 3
|
const limit = 3
|
||||||
serv.maxNumHandshakesTotal = 10000
|
limiter := rate.NewLimiter(0, limit)
|
||||||
serv.maxNumHandshakesUnvalidated = limit
|
serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() }
|
||||||
|
|
||||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
||||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
|
||||||
|
|
||||||
handshakeChan := make(chan struct{})
|
|
||||||
connChan := make(chan *MockQUICConn, 1)
|
connChan := make(chan *MockQUICConn, 1)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2 * limit)
|
wg.Add(limit)
|
||||||
|
done := make(chan struct{})
|
||||||
serv.newConn = func(
|
serv.newConn = func(
|
||||||
_ sendConn,
|
_ sendConn,
|
||||||
runner connRunner,
|
runner connRunner,
|
||||||
|
@ -675,7 +683,7 @@ var _ = Describe("Server", func() {
|
||||||
conn.EXPECT().handlePacket(gomock.Any())
|
conn.EXPECT().handlePacket(gomock.Any())
|
||||||
conn.EXPECT().run()
|
conn.EXPECT().run()
|
||||||
conn.EXPECT().Context().Return(context.Background())
|
conn.EXPECT().Context().Return(context.Background())
|
||||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan).Do(func() <-chan struct{} { wg.Done(); return nil })
|
conn.EXPECT().HandshakeComplete().DoAndReturn(func() <-chan struct{} { wg.Done(); return done })
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -688,7 +696,6 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
// Now initiate another connection attempt.
|
// Now initiate another connection attempt.
|
||||||
p := getInitialWithRandomDestConnID()
|
p := getInitialWithRandomDestConnID()
|
||||||
done := make(chan struct{})
|
|
||||||
tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||||
|
@ -704,34 +711,19 @@ var _ = Describe("Server", func() {
|
||||||
serv.handlePacket(p)
|
serv.handlePacket(p)
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
|
|
||||||
close(handshakeChan)
|
|
||||||
for i := 0; i < limit; i++ {
|
for i := 0; i < limit; i++ {
|
||||||
_, err := serv.Accept(context.Background())
|
_, err := serv.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
}
|
}
|
||||||
for i := 0; i < limit; i++ {
|
|
||||||
conn := NewMockQUICConn(mockCtrl)
|
|
||||||
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
|
|
||||||
connChan <- conn
|
|
||||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
|
||||||
}
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
})
|
})
|
||||||
|
})
|
||||||
|
|
||||||
It("limits the number of total handshakes", func() {
|
Context("token validation", func() {
|
||||||
const limit = 3
|
It("decodes the token from the token field", func() {
|
||||||
serv.maxNumHandshakesTotal = limit
|
|
||||||
serv.maxNumHandshakesUnvalidated = limit // same limit, but we check that we send CONNECTION_REFUSED and not Retry
|
|
||||||
|
|
||||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
|
||||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
|
|
||||||
|
|
||||||
handshakeChan := make(chan struct{})
|
|
||||||
connChan := make(chan *MockQUICConn, 1)
|
|
||||||
serv.newConn = func(
|
serv.newConn = func(
|
||||||
_ sendConn,
|
_ sendConn,
|
||||||
runner connRunner,
|
_ connRunner,
|
||||||
_ protocol.ConnectionID,
|
_ protocol.ConnectionID,
|
||||||
_ *protocol.ConnectionID,
|
_ *protocol.ConnectionID,
|
||||||
_ protocol.ConnectionID,
|
_ protocol.ConnectionID,
|
||||||
|
@ -748,67 +740,15 @@ var _ = Describe("Server", func() {
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
_ protocol.Version,
|
_ protocol.Version,
|
||||||
) quicConn {
|
) quicConn {
|
||||||
conn := <-connChan
|
c := NewMockQUICConn(mockCtrl)
|
||||||
conn.EXPECT().handlePacket(gomock.Any())
|
c.EXPECT().handlePacket(gomock.Any())
|
||||||
conn.EXPECT().run()
|
c.EXPECT().run()
|
||||||
conn.EXPECT().Context().Return(context.Background())
|
c.EXPECT().HandshakeComplete()
|
||||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return conn
|
cancel()
|
||||||
|
c.EXPECT().Context().Return(ctx)
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < limit; i++ {
|
|
||||||
conn := NewMockQUICConn(mockCtrl)
|
|
||||||
connChan <- conn
|
|
||||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
|
||||||
}
|
|
||||||
|
|
||||||
p := getInitialWithRandomDestConnID()
|
|
||||||
done := make(chan struct{})
|
|
||||||
tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
hdr, _, _, err := wire.ParsePacket(p.data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
|
||||||
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
|
||||||
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
|
||||||
Expect(frames).To(HaveLen(1))
|
|
||||||
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
|
||||||
ccf := frames[0].(*logging.ConnectionCloseFrame)
|
|
||||||
Expect(ccf.IsApplicationError).To(BeFalse())
|
|
||||||
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ConnectionRefused))
|
|
||||||
})
|
|
||||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
defer close(done)
|
|
||||||
hdr, _, _, err := wire.ParsePacket(p.data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
checkConnectionCloseError(b, hdr, qerr.ConnectionRefused)
|
|
||||||
return len(b), nil
|
|
||||||
})
|
|
||||||
serv.handlePacket(p)
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
|
|
||||||
close(handshakeChan)
|
|
||||||
for i := 0; i < limit; i++ {
|
|
||||||
_, err := serv.Accept(context.Background())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
// make sure we can enqueue and accept more connections after that
|
|
||||||
for i := 0; i < limit; i++ {
|
|
||||||
conn := NewMockQUICConn(mockCtrl)
|
|
||||||
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
|
|
||||||
connChan <- conn
|
|
||||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
|
||||||
}
|
|
||||||
for i := 0; i < limit; i++ {
|
|
||||||
_, err := serv.Accept(context.Background())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("token validation", func() {
|
|
||||||
It("decodes the token from the token field", func() {
|
|
||||||
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
|
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
|
||||||
token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
|
token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -834,7 +774,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
|
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
|
||||||
serv.maxNumHandshakesUnvalidated = 0
|
serv.verifySourceAddress = func(net.Addr) bool { return true }
|
||||||
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
|
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
hdr := &wire.Header{
|
hdr := &wire.Header{
|
||||||
|
@ -870,7 +810,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
|
It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
|
||||||
serv.maxNumHandshakesUnvalidated = 0
|
serv.verifySourceAddress = func(net.Addr) bool { return true }
|
||||||
serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
|
serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
|
||||||
Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
|
Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
|
||||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
|
@ -908,7 +848,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
|
It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
|
||||||
serv.maxNumHandshakesUnvalidated = 0
|
serv.verifySourceAddress = func(net.Addr) bool { return true }
|
||||||
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
|
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
hdr := &wire.Header{
|
hdr := &wire.Header{
|
||||||
|
@ -937,7 +877,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
|
It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
|
||||||
serv.maxNumHandshakesUnvalidated = 0
|
serv.verifySourceAddress = func(net.Addr) bool { return true }
|
||||||
serv.maxTokenAge = time.Millisecond
|
serv.maxTokenAge = time.Millisecond
|
||||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
token, err := serv.tokenGenerator.NewToken(raddr)
|
token, err := serv.tokenGenerator.NewToken(raddr)
|
||||||
|
@ -966,7 +906,6 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
|
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
|
||||||
serv.maxNumHandshakesUnvalidated = 0
|
|
||||||
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
|
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
hdr := &wire.Header{
|
hdr := &wire.Header{
|
||||||
|
@ -1011,7 +950,7 @@ var _ = Describe("Server", func() {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
It("closes connection that are still handshaking after Close", func() {
|
PIt("closes connection that are still handshaking after Close", func() {
|
||||||
serv.Close()
|
serv.Close()
|
||||||
|
|
||||||
destroyed := make(chan struct{})
|
destroyed := make(chan struct{})
|
||||||
|
|
52
transport.go
52
transport.go
|
@ -5,7 +5,6 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"math"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -19,18 +18,6 @@ import (
|
||||||
|
|
||||||
var errListenerAlreadySet = errors.New("listener already set")
|
var errListenerAlreadySet = errors.New("listener already set")
|
||||||
|
|
||||||
const (
|
|
||||||
// defaultMaxNumUnvalidatedHandshakes is the default value for Transport.MaxUnvalidatedHandshakes.
|
|
||||||
defaultMaxNumUnvalidatedHandshakes = 128
|
|
||||||
// defaultMaxNumHandshakes is the default value for Transport.MaxHandshakes.
|
|
||||||
// It's not clear how to choose a reasonable value that works for all use cases.
|
|
||||||
// In production, implementations should:
|
|
||||||
// 1. Choose a lower value.
|
|
||||||
// 2. Implement some kind of IP-address based filtering using the Config.GetConfigForClient
|
|
||||||
// callback in order to prevent flooding attacks from a single / small number of IP addresses.
|
|
||||||
defaultMaxNumHandshakes = math.MaxInt32
|
|
||||||
)
|
|
||||||
|
|
||||||
// The Transport is the central point to manage incoming and outgoing QUIC connections.
|
// The Transport is the central point to manage incoming and outgoing QUIC connections.
|
||||||
// QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
|
// QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
|
||||||
// This means that a single UDP socket can be used for listening for incoming connections, as well as
|
// This means that a single UDP socket can be used for listening for incoming connections, as well as
|
||||||
|
@ -91,24 +78,16 @@ type Transport struct {
|
||||||
// It has no effect for clients.
|
// It has no effect for clients.
|
||||||
DisableVersionNegotiationPackets bool
|
DisableVersionNegotiationPackets bool
|
||||||
|
|
||||||
// MaxUnvalidatedHandshakes is the maximum number of concurrent incoming QUIC handshakes
|
// VerifySourceAddress decides if a connection attempt originating from unvalidated source
|
||||||
// originating from unvalidated source addresses.
|
// addresses first needs to go through source address validation using QUIC's Retry mechanism,
|
||||||
// If the number of handshakes from unvalidated addresses reaches this number, new incoming
|
// as described in RFC 9000 section 8.1.2.
|
||||||
// connection attempts will need to proof reachability at the respective source address using the
|
// Note that the address passed to this callback is unvalidated, and might be spoofed in case
|
||||||
// Retry mechanism, as described in RFC 9000 section 8.1.2.
|
// of an attack.
|
||||||
// Validating the source address adds one additional network roundtrip to the handshake.
|
// Validating the source address adds one additional network roundtrip to the handshake,
|
||||||
// If unset, a default value of 128 will be used.
|
// and should therefore only be used if a suspiciously high number of incoming connection is recorded.
|
||||||
// When set to a negative value, every connection attempt will need to validate the source address.
|
// For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable
|
||||||
// It does not make sense to set this value higher than MaxHandshakes.
|
// implementation of this callback (negating its return value).
|
||||||
MaxUnvalidatedHandshakes int
|
VerifySourceAddress func(net.Addr) bool
|
||||||
// MaxHandshakes is the maximum number of concurrent incoming handshakes, both from validated
|
|
||||||
// and unvalidated source addresses.
|
|
||||||
// If unset, the number of concurrent handshakes will not be limited.
|
|
||||||
// Applications should choose a reasonable value based on their thread model, and consider
|
|
||||||
// implementing IP-based rate limiting using Config.GetConfigForClient.
|
|
||||||
// If the number of handshakes reaches this number, new connection attempts will be rejected by
|
|
||||||
// terminating the connection attempt using a CONNECTION_REFUSED error.
|
|
||||||
MaxHandshakes int
|
|
||||||
|
|
||||||
// A Tracer traces events that don't belong to a single QUIC connection.
|
// A Tracer traces events that don't belong to a single QUIC connection.
|
||||||
// Tracer.Close is called when the transport is closed.
|
// Tracer.Close is called when the transport is closed.
|
||||||
|
@ -185,14 +164,6 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
|
||||||
if err := t.init(false); err != nil {
|
if err := t.init(false); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
maxUnvalidatedHandshakes := t.MaxUnvalidatedHandshakes
|
|
||||||
if maxUnvalidatedHandshakes == 0 {
|
|
||||||
maxUnvalidatedHandshakes = defaultMaxNumUnvalidatedHandshakes
|
|
||||||
}
|
|
||||||
maxHandshakes := t.MaxHandshakes
|
|
||||||
if maxHandshakes == 0 {
|
|
||||||
maxHandshakes = defaultMaxNumHandshakes
|
|
||||||
}
|
|
||||||
s := newServer(
|
s := newServer(
|
||||||
t.conn,
|
t.conn,
|
||||||
t.handlerMap,
|
t.handlerMap,
|
||||||
|
@ -203,8 +174,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
|
||||||
t.closeServer,
|
t.closeServer,
|
||||||
*t.TokenGeneratorKey,
|
*t.TokenGeneratorKey,
|
||||||
t.MaxTokenAge,
|
t.MaxTokenAge,
|
||||||
maxUnvalidatedHandshakes,
|
t.VerifySourceAddress,
|
||||||
maxHandshakes,
|
|
||||||
t.DisableVersionNegotiationPackets,
|
t.DisableVersionNegotiationPackets,
|
||||||
allow0RTT,
|
allow0RTT,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue