package quic import ( "context" "crypto/rand" "errors" "net" "sync" "sync/atomic" "time" tls "github.com/refraction-networking/utls" "golang.org/x/time/rate" "github.com/refraction-networking/uquic/internal/handshake" mocklogging "github.com/refraction-networking/uquic/internal/mocks/logging" "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" "github.com/refraction-networking/uquic/internal/testdata" "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" "github.com/refraction-networking/uquic/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "go.uber.org/mock/gomock" ) var _ = Describe("Server", func() { var ( conn *MockPacketConn tlsConf *tls.Config ) getPacket := func(hdr *wire.Header, p []byte) receivedPacket { buf := getPacketBuffer() hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 var err error buf.Data, err = (&wire.ExtendedHeader{ Header: *hdr, PacketNumber: 0x42, PacketNumberLen: protocol.PacketNumberLen4, }).Append(buf.Data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) n := len(buf.Data) buf.Data = append(buf.Data, p...) data := buf.Data sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) data = data[:len(data)+16] sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n]) return receivedPacket{ rcvTime: time.Now(), remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, data: data, buffer: buf, } } getInitial := func(destConnID protocol.ConnectionID) receivedPacket { senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: destConnID, Version: protocol.Version1, } p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) p.buffer = getPacketBuffer() p.remoteAddr = senderAddr return p } getInitialWithRandomDestConnID := func() receivedPacket { b := make([]byte, 10) _, err := rand.Read(b) Expect(err).ToNot(HaveOccurred()) return getInitial(protocol.ParseConnectionID(b)) } parseHeader := func(data []byte) *wire.Header { hdr, _, _, err := wire.ParsePacket(data) Expect(err).ToNot(HaveOccurred()) return hdr } checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) { replyHdr := parseHeader(b) Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID)) _, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version) Expect(err).ToNot(HaveOccurred()) data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) Expect(err).ToNot(HaveOccurred()) _, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) ccf := f.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode)) Expect(ccf.ReasonPhrase).To(BeEmpty()) } BeforeEach(func() { conn = NewMockPacketConn(mockCtrl) conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() wait := make(chan struct{}) conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) { <-wait return 0, nil, errors.New("done") }).MaxTimes(1) conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) error { close(wait) conn.EXPECT().SetReadDeadline(time.Time{}) return nil }).MaxTimes(1) tlsConf = testdata.GetTLSConfig() tlsConf.NextProtos = []string{"proto1"} }) It("errors when no tls.Config is given", func() { _, err := ListenAddr("localhost:0", nil, nil) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set")) }) It("errors when the Config contains an invalid version", func() { version := protocol.Version(0x1234) _, err := Listen(nil, tlsConf, &Config{Versions: []protocol.Version{version}}) Expect(err).To(MatchError("invalid QUIC version: 0x1234")) }) It("fills in default values if options are not set in the Config", func() { ln, err := Listen(conn, tlsConf, &Config{}) Expect(err).ToNot(HaveOccurred()) server := ln.baseServer Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) Expect(server.config.KeepAlivePeriod).To(BeZero()) // stop the listener Expect(ln.Close()).To(Succeed()) }) It("setups with the right values", func() { supportedVersions := []protocol.Version{protocol.Version1} config := Config{ Versions: supportedVersions, HandshakeIdleTimeout: 1337 * time.Hour, MaxIdleTimeout: 42 * time.Minute, KeepAlivePeriod: 5 * time.Second, } ln, err := Listen(conn, tlsConf, &config) Expect(err).ToNot(HaveOccurred()) server := ln.baseServer Expect(server.connHandler).ToNot(BeNil()) Expect(server.config.Versions).To(Equal(supportedVersions)) Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) // stop the listener Expect(ln.Close()).To(Succeed()) }) It("listens on a given address", func() { addr := "127.0.0.1:13579" ln, err := ListenAddr(addr, tlsConf, &Config{}) Expect(err).ToNot(HaveOccurred()) Expect(ln.Addr().String()).To(Equal(addr)) // stop the listener Expect(ln.Close()).To(Succeed()) }) It("errors if given an invalid address", func() { addr := "127.0.0.1" _, err := ListenAddr(addr, tlsConf, &Config{}) Expect(err).To(BeAssignableToTypeOf(&net.AddrError{})) }) It("errors if given an invalid address", func() { addr := "1.1.1.1:1111" _, err := ListenAddr(addr, tlsConf, &Config{}) Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) }) Context("server accepting connections that completed the handshake", func() { var ( tr *Transport serv *baseServer phm *MockPacketHandlerManager tracer *mocklogging.MockTracer ) BeforeEach(func() { var t *logging.Tracer t, tracer = mocklogging.NewMockTracer(mockCtrl) tr = &Transport{Conn: conn, Tracer: t} ln, err := tr.Listen(tlsConf, nil) Expect(err).ToNot(HaveOccurred()) serv = ln.baseServer phm = NewMockPacketHandlerManager(mockCtrl) serv.connHandler = phm }) AfterEach(func() { tracer.EXPECT().Close() tr.Close() }) Context("handling packets", func() { It("drops Initial packets with a too short connection ID", func() { p := getPacket(&wire.Header{ Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Version: serv.config.Versions[0], }, nil) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) serv.handlePacket(p) // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) }) It("drops too small Initial", func() { p := getPacket(&wire.Header{ Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: serv.config.Versions[0], }, make([]byte, protocol.MinInitialPacketSize-100)) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) serv.handlePacket(p) // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) }) It("drops non-Initial packets", func() { p := getPacket(&wire.Header{ Type: protocol.PacketTypeHandshake, Version: serv.config.Versions[0], }, []byte("invalid")) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket) serv.handlePacket(p) // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) }) It("passes packets to existing connections", func() { connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) p := getPacket(&wire.Header{ Type: protocol.PacketTypeInitial, DestConnectionID: connID, Version: serv.config.Versions[0], }, make([]byte, protocol.MinInitialPacketSize)) conn := NewMockPacketHandler(mockCtrl) phm.EXPECT().Get(connID).Return(conn, true) handled := make(chan struct{}) conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) serv.handlePacket(p) Eventually(handled).Should(BeClosed()) }) It("creates a connection when the token is accepted", func() { serv.verifySourceAddress = func(net.Addr) bool { return true } raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} retryToken, err := serv.tokenGenerator.NewRetryToken( raddr, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), ) Expect(err).ToNot(HaveOccurred()) connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: connID, Version: protocol.Version1, Token: retryToken, } p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) p.remoteAddr = raddr run := make(chan struct{}) var token protocol.StatelessResetToken rand.Read(token[:]) var newConnID protocol.ConnectionID conn := NewMockQUICConn(mockCtrl) serv.newConn = func( _ sendConn, _ connRunner, origDestConnID protocol.ConnectionID, retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, _ ConnectionIDGenerator, tokenP protocol.StatelessResetToken, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))) Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) Expect(destConnID).To(Equal(hdr.SrcConnectionID)) // make sure we're using a server-generated connection ID Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) newConnID = srcConnID Expect(tokenP).To(Equal(token)) conn.EXPECT().handlePacket(p) conn.EXPECT().run().Do(func() error { close(run); return nil }) conn.EXPECT().Context().Return(context.Background()) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) return conn } phm.EXPECT().Get(connID) phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token) phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool { Expect(cid).To(Equal(newConnID)) return true }) done := make(chan struct{}) go func() { defer GinkgoRecover() serv.handlePacket(p) // the Handshake packet is written by the connection. // Make sure there are no Write calls on the packet conn. time.Sleep(50 * time.Millisecond) close(done) }() // make sure we're using a server-generated connection ID Eventually(run).Should(BeClosed()) Eventually(done).Should(BeClosed()) // shutdown conn.EXPECT().closeWithTransportError(gomock.Any()) }) It("sends a Version Negotiation Packet for unsupported versions", func() { srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) packet := getPacket(&wire.Header{ Type: protocol.PacketTypeHandshake, SrcConnectionID: srcConnID, DestConnectionID: destConnID, Version: 0x42, }, make([]byte, protocol.MinUnknownVersionPacketSize)) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} packet.remoteAddr = raddr tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.Version) { Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) }) done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) dest, src, versions, err := wire.ParseVersionNegotiationPacket(b) Expect(err).ToNot(HaveOccurred()) Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) Expect(versions).ToNot(ContainElement(protocol.Version(0x42))) return len(b), nil }) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) It("doesn't send a Version Negotiation packets if sending them is disabled", func() { serv.disableVersionNegotiation = true srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) packet := getPacket(&wire.Header{ Type: protocol.PacketTypeHandshake, SrcConnectionID: srcConnID, DestConnectionID: destConnID, Version: 0x42, }, make([]byte, protocol.MinUnknownVersionPacketSize)) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} packet.remoteAddr = raddr done := make(chan struct{}) serv.handlePacket(packet) Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed()) }) It("ignores Version Negotiation packets", func() { data := wire.ComposeVersionNegotiation( protocol.ArbitraryLenConnectionID{1, 2, 3, 4}, protocol.ArbitraryLenConnectionID{4, 3, 2, 1}, []protocol.Version{1, 2, 3}, ) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} done := make(chan struct{}) tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) serv.handlePacket(receivedPacket{ remoteAddr: raddr, data: data, buffer: getPacketBuffer(), }) Eventually(done).Should(BeClosed()) // make sure no other packet is sent time.Sleep(scaleDuration(20 * time.Millisecond)) }) It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() { srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) p := getPacket(&wire.Header{ Type: protocol.PacketTypeHandshake, SrcConnectionID: srcConnID, DestConnectionID: destConnID, Version: 0x42, }, make([]byte, protocol.MinUnknownVersionPacketSize-50)) Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize)) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} p.remoteAddr = raddr done := make(chan struct{}) tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) serv.handlePacket(p) Eventually(done).Should(BeClosed()) // make sure no other packet is sent time.Sleep(scaleDuration(20 * time.Millisecond)) }) It("replies with a Retry packet, if a token is required", func() { connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} var called bool serv.verifySourceAddress = func(addr net.Addr) bool { Expect(addr).To(Equal(raddr)) called = true return true } hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: connID, Version: protocol.Version1, } packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) packet.remoteAddr = raddr tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) Expect(replyHdr.Token).ToNot(BeEmpty()) }) done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) replyHdr := parseHeader(b) Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) Expect(replyHdr.Token).ToNot(BeEmpty()) Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:])) return len(b), nil }) phm.EXPECT().Get(connID) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) Expect(called).To(BeTrue()) }) It("creates a connection, if no token is required", func() { connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: connID, Version: protocol.Version1, } p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) run := make(chan struct{}) var token protocol.StatelessResetToken rand.Read(token[:]) var newConnID protocol.ConnectionID conn := NewMockQUICConn(mockCtrl) serv.newConn = func( _ sendConn, _ connRunner, origDestConnID protocol.ConnectionID, retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, _ ConnectionIDGenerator, tokenP protocol.StatelessResetToken, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) Expect(retrySrcConnID).To(BeNil()) Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) Expect(destConnID).To(Equal(hdr.SrcConnectionID)) // make sure we're using a server-generated connection ID Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) newConnID = srcConnID Expect(tokenP).To(Equal(token)) conn.EXPECT().handlePacket(p) conn.EXPECT().run().Do(func() error { close(run); return nil }) conn.EXPECT().Context().Return(context.Background()) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) return conn } gomock.InOrder( phm.EXPECT().Get(connID), phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token), phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool { Expect(c).To(Equal(newConnID)) return true }), ) done := make(chan struct{}) go func() { defer GinkgoRecover() serv.handlePacket(p) // the Handshake packet is written by the connection // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) close(done) }() // make sure we're using a server-generated connection ID Eventually(run).Should(BeClosed()) Eventually(done).Should(BeClosed()) // shutdown conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) }) It("drops packets if the receive queue is full", func() { serv.verifySourceAddress = func(net.Addr) bool { return false } phm.EXPECT().Get(gomock.Any()).AnyTimes() phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() acceptConn := make(chan struct{}) var counter atomic.Uint32 serv.newConn = func( _ sendConn, runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { <-acceptConn counter.Add(1) conn := NewMockQUICConn(mockCtrl) conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) conn.EXPECT().run().MaxTimes(1) conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1) // shutdown conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) return conn } p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})) serv.handlePacket(p) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1) var wg sync.WaitGroup for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ { wg.Add(1) go func() { defer GinkgoRecover() defer wg.Done() serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))) }() } wg.Wait() close(acceptConn) Eventually( func() uint32 { return counter.Load() }, scaleDuration(100*time.Millisecond), ).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) }) It("only creates a single connection for a duplicate Initial", func() { done := make(chan struct{}) serv.newConn = func( _ sendConn, runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().closeWithTransportError(qerr.ConnectionRefused).Do(func(qerr.TransportErrorCode) { close(done) }) return conn } connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) p := getInitial(connID) phm.EXPECT().Get(connID) phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision Expect(serv.handlePacketImpl(p)).To(BeTrue()) Eventually(done).Should(BeClosed()) }) It("limits the number of unvalidated handshakes", func() { const limit = 3 limiter := rate.NewLimiter(0, limit) serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() } phm.EXPECT().Get(gomock.Any()).AnyTimes() phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() connChan := make(chan *MockQUICConn, 1) var wg sync.WaitGroup wg.Add(limit) done := make(chan struct{}) serv.newConn = func( _ sendConn, runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { conn := <-connChan conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().run() conn.EXPECT().Context().Return(context.Background()) conn.EXPECT().HandshakeComplete().DoAndReturn(func() <-chan struct{} { wg.Done(); return done }) return conn } // Initiate the maximum number of allowed connection attempts. for i := 0; i < limit; i++ { conn := NewMockQUICConn(mockCtrl) connChan <- conn serv.handlePacket(getInitialWithRandomDestConnID()) } // Now initiate another connection attempt. p := getInitialWithRandomDestConnID() tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { defer GinkgoRecover() Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) }) conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer GinkgoRecover() defer close(done) hdr, _, _, err := wire.ParsePacket(b) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) return len(b), nil }) serv.handlePacket(p) Eventually(done).Should(BeClosed()) for i := 0; i < limit; i++ { _, err := serv.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) } wg.Wait() }) }) Context("token validation", func() { It("decodes the token from the token field", func() { serv.newConn = func( _ sendConn, _ connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { c := NewMockQUICConn(mockCtrl) c.EXPECT().handlePacket(gomock.Any()) c.EXPECT().run() c.EXPECT().HandshakeComplete() ctx, cancel := context.WithCancel(context.Background()) cancel() c.EXPECT().Context().Return(ctx) return c } raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) packet := getPacket(&wire.Header{ Type: protocol.PacketTypeInitial, Token: token, Version: serv.config.Versions[0], }, make([]byte, protocol.MinInitialPacketSize)) packet.remoteAddr = raddr conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) done := make(chan struct{}) phm.EXPECT().Get(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool { close(done) return true }) phm.EXPECT().Remove(gomock.Any()).AnyTimes() serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { serv.verifySourceAddress = func(net.Addr) bool { return true } token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.Version1, } packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} packet.remoteAddr = raddr tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) Expect(frames).To(HaveLen(1)) Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) ccf := frames[0].(*logging.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) }) done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) checkConnectionCloseError(b, hdr, qerr.InvalidToken) return len(b), nil }) phm.EXPECT().Get(gomock.Any()) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) It("sends an INVALID_TOKEN error, if an expired retry token is received", func() { serv.verifySourceAddress = func(net.Addr) bool { return true } serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond)) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) time.Sleep(2 * time.Millisecond) // make sure the token is expired hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.Version1, } packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) packet.remoteAddr = raddr tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) Expect(frames).To(HaveLen(1)) Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) ccf := frames[0].(*logging.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) }) done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) checkConnectionCloseError(b, hdr, qerr.InvalidToken) return len(b), nil }) phm.EXPECT().Get(gomock.Any()) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() { serv.verifySourceAddress = func(net.Addr) bool { return true } token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.Version1, } packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} packet.remoteAddr = raddr tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) replyHdr := parseHeader(b) Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) return len(b), nil }) phm.EXPECT().Get(gomock.Any()) serv.handlePacket(packet) // make sure there are no Write calls on the packet conn Eventually(done).Should(BeClosed()) }) It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() { serv.verifySourceAddress = func(net.Addr) bool { return true } serv.maxTokenAge = time.Millisecond raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} token, err := serv.tokenGenerator.NewToken(raddr) Expect(err).ToNot(HaveOccurred()) time.Sleep(2 * time.Millisecond) // make sure the token is expired hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.Version1, } packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) packet.remoteAddr = raddr tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) }) done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) return len(b), nil }) phm.EXPECT().Get(gomock.Any()) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.Version1, } packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} done := make(chan struct{}) tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) phm.EXPECT().Get(gomock.Any()) serv.handlePacket(packet) // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) Eventually(done).Should(BeClosed()) }) }) Context("accepting connections", func() { It("returns Accept when closed", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() _, err := serv.Accept(context.Background()) Expect(err).To(MatchError(ErrServerClosed)) close(done) }() serv.Close() Eventually(done).Should(BeClosed()) }) It("returns immediately, if an error occurred before", func() { serv.Close() for i := 0; i < 3; i++ { _, err := serv.Accept(context.Background()) Expect(err).To(MatchError(ErrServerClosed)) } }) PIt("closes connection that are still handshaking after Close", func() { serv.Close() destroyed := make(chan struct{}) serv.newConn = func( _ sendConn, _ connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ protocol.StatelessResetToken, conf *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) }) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) conn.EXPECT().run().MaxTimes(1) conn.EXPECT().Context().Return(context.Background()) return conn } phm.EXPECT().Get(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) serv.handleInitialImpl( receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) Eventually(destroyed).Should(BeClosed()) }) It("returns when the context is canceled", func() { ctx, cancel := context.WithCancel(context.Background()) done := make(chan struct{}) go func() { defer GinkgoRecover() _, err := serv.Accept(ctx) Expect(err).To(MatchError("context canceled")) close(done) }() Consistently(done).ShouldNot(BeClosed()) cancel() Eventually(done).Should(BeClosed()) }) It("uses the config returned by GetConfigClient", func() { conn := NewMockQUICConn(mockCtrl) conf := &Config{MaxIncomingStreams: 1234} serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }}) done := make(chan struct{}) go func() { defer GinkgoRecover() s, err := serv.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(s).To(Equal(conn)) close(done) }() handshakeChan := make(chan struct{}) serv.newConn = func( _ sendConn, _ connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ protocol.StatelessResetToken, conf *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234)) conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().run() conn.EXPECT().Context().Return(context.Background()) return conn } phm.EXPECT().Get(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) serv.handleInitialImpl( receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) Consistently(done).ShouldNot(BeClosed()) close(handshakeChan) // complete the handshake Eventually(done).Should(BeClosed()) }) It("rejects a connection attempt when GetConfigClient returns an error", func() { serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }}) phm.EXPECT().Get(gomock.Any()) done := make(chan struct{}) tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) rejectHdr := parseHeader(b) Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) return len(b), nil }) serv.handleInitialImpl( receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1}, ) Eventually(done).Should(BeClosed()) }) It("accepts new connections when the handshake completes", func() { conn := NewMockQUICConn(mockCtrl) done := make(chan struct{}) go func() { defer GinkgoRecover() s, err := serv.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(s).To(Equal(conn)) close(done) }() handshakeChan := make(chan struct{}) serv.newConn = func( _ sendConn, runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().run() conn.EXPECT().Context().Return(context.Background()) return conn } phm.EXPECT().Get(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) serv.handleInitialImpl( receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) Consistently(done).ShouldNot(BeClosed()) close(handshakeChan) // complete the handshake Eventually(done).Should(BeClosed()) }) }) }) Context("server accepting connections that haven't completed the handshake", func() { var ( serv *EarlyListener phm *MockPacketHandlerManager ) BeforeEach(func() { var err error serv, err = ListenEarly(conn, tlsConf, nil) Expect(err).ToNot(HaveOccurred()) phm = NewMockPacketHandlerManager(mockCtrl) serv.baseServer.connHandler = phm }) AfterEach(func() { serv.Close() }) It("accepts new connections when they become ready", func() { conn := NewMockQUICConn(mockCtrl) done := make(chan struct{}) go func() { defer GinkgoRecover() s, err := serv.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(s).To(Equal(conn)) close(done) }() ready := make(chan struct{}) serv.baseServer.newConn = func( _ sendConn, runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().run() conn.EXPECT().earlyConnReady().Return(ready) conn.EXPECT().Context().Return(context.Background()) return conn } phm.EXPECT().Get(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) serv.baseServer.handleInitialImpl( receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) Consistently(done).ShouldNot(BeClosed()) close(ready) Eventually(done).Should(BeClosed()) }) It("rejects new connection attempts if the accept queue is full", func() { connChan := make(chan *MockQUICConn, 1) var wg sync.WaitGroup // to make sure the test fully completes wg.Add(protocol.MaxAcceptQueueSize + 1) serv.baseServer.newConn = func( _ sendConn, runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { defer wg.Done() ready := make(chan struct{}) close(ready) conn := <-connChan conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().run() conn.EXPECT().earlyConnReady().Return(ready) conn.EXPECT().Context().Return(context.Background()) return conn } phm.EXPECT().Get(gomock.Any()).AnyTimes() phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize) for i := 0; i < protocol.MaxAcceptQueueSize; i++ { conn := NewMockQUICConn(mockCtrl) connChan <- conn serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) } Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize)) phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) conn := NewMockQUICConn(mockCtrl) conn.EXPECT().closeWithTransportError(ConnectionRefused) connChan <- conn serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) wg.Wait() }) It("doesn't accept new connections if they were closed in the mean time", func() { p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) ctx, cancel := context.WithCancel(context.Background()) connCreated := make(chan struct{}) conn := NewMockQUICConn(mockCtrl) serv.baseServer.newConn = func( _ sendConn, runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { conn.EXPECT().handlePacket(p) conn.EXPECT().run() conn.EXPECT().earlyConnReady() conn.EXPECT().Context().Return(ctx) close(connCreated) return conn } phm.EXPECT().Get(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) serv.baseServer.handlePacket(p) // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) Eventually(connCreated).Should(BeClosed()) cancel() time.Sleep(scaleDuration(200 * time.Millisecond)) done := make(chan struct{}) go func() { defer GinkgoRecover() serv.Accept(context.Background()) close(done) }() Consistently(done).ShouldNot(BeClosed()) // make the go routine return Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) }) Context("0-RTT", func() { var ( tr *Transport serv *baseServer phm *MockPacketHandlerManager tracer *mocklogging.MockTracer ) BeforeEach(func() { var t *logging.Tracer t, tracer = mocklogging.NewMockTracer(mockCtrl) tr = &Transport{Conn: conn, Tracer: t} ln, err := tr.ListenEarly(tlsConf, nil) Expect(err).ToNot(HaveOccurred()) phm = NewMockPacketHandlerManager(mockCtrl) serv = ln.baseServer serv.connHandler = phm }) AfterEach(func() { tracer.EXPECT().Close() Expect(tr.Close()).To(Succeed()) }) It("passes packets to existing connections", func() { connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) p := getPacket(&wire.Header{ Type: protocol.PacketType0RTT, DestConnectionID: connID, Version: serv.config.Versions[0], }, make([]byte, 100)) conn := NewMockPacketHandler(mockCtrl) phm.EXPECT().Get(connID).Return(conn, true) handled := make(chan struct{}) conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) serv.handlePacket(p) Eventually(handled).Should(BeClosed()) }) It("queues 0-RTT packets, up to Max0RTTQueueSize", func() { connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) var zeroRTTPackets []receivedPacket for i := 0; i < protocol.Max0RTTQueueLen; i++ { p := getPacket(&wire.Header{ Type: protocol.PacketType0RTT, DestConnectionID: connID, Version: serv.config.Versions[0], }, make([]byte, 100+i)) phm.EXPECT().Get(connID) serv.handlePacket(p) zeroRTTPackets = append(zeroRTTPackets, p) } // send one more packet, this one should be dropped p := getPacket(&wire.Header{ Type: protocol.PacketType0RTT, DestConnectionID: connID, Version: serv.config.Versions[0], }, make([]byte, 200)) phm.EXPECT().Get(connID) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) serv.handlePacket(p) initial := getPacket(&wire.Header{ Type: protocol.PacketTypeInitial, DestConnectionID: connID, Version: serv.config.Versions[0], }, make([]byte, protocol.MinInitialPacketSize)) called := make(chan struct{}) serv.newConn = func( _ sendConn, _ connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.Version, ) quicConn { conn := NewMockQUICConn(mockCtrl) var calls []any calls = append(calls, conn.EXPECT().handlePacket(initial)) for _, p := range zeroRTTPackets { calls = append(calls, conn.EXPECT().handlePacket(p)) } gomock.InOrder(calls...) conn.EXPECT().run() conn.EXPECT().earlyConnReady() conn.EXPECT().Context().Return(context.Background()) close(called) // shutdown conn.EXPECT().closeWithTransportError(gomock.Any()) return conn } phm.EXPECT().Get(connID) phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) serv.handlePacket(initial) Eventually(called).Should(BeClosed()) }) It("limits the number of queues", func() { for i := 0; i < protocol.Max0RTTQueues; i++ { b := make([]byte, 16) rand.Read(b) connID := protocol.ParseConnectionID(b) p := getPacket(&wire.Header{ Type: protocol.PacketType0RTT, DestConnectionID: connID, Version: serv.config.Versions[0], }, make([]byte, 100+i)) phm.EXPECT().Get(connID) serv.handlePacket(p) } connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) p := getPacket(&wire.Header{ Type: protocol.PacketType0RTT, DestConnectionID: connID, Version: serv.config.Versions[0], }, make([]byte, 200)) phm.EXPECT().Get(connID) dropped := make(chan struct{}) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) serv.handlePacket(p) Eventually(dropped).Should(BeClosed()) }) It("drops queues after a while", func() { now := time.Now() connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) p := getPacket(&wire.Header{ Type: protocol.PacketType0RTT, DestConnectionID: connID, Version: serv.config.Versions[0], }, make([]byte, 200)) p.rcvTime = now connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9}) p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2) p2 := getPacket(&wire.Header{ Type: protocol.PacketType0RTT, DestConnectionID: connID2, Version: serv.config.Versions[0], }, make([]byte, 300)) p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet dropped1 := make(chan struct{}) dropped2 := make(chan struct{}) // need to register the call before handling the packet to avoid race condition gomock.InOrder( tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped1) }), tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped2) }), ) phm.EXPECT().Get(connID) serv.handlePacket(p) // There's no cleanup Go routine. // Cleanup is triggered when new packets are received. phm.EXPECT().Get(connID2) serv.handlePacket(p2) // make sure no cleanup is executed Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed()) // There's no cleanup Go routine. // Cleanup is triggered when new packets are received. connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0}) p3 := getPacket(&wire.Header{ Type: protocol.PacketType0RTT, DestConnectionID: connID3, Version: serv.config.Versions[0], }, make([]byte, 200)) p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup phm.EXPECT().Get(connID3) serv.handlePacket(p3) Eventually(dropped1).Should(BeClosed()) Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed()) // make sure the second packet is also cleaned up connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1}) p4 := getPacket(&wire.Header{ Type: protocol.PacketType0RTT, DestConnectionID: connID4, Version: serv.config.Versions[0], }, make([]byte, 200)) p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup phm.EXPECT().Get(connID4) serv.handlePacket(p4) Eventually(dropped2).Should(BeClosed()) }) }) })