uquic/server_test.go
Marten Seemann 8e93770dd3
avoid lock contention when accepting new connections (#4313)
* avoid lock contention when accepting new connections

The server used to hold the packet handler map's lock while creating the
connection struct for a newly accepted connection. This was intended to
make sure that no two connections with the same Destination Connection
ID could be created.

This is a corner case: it can only happen if two Initial packets with
the same Destination Connection ID are received at the same time. If
the second one is received after the first one has already been
processed, it would be routed to the first connection. We don't need to
optimized for this corner case. It's ok to create a new connection in
that case, and immediately close it if this collision is detected.

* only pass 0-RTT to the connection if it was actually accepted
2024-02-08 19:34:42 -08:00

1584 lines
58 KiB
Go

package quic
import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/handshake"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/testdata"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"go.uber.org/mock/gomock"
)
var _ = Describe("Server", func() {
var (
conn *MockPacketConn
tlsConf *tls.Config
)
getPacket := func(hdr *wire.Header, p []byte) receivedPacket {
buf := getPacketBuffer()
hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
var err error
buf.Data, err = (&wire.ExtendedHeader{
Header: *hdr,
PacketNumber: 0x42,
PacketNumberLen: protocol.PacketNumberLen4,
}).Append(buf.Data, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
n := len(buf.Data)
buf.Data = append(buf.Data, p...)
data := buf.Data
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
data = data[:len(data)+16]
sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
return receivedPacket{
remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
data: data,
buffer: buf,
}
}
getInitial := func(destConnID protocol.ConnectionID) receivedPacket {
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: destConnID,
Version: protocol.Version1,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
p.buffer = getPacketBuffer()
p.remoteAddr = senderAddr
return p
}
getInitialWithRandomDestConnID := func() receivedPacket {
b := make([]byte, 10)
_, err := rand.Read(b)
Expect(err).ToNot(HaveOccurred())
return getInitial(protocol.ParseConnectionID(b))
}
parseHeader := func(data []byte) *wire.Header {
hdr, _, _, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
return hdr
}
checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) {
replyHdr := parseHeader(b)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
Expect(err).ToNot(HaveOccurred())
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
Expect(err).ToNot(HaveOccurred())
_, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version)
Expect(err).ToNot(HaveOccurred())
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := f.(*wire.ConnectionCloseFrame)
Expect(ccf.IsApplicationError).To(BeFalse())
Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode))
Expect(ccf.ReasonPhrase).To(BeEmpty())
}
BeforeEach(func() {
conn = NewMockPacketConn(mockCtrl)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
wait := make(chan struct{})
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) {
<-wait
return 0, nil, errors.New("done")
}).MaxTimes(1)
conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) error {
close(wait)
conn.EXPECT().SetReadDeadline(time.Time{})
return nil
}).MaxTimes(1)
tlsConf = testdata.GetTLSConfig()
tlsConf.NextProtos = []string{"proto1"}
})
It("errors when no tls.Config is given", func() {
_, err := ListenAddr("localhost:0", nil, nil)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set"))
})
It("errors when the Config contains an invalid version", func() {
version := protocol.Version(0x1234)
_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.Version{version}})
Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
})
It("fills in default values if options are not set in the Config", func() {
ln, err := Listen(conn, tlsConf, &Config{})
Expect(err).ToNot(HaveOccurred())
server := ln.baseServer
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
Expect(server.config.KeepAlivePeriod).To(BeZero())
// stop the listener
Expect(ln.Close()).To(Succeed())
})
It("setups with the right values", func() {
supportedVersions := []protocol.Version{protocol.Version1}
config := Config{
Versions: supportedVersions,
HandshakeIdleTimeout: 1337 * time.Hour,
MaxIdleTimeout: 42 * time.Minute,
KeepAlivePeriod: 5 * time.Second,
}
ln, err := Listen(conn, tlsConf, &config)
Expect(err).ToNot(HaveOccurred())
server := ln.baseServer
Expect(server.connHandler).ToNot(BeNil())
Expect(server.config.Versions).To(Equal(supportedVersions))
Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
// stop the listener
Expect(ln.Close()).To(Succeed())
})
It("listens on a given address", func() {
addr := "127.0.0.1:13579"
ln, err := ListenAddr(addr, tlsConf, &Config{})
Expect(err).ToNot(HaveOccurred())
Expect(ln.Addr().String()).To(Equal(addr))
// stop the listener
Expect(ln.Close()).To(Succeed())
})
It("errors if given an invalid address", func() {
addr := "127.0.0.1"
_, err := ListenAddr(addr, tlsConf, &Config{})
Expect(err).To(BeAssignableToTypeOf(&net.AddrError{}))
})
It("errors if given an invalid address", func() {
addr := "1.1.1.1:1111"
_, err := ListenAddr(addr, tlsConf, &Config{})
Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
})
Context("server accepting connections that completed the handshake", func() {
var (
tr *Transport
serv *baseServer
phm *MockPacketHandlerManager
tracer *mocklogging.MockTracer
)
BeforeEach(func() {
var t *logging.Tracer
t, tracer = mocklogging.NewMockTracer(mockCtrl)
tr = &Transport{Conn: conn, Tracer: t}
ln, err := tr.Listen(tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
serv = ln.baseServer
phm = NewMockPacketHandlerManager(mockCtrl)
serv.connHandler = phm
})
AfterEach(func() {
tracer.EXPECT().Close()
tr.Close()
})
Context("handling packets", func() {
It("drops Initial packets with a too short connection ID", func() {
p := getPacket(&wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
Version: serv.config.Versions[0],
}, nil)
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
serv.handlePacket(p)
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
It("drops too small Initial", func() {
p := getPacket(&wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize-100))
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
serv.handlePacket(p)
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
It("drops non-Initial packets", func() {
p := getPacket(&wire.Header{
Type: protocol.PacketTypeHandshake,
Version: serv.config.Versions[0],
}, []byte("invalid"))
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket)
serv.handlePacket(p)
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
It("passes packets to existing connections", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
conn := NewMockPacketHandler(mockCtrl)
phm.EXPECT().Get(connID).Return(conn, true)
handled := make(chan struct{})
conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
serv.handlePacket(p)
Eventually(handled).Should(BeClosed())
})
It("creates a connection when the token is accepted", func() {
serv.maxNumHandshakesUnvalidated = 0
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
retryToken, err := serv.tokenGenerator.NewRetryToken(
raddr,
protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}),
protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
)
Expect(err).ToNot(HaveOccurred())
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: connID,
Version: protocol.Version1,
Token: retryToken,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
p.remoteAddr = raddr
run := make(chan struct{})
var token protocol.StatelessResetToken
rand.Read(token[:])
var newConnID protocol.ConnectionID
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ sendConn,
_ connRunner,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
_ ConnectionIDGenerator,
tokenP protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})))
Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
Expect(destConnID).To(Equal(hdr.SrcConnectionID))
// make sure we're using a server-generated connection ID
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
newConnID = srcConnID
Expect(tokenP).To(Equal(token))
conn.EXPECT().handlePacket(p)
conn.EXPECT().run().Do(func() error { close(run); return nil })
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
phm.EXPECT().Get(connID)
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool {
Expect(cid).To(Equal(newConnID))
return true
})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.handlePacket(p)
// the Handshake packet is written by the connection.
// Make sure there are no Write calls on the packet conn.
time.Sleep(50 * time.Millisecond)
close(done)
}()
// make sure we're using a server-generated connection ID
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
// shutdown
conn.EXPECT().closeWithTransportError(gomock.Any())
})
It("sends a Version Negotiation Packet for unsupported versions", func() {
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
packet := getPacket(&wire.Header{
Type: protocol.PacketTypeHandshake,
SrcConnectionID: srcConnID,
DestConnectionID: destConnID,
Version: 0x42,
}, make([]byte, protocol.MinUnknownVersionPacketSize))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.Version) {
Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue())
dest, src, versions, err := wire.ParseVersionNegotiationPacket(b)
Expect(err).ToNot(HaveOccurred())
Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
Expect(versions).ToNot(ContainElement(protocol.Version(0x42)))
return len(b), nil
})
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("doesn't send a Version Negotiation packets if sending them is disabled", func() {
serv.disableVersionNegotiation = true
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
packet := getPacket(&wire.Header{
Type: protocol.PacketTypeHandshake,
SrcConnectionID: srcConnID,
DestConnectionID: destConnID,
Version: 0x42,
}, make([]byte, protocol.MinUnknownVersionPacketSize))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
done := make(chan struct{})
serv.handlePacket(packet)
Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed())
})
It("ignores Version Negotiation packets", func() {
data := wire.ComposeVersionNegotiation(
protocol.ArbitraryLenConnectionID{1, 2, 3, 4},
protocol.ArbitraryLenConnectionID{4, 3, 2, 1},
[]protocol.Version{1, 2, 3},
)
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
done := make(chan struct{})
tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(done)
})
serv.handlePacket(receivedPacket{
remoteAddr: raddr,
data: data,
buffer: getPacketBuffer(),
})
Eventually(done).Should(BeClosed())
// make sure no other packet is sent
time.Sleep(scaleDuration(20 * time.Millisecond))
})
It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() {
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
p := getPacket(&wire.Header{
Type: protocol.PacketTypeHandshake,
SrcConnectionID: srcConnID,
DestConnectionID: destConnID,
Version: 0x42,
}, make([]byte, protocol.MinUnknownVersionPacketSize-50))
Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
p.remoteAddr = raddr
done := make(chan struct{})
tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(done)
})
serv.handlePacket(p)
Eventually(done).Should(BeClosed())
// make sure no other packet is sent
time.Sleep(scaleDuration(20 * time.Millisecond))
})
It("replies with a Retry packet, if a token is required", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
serv.maxNumHandshakesUnvalidated = 0
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: connID,
Version: protocol.Version1,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(replyHdr.Token).ToNot(BeEmpty())
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
replyHdr := parseHeader(b)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(replyHdr.Token).ToNot(BeEmpty())
Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:]))
return len(b), nil
})
phm.EXPECT().Get(connID)
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("creates a connection, if no token is required", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: connID,
Version: protocol.Version1,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
run := make(chan struct{})
var token protocol.StatelessResetToken
rand.Read(token[:])
var newConnID protocol.ConnectionID
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ sendConn,
_ connRunner,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
_ ConnectionIDGenerator,
tokenP protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
Expect(origDestConnID).To(Equal(hdr.DestConnectionID))
Expect(retrySrcConnID).To(BeNil())
Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
Expect(destConnID).To(Equal(hdr.SrcConnectionID))
// make sure we're using a server-generated connection ID
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
newConnID = srcConnID
Expect(tokenP).To(Equal(token))
conn.EXPECT().handlePacket(p)
conn.EXPECT().run().Do(func() error { close(run); return nil })
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token),
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool {
Expect(c).To(Equal(newConnID))
return true
}),
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.handlePacket(p)
// the Handshake packet is written by the connection
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
close(done)
}()
// make sure we're using a server-generated connection ID
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
// shutdown
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
})
It("drops packets if the receive queue is full", func() {
serv.maxNumHandshakesTotal = 10000
serv.maxNumHandshakesUnvalidated = 10000
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
acceptConn := make(chan struct{})
var counter atomic.Uint32
serv.newConn = func(
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
<-acceptConn
counter.Add(1)
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
conn.EXPECT().run().MaxTimes(1)
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
// shutdown
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
return conn
}
p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
serv.handlePacket(p)
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1)
var wg sync.WaitGroup
for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ {
wg.Add(1)
go func() {
defer GinkgoRecover()
defer wg.Done()
serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})))
}()
}
wg.Wait()
close(acceptConn)
Eventually(
func() uint32 { return counter.Load() },
scaleDuration(100*time.Millisecond),
).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
})
PIt("only creates a single connection for a duplicate Initial", func() {
var createdConn bool
serv.newConn = func(
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
createdConn = true
return NewMockQUICConn(mockCtrl)
}
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
p := getInitial(connID)
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) (int, error) { close(done); return 0, nil })
Expect(serv.handlePacketImpl(p)).To(BeTrue())
Expect(createdConn).To(BeFalse())
Eventually(done).Should(BeClosed())
})
It("limits the number of unvalidated handshakes", func() {
const limit = 3
serv.maxNumHandshakesTotal = 10000
serv.maxNumHandshakesUnvalidated = limit
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
handshakeChan := make(chan struct{})
connChan := make(chan *MockQUICConn, 1)
var wg sync.WaitGroup
wg.Add(2 * limit)
serv.newConn = func(
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
conn := <-connChan
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().run()
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(handshakeChan).Do(func() <-chan struct{} { wg.Done(); return nil })
return conn
}
// Initiate the maximum number of allowed connection attempts.
for i := 0; i < limit; i++ {
conn := NewMockQUICConn(mockCtrl)
connChan <- conn
serv.handlePacket(getInitialWithRandomDestConnID())
}
// Now initiate another connection attempt.
p := getInitialWithRandomDestConnID()
done := make(chan struct{})
tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
defer GinkgoRecover()
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
})
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer GinkgoRecover()
defer close(done)
hdr, _, _, err := wire.ParsePacket(b)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
return len(b), nil
})
serv.handlePacket(p)
Eventually(done).Should(BeClosed())
close(handshakeChan)
for i := 0; i < limit; i++ {
_, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
}
for i := 0; i < limit; i++ {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
connChan <- conn
serv.handlePacket(getInitialWithRandomDestConnID())
}
wg.Wait()
})
It("limits the number of total handshakes", func() {
const limit = 3
serv.maxNumHandshakesTotal = limit
serv.maxNumHandshakesUnvalidated = limit // same limit, but we check that we send CONNECTION_REFUSED and not Retry
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
handshakeChan := make(chan struct{})
connChan := make(chan *MockQUICConn, 1)
serv.newConn = func(
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
conn := <-connChan
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().run()
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
return conn
}
for i := 0; i < limit; i++ {
conn := NewMockQUICConn(mockCtrl)
connChan <- conn
serv.handlePacket(getInitialWithRandomDestConnID())
}
p := getInitialWithRandomDestConnID()
done := make(chan struct{})
tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
defer GinkgoRecover()
hdr, _, _, err := wire.ParsePacket(p.data)
Expect(err).ToNot(HaveOccurred())
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(frames).To(HaveLen(1))
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := frames[0].(*logging.ConnectionCloseFrame)
Expect(ccf.IsApplicationError).To(BeFalse())
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ConnectionRefused))
})
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer GinkgoRecover()
defer close(done)
hdr, _, _, err := wire.ParsePacket(p.data)
Expect(err).ToNot(HaveOccurred())
checkConnectionCloseError(b, hdr, qerr.ConnectionRefused)
return len(b), nil
})
serv.handlePacket(p)
Eventually(done).Should(BeClosed())
close(handshakeChan)
for i := 0; i < limit; i++ {
_, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
}
// make sure we can enqueue and accept more connections after that
for i := 0; i < limit; i++ {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
connChan <- conn
serv.handlePacket(getInitialWithRandomDestConnID())
}
for i := 0; i < limit; i++ {
_, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
}
})
})
Context("token validation", func() {
It("decodes the token from the token field", func() {
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
packet := getPacket(&wire.Header{
Type: protocol.PacketTypeInitial,
Token: token,
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
done := make(chan struct{})
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool {
close(done)
return true
})
phm.EXPECT().Remove(gomock.Any()).AnyTimes()
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
serv.maxNumHandshakesUnvalidated = 0
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
Token: token,
Version: protocol.Version1,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(frames).To(HaveLen(1))
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := frames[0].(*logging.ConnectionCloseFrame)
Expect(ccf.IsApplicationError).To(BeFalse())
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
checkConnectionCloseError(b, hdr, qerr.InvalidToken)
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
serv.maxNumHandshakesUnvalidated = 0
serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
time.Sleep(2 * time.Millisecond) // make sure the token is expired
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
Token: token,
Version: protocol.Version1,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(frames).To(HaveLen(1))
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := frames[0].(*logging.ConnectionCloseFrame)
Expect(ccf.IsApplicationError).To(BeFalse())
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
checkConnectionCloseError(b, hdr, qerr.InvalidToken)
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
serv.maxNumHandshakesUnvalidated = 0
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
Token: token,
Version: protocol.Version1,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
replyHdr := parseHeader(b)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
// make sure there are no Write calls on the packet conn
Eventually(done).Should(BeClosed())
})
It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
serv.maxNumHandshakesUnvalidated = 0
serv.maxTokenAge = time.Millisecond
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
token, err := serv.tokenGenerator.NewToken(raddr)
Expect(err).ToNot(HaveOccurred())
time.Sleep(2 * time.Millisecond) // make sure the token is expired
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
Token: token,
Version: protocol.Version1,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
serv.maxNumHandshakesUnvalidated = 0
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
Token: token,
Version: protocol.Version1,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
done := make(chan struct{})
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
Eventually(done).Should(BeClosed())
})
})
Context("accepting connections", func() {
It("returns Accept when closed", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(ErrServerClosed))
close(done)
}()
serv.Close()
Eventually(done).Should(BeClosed())
})
It("returns immediately, if an error occurred before", func() {
serv.Close()
for i := 0; i < 3; i++ {
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(ErrServerClosed))
}
})
It("closes connection that are still handshaking after Close", func() {
serv.Close()
destroyed := make(chan struct{})
serv.newConn = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
conf *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) })
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().run().MaxTimes(1)
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
)
Eventually(destroyed).Should(BeClosed())
})
It("returns when the context is canceled", func() {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept(ctx)
Expect(err).To(MatchError("context canceled"))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
cancel()
Eventually(done).Should(BeClosed())
})
It("uses the config returned by GetConfigClient", func() {
conn := NewMockQUICConn(mockCtrl)
conf := &Config{MaxIncomingStreams: 1234}
serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(s).To(Equal(conn))
close(done)
}()
handshakeChan := make(chan struct{})
serv.newConn = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
conf *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().run()
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
)
Consistently(done).ShouldNot(BeClosed())
close(handshakeChan) // complete the handshake
Eventually(done).Should(BeClosed())
})
It("rejects a connection attempt when GetConfigClient returns an error", func() {
serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
phm.EXPECT().Get(gomock.Any())
done := make(chan struct{})
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
rejectHdr := parseHeader(b)
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
return len(b), nil
})
serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
)
Eventually(done).Should(BeClosed())
})
It("accepts new connections when the handshake completes", func() {
conn := NewMockQUICConn(mockCtrl)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(s).To(Equal(conn))
close(done)
}()
handshakeChan := make(chan struct{})
serv.newConn = func(
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().run()
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
)
Consistently(done).ShouldNot(BeClosed())
close(handshakeChan) // complete the handshake
Eventually(done).Should(BeClosed())
})
})
})
Context("server accepting connections that haven't completed the handshake", func() {
var (
serv *EarlyListener
phm *MockPacketHandlerManager
)
BeforeEach(func() {
var err error
serv, err = ListenEarly(conn, tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
phm = NewMockPacketHandlerManager(mockCtrl)
serv.baseServer.connHandler = phm
})
AfterEach(func() {
serv.Close()
})
It("accepts new connections when they become ready", func() {
conn := NewMockQUICConn(mockCtrl)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(s).To(Equal(conn))
close(done)
}()
ready := make(chan struct{})
serv.baseServer.newConn = func(
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().run()
conn.EXPECT().earlyConnReady().Return(ready)
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.baseServer.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
)
Consistently(done).ShouldNot(BeClosed())
close(ready)
Eventually(done).Should(BeClosed())
})
It("rejects new connection attempts if the accept queue is full", func() {
connChan := make(chan *MockQUICConn, 1)
var wg sync.WaitGroup // to make sure the test fully completes
wg.Add(protocol.MaxAcceptQueueSize + 1)
serv.baseServer.newConn = func(
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
defer wg.Done()
ready := make(chan struct{})
close(ready)
conn := <-connChan
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().run()
conn.EXPECT().earlyConnReady().Return(ready)
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
conn := NewMockQUICConn(mockCtrl)
connChan <- conn
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
}
Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize))
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().closeWithTransportError(ConnectionRefused)
connChan <- conn
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
wg.Wait()
})
It("doesn't accept new connections if they were closed in the mean time", func() {
p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
ctx, cancel := context.WithCancel(context.Background())
connCreated := make(chan struct{})
conn := NewMockQUICConn(mockCtrl)
serv.baseServer.newConn = func(
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
conn.EXPECT().handlePacket(p)
conn.EXPECT().run()
conn.EXPECT().earlyConnReady()
conn.EXPECT().Context().Return(ctx)
close(connCreated)
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.baseServer.handlePacket(p)
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
Eventually(connCreated).Should(BeClosed())
cancel()
time.Sleep(scaleDuration(200 * time.Millisecond))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.Accept(context.Background())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
})
Context("0-RTT", func() {
var (
tr *Transport
serv *baseServer
phm *MockPacketHandlerManager
tracer *mocklogging.MockTracer
)
BeforeEach(func() {
var t *logging.Tracer
t, tracer = mocklogging.NewMockTracer(mockCtrl)
tr = &Transport{Conn: conn, Tracer: t}
ln, err := tr.ListenEarly(tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
phm = NewMockPacketHandlerManager(mockCtrl)
serv = ln.baseServer
serv.connHandler = phm
})
AfterEach(func() {
tracer.EXPECT().Close()
Expect(tr.Close()).To(Succeed())
})
It("passes packets to existing connections", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 100))
conn := NewMockPacketHandler(mockCtrl)
phm.EXPECT().Get(connID).Return(conn, true)
handled := make(chan struct{})
conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
serv.handlePacket(p)
Eventually(handled).Should(BeClosed())
})
It("queues 0-RTT packets, up to Max0RTTQueueSize", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
var zeroRTTPackets []receivedPacket
for i := 0; i < protocol.Max0RTTQueueLen; i++ {
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 100+i))
phm.EXPECT().Get(connID)
serv.handlePacket(p)
zeroRTTPackets = append(zeroRTTPackets, p)
}
// send one more packet, this one should be dropped
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 200))
phm.EXPECT().Get(connID)
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
serv.handlePacket(p)
initial := getPacket(&wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
called := make(chan struct{})
serv.newConn = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
conn := NewMockQUICConn(mockCtrl)
var calls []any
calls = append(calls, conn.EXPECT().handlePacket(initial))
for _, p := range zeroRTTPackets {
calls = append(calls, conn.EXPECT().handlePacket(p))
}
gomock.InOrder(calls...)
conn.EXPECT().run()
conn.EXPECT().earlyConnReady()
conn.EXPECT().Context().Return(context.Background())
close(called)
// shutdown
conn.EXPECT().closeWithTransportError(gomock.Any())
return conn
}
phm.EXPECT().Get(connID)
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.handlePacket(initial)
Eventually(called).Should(BeClosed())
})
It("limits the number of queues", func() {
for i := 0; i < protocol.Max0RTTQueues; i++ {
b := make([]byte, 16)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 100+i))
phm.EXPECT().Get(connID)
serv.handlePacket(p)
}
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 200))
phm.EXPECT().Get(connID)
dropped := make(chan struct{})
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(dropped)
})
serv.handlePacket(p)
Eventually(dropped).Should(BeClosed())
})
It("drops queues after a while", func() {
now := time.Now()
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 200))
p.rcvTime = now
connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9})
p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2)
p2 := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID2,
Version: serv.config.Versions[0],
}, make([]byte, 300))
p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet
dropped1 := make(chan struct{})
dropped2 := make(chan struct{})
// need to register the call before handling the packet to avoid race condition
gomock.InOrder(
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(dropped1)
}),
tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(dropped2)
}),
)
phm.EXPECT().Get(connID)
serv.handlePacket(p)
// There's no cleanup Go routine.
// Cleanup is triggered when new packets are received.
phm.EXPECT().Get(connID2)
serv.handlePacket(p2)
// make sure no cleanup is executed
Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed())
// There's no cleanup Go routine.
// Cleanup is triggered when new packets are received.
connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0})
p3 := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID3,
Version: serv.config.Versions[0],
}, make([]byte, 200))
p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
phm.EXPECT().Get(connID3)
serv.handlePacket(p3)
Eventually(dropped1).Should(BeClosed())
Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed())
// make sure the second packet is also cleaned up
connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1})
p4 := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID4,
Version: serv.config.Versions[0],
}, make([]byte, 200))
p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
phm.EXPECT().Get(connID4)
serv.handlePacket(p4)
Eventually(dropped2).Should(BeClosed())
})
})
})