add a GetConfigForClient callback to the Config

This commit is contained in:
Marten Seemann 2023-04-25 11:31:01 +02:00
parent ba942715db
commit bc7cb706c5
9 changed files with 223 additions and 47 deletions

View file

@ -103,6 +103,7 @@ func populateConfig(config *Config) *Config {
}
return &Config{
GetConfigForClient: config.GetConfigForClient,
Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout,

View file

@ -1,6 +1,7 @@
package quic
import (
"errors"
"fmt"
"net"
"reflect"
@ -45,7 +46,7 @@ var _ = Describe("Config", func() {
}
switch fn := typ.Field(i).Name; fn {
case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT":
case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT":
// Can't compare functions.
case "Versions":
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
@ -108,6 +109,7 @@ var _ = Describe("Config", func() {
It("clones function fields", func() {
var calledAddrValidation, calledAllowConnectionWindowIncrease bool
c1 := &Config{
GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
}
@ -116,6 +118,8 @@ var _ = Describe("Config", func() {
Expect(calledAddrValidation).To(BeTrue())
c2.AllowConnectionWindowIncrease(nil, 1234)
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
_, err := c2.GetConfigForClient(&ClientHelloInfo{})
Expect(err).To(MatchError("nope"))
})
It("clones non-function fields", func() {
@ -164,6 +168,7 @@ var _ = Describe("Config", func() {
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams))
Expect(c.DisableVersionNegotiationPackets).To(BeFalse())
Expect(c.DisablePathMTUDiscovery).To(BeFalse())
Expect(c.GetConfigForClient).To(BeNil())
})
It("populates empty fields with default values, for the server", func() {

View file

@ -436,6 +436,72 @@ var _ = Describe("Handshake tests", func() {
})
})
Context("GetConfigForClient", func() {
It("uses the quic.Config returned by GetConfigForClient", func() {
serverConfig.EnableDatagrams = false
var calledFrom net.Addr
serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
conf := serverConfig.Clone()
conf.EnableDatagrams = true
calledFrom = info.RemoteAddr
return getQuicConfig(conf), nil
}
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
close(done)
}()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
cs := conn.ConnectionState()
Expect(cs.SupportsDatagrams).To(BeTrue())
Eventually(done).Should(BeClosed())
Expect(ln.Close()).To(Succeed())
Expect(calledFrom.(*net.UDPAddr).Port).To(Equal(conn.LocalAddr().(*net.UDPAddr).Port))
})
It("rejects the connection attempt if GetConfigForClient errors", func() {
serverConfig.EnableDatagrams = false
serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
return nil, errors.New("rejected")
}
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := ln.Accept(context.Background())
Expect(err).To(HaveOccurred()) // we don't expect to accept any connection
close(done)
}()
_, err = quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
Expect(err).To(HaveOccurred())
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused))
})
})
It("doesn't send any packets when generating the ClientHello fails", func() {
ln, err := net.ListenUDP("udp", nil)
Expect(err).ToNot(HaveOccurred())

View file

@ -239,6 +239,9 @@ type ConnectionIDGenerator interface {
// Config contains all configuration data needed for a QUIC server or client.
type Config struct {
// GetConfigForClient is called for incoming connections.
// If the error is not nil, the connection attempt is refused.
GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
Versions []VersionNumber
@ -324,6 +327,10 @@ type Config struct {
Tracer logging.Tracer
}
type ClientHelloInfo struct {
RemoteAddr net.Addr
}
// ConnectionState records basic details about a QUIC connection
type ConnectionState struct {
TLS handshake.ConnectionState

View file

@ -61,7 +61,7 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interfa
}
// AddWithConnID mocks base method.
func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() packetHandler) bool {
func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() (packetHandler, bool)) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2)
ret0, _ := ret[0].(bool)

View file

@ -122,7 +122,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
return true
}
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool {
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
h.mutex.Lock()
defer h.mutex.Unlock()
@ -130,7 +130,10 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false
}
conn := fn()
conn, ok := fn()
if !ok {
return false
}
h.handlers[clientDestConnID] = conn
h.handlers[newConnID] = conn
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)

View file

@ -62,14 +62,14 @@ var _ = Describe("Packet Handler Map", func() {
var called bool
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
Expect(m.AddWithConnID(connID1, connID2, func() packetHandler {
Expect(m.AddWithConnID(connID1, connID2, func() (packetHandler, bool) {
called = true
return NewMockPacketHandler(mockCtrl)
return NewMockPacketHandler(mockCtrl), true
})).To(BeTrue())
Expect(called).To(BeTrue())
Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() packetHandler {
Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() (packetHandler, bool) {
Fail("didn't expect the constructor to be executed")
return nil
return nil, false
})).To(BeFalse())
})

View file

@ -33,7 +33,7 @@ type packetHandler interface {
type packetHandlerManager interface {
Get(protocol.ConnectionID) (packetHandler, bool)
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
Close(error)
CloseServer()
connRunner
@ -584,7 +584,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
s.logger.Debugf("Changing connection ID to %s.", connID)
var conn quicConn
tracingID := nextConnTracingID()
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler {
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) {
var tracer logging.ConnectionTracer
if s.config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
@ -598,6 +598,15 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
connID,
)
}
config := s.config
if s.config.GetConfigForClient != nil {
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr})
if err != nil {
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
return nil, false
}
config = populateConfig(conf)
}
conn = s.newConn(
newSendConn(s.conn, p.remoteAddr, p.info),
s.connHandler,
@ -608,7 +617,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
connID,
s.connIDGenerator,
s.connHandler.GetStatelessResetToken(connID),
s.config,
config,
s.tlsConf,
s.tokenGenerator,
clientAddrIsValid,
@ -626,10 +635,14 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
return conn
return conn, true
}); !added {
// TODO: don't just drop the packet
// Properly reject the connection attempt.
go func() {
defer p.buffer.Release()
if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil {
s.logger.Debugf("Error rejecting connection: %s", err)
}
}()
return nil
}
go conn.run()

View file

@ -267,14 +267,14 @@ var _ = Describe("Server", func() {
var newConnID protocol.ConnectionID
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
return token
})
fn()
return true
_, ok := fn()
return ok
})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))
conn := NewMockQUICConn(mockCtrl)
@ -468,14 +468,14 @@ var _ = Describe("Server", func() {
var newConnID protocol.ConnectionID
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
return token
})
fn()
return true
_, ok := fn()
return ok
}),
)
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
@ -532,10 +532,10 @@ var _ = Describe("Server", func() {
It("drops packets if the receive queue is full", func() {
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
}).AnyTimes()
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
@ -594,7 +594,6 @@ var _ = Describe("Server", func() {
It("only creates a single connection for a duplicate Initial", func() {
var createdConn bool
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ sendConn,
runner connRunner,
@ -615,15 +614,19 @@ var _ = Describe("Server", func() {
_ protocol.VersionNumber,
) quicConn {
createdConn = true
return conn
return NewMockQUICConn(mockCtrl)
}
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
p := getInitial(connID)
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) { close(done) })
Expect(serv.handlePacketImpl(p)).To(BeTrue())
Expect(createdConn).To(BeFalse())
Eventually(done).Should(BeClosed())
})
It("rejects new connection attempts if the accept queue is full", func() {
@ -657,10 +660,10 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
}).Times(protocol.MaxAcceptQueueSize)
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize)
@ -729,10 +732,10 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
@ -792,7 +795,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() (packetHandler, bool)) { close(done) })
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
@ -994,6 +997,84 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})
It("uses the config returned by GetConfigClient", func() {
conn := NewMockQUICConn(mockCtrl)
conf := &Config{MaxIncomingStreams: 1234}
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(s).To(Equal(conn))
close(done)
}()
handshakeChan := make(chan struct{})
serv.newConn = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
conf *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().run().Do(func() {})
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
})
serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
)
Consistently(done).ShouldNot(BeClosed())
close(handshakeChan) // complete the handshake
Eventually(done).Should(BeClosed())
})
It("rejects a connection attempt when GetConfigClient returns an error", func() {
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
_, ok := fn()
return ok
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
rejectHdr := parseHeader(b)
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
return len(b), nil
})
serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
)
Eventually(done).Should(BeClosed())
})
It("accepts new connections when the handshake completes", func() {
conn := NewMockQUICConn(mockCtrl)
@ -1033,10 +1114,10 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
serv.handleInitialImpl(
@ -1107,10 +1188,10 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
serv.baseServer.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
@ -1154,10 +1235,10 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
}).Times(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
@ -1216,10 +1297,10 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
serv.baseServer.handlePacket(p)
// make sure there are no Write calls on the packet conn
@ -1346,10 +1427,10 @@ var _ = Describe("Server", func() {
tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any())
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
serv.handlePacket(initial)
Eventually(called).Should(BeClosed())