diff --git a/config.go b/config.go index ceb4fb68..fbd6da17 100644 --- a/config.go +++ b/config.go @@ -103,6 +103,7 @@ func populateConfig(config *Config) *Config { } return &Config{ + GetConfigForClient: config.GetConfigForClient, Versions: versions, HandshakeIdleTimeout: handshakeIdleTimeout, MaxIdleTimeout: idleTimeout, diff --git a/config_test.go b/config_test.go index f319deb2..3de7a173 100644 --- a/config_test.go +++ b/config_test.go @@ -1,6 +1,7 @@ package quic import ( + "errors" "fmt" "net" "reflect" @@ -45,7 +46,7 @@ var _ = Describe("Config", func() { } switch fn := typ.Field(i).Name; fn { - case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT": + case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT": // Can't compare functions. case "Versions": f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) @@ -108,6 +109,7 @@ var _ = Describe("Config", func() { It("clones function fields", func() { var calledAddrValidation, calledAllowConnectionWindowIncrease bool c1 := &Config{ + GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, } @@ -116,6 +118,8 @@ var _ = Describe("Config", func() { Expect(calledAddrValidation).To(BeTrue()) c2.AllowConnectionWindowIncrease(nil, 1234) Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) + _, err := c2.GetConfigForClient(&ClientHelloInfo{}) + Expect(err).To(MatchError("nope")) }) It("clones non-function fields", func() { @@ -164,6 +168,7 @@ var _ = Describe("Config", func() { Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams)) Expect(c.DisableVersionNegotiationPackets).To(BeFalse()) Expect(c.DisablePathMTUDiscovery).To(BeFalse()) + Expect(c.GetConfigForClient).To(BeNil()) }) It("populates empty fields with default values, for the server", func() { diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index fc77f424..0f1c3678 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -436,6 +436,72 @@ var _ = Describe("Handshake tests", func() { }) }) + Context("GetConfigForClient", func() { + It("uses the quic.Config returned by GetConfigForClient", func() { + serverConfig.EnableDatagrams = false + var calledFrom net.Addr + serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) { + conf := serverConfig.Clone() + conf.EnableDatagrams = true + calledFrom = info.RemoteAddr + return getQuicConfig(conf), nil + } + ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) + Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + conn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + getQuicConfig(&quic.Config{EnableDatagrams: true}), + ) + Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") + cs := conn.ConnectionState() + Expect(cs.SupportsDatagrams).To(BeTrue()) + Eventually(done).Should(BeClosed()) + Expect(ln.Close()).To(Succeed()) + Expect(calledFrom.(*net.UDPAddr).Port).To(Equal(conn.LocalAddr().(*net.UDPAddr).Port)) + }) + + It("rejects the connection attempt if GetConfigForClient errors", func() { + serverConfig.EnableDatagrams = false + serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) { + return nil, errors.New("rejected") + } + ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := ln.Accept(context.Background()) + Expect(err).To(HaveOccurred()) // we don't expect to accept any connection + close(done) + }() + + _, err = quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + getQuicConfig(&quic.Config{EnableDatagrams: true}), + ) + Expect(err).To(HaveOccurred()) + var transportErr *quic.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused)) + }) + }) + It("doesn't send any packets when generating the ClientHello fails", func() { ln, err := net.ListenUDP("udp", nil) Expect(err).ToNot(HaveOccurred()) diff --git a/interface.go b/interface.go index 267c07f5..29a20958 100644 --- a/interface.go +++ b/interface.go @@ -239,6 +239,9 @@ type ConnectionIDGenerator interface { // Config contains all configuration data needed for a QUIC server or client. type Config struct { + // GetConfigForClient is called for incoming connections. + // If the error is not nil, the connection attempt is refused. + GetConfigForClient func(info *ClientHelloInfo) (*Config, error) // The QUIC versions that can be negotiated. // If not set, it uses all versions available. Versions []VersionNumber @@ -324,6 +327,10 @@ type Config struct { Tracer logging.Tracer } +type ClientHelloInfo struct { + RemoteAddr net.Addr +} + // ConnectionState records basic details about a QUIC connection type ConnectionState struct { TLS handshake.ConnectionState diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 25ae5420..7b70a8db 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -61,7 +61,7 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interfa } // AddWithConnID mocks base method. -func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() packetHandler) bool { +func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() (packetHandler, bool)) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2) ret0, _ := ret[0].(bool) diff --git a/packet_handler_map.go b/packet_handler_map.go index 2a08359a..83caa192 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -122,7 +122,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) return true } -func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool { +func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool { h.mutex.Lock() defer h.mutex.Unlock() @@ -130,7 +130,10 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) return false } - conn := fn() + conn, ok := fn() + if !ok { + return false + } h.handlers[clientDestConnID] = conn h.handlers[newConnID] = conn h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index e87a75f8..2969bb5b 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -62,14 +62,14 @@ var _ = Describe("Packet Handler Map", func() { var called bool connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) - Expect(m.AddWithConnID(connID1, connID2, func() packetHandler { + Expect(m.AddWithConnID(connID1, connID2, func() (packetHandler, bool) { called = true - return NewMockPacketHandler(mockCtrl) + return NewMockPacketHandler(mockCtrl), true })).To(BeTrue()) Expect(called).To(BeTrue()) - Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() packetHandler { + Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() (packetHandler, bool) { Fail("didn't expect the constructor to be executed") - return nil + return nil, false })).To(BeFalse()) }) diff --git a/server.go b/server.go index a80709ea..a38a5784 100644 --- a/server.go +++ b/server.go @@ -33,7 +33,7 @@ type packetHandler interface { type packetHandlerManager interface { Get(protocol.ConnectionID) (packetHandler, bool) GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool) - AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool + AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool Close(error) CloseServer() connRunner @@ -584,7 +584,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro s.logger.Debugf("Changing connection ID to %s.", connID) var conn quicConn tracingID := nextConnTracingID() - if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { + if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) { var tracer logging.ConnectionTracer if s.config.Tracer != nil { // Use the same connection ID that is passed to the client's GetLogWriter callback. @@ -598,6 +598,15 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro connID, ) } + config := s.config + if s.config.GetConfigForClient != nil { + conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr}) + if err != nil { + s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback") + return nil, false + } + config = populateConfig(conf) + } conn = s.newConn( newSendConn(s.conn, p.remoteAddr, p.info), s.connHandler, @@ -608,7 +617,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro connID, s.connIDGenerator, s.connHandler.GetStatelessResetToken(connID), - s.config, + config, s.tlsConf, s.tokenGenerator, clientAddrIsValid, @@ -626,10 +635,14 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro delete(s.zeroRTTQueues, hdr.DestConnectionID) } - return conn + return conn, true }); !added { - // TODO: don't just drop the packet - // Properly reject the connection attempt. + go func() { + defer p.buffer.Release() + if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil { + s.logger.Debugf("Error rejecting connection: %s", err) + } + }() return nil } go conn.run() diff --git a/server_test.go b/server_test.go index 7a17ffd5..e78356c8 100644 --- a/server_test.go +++ b/server_test.go @@ -267,14 +267,14 @@ var _ = Describe("Server", func() { var newConnID protocol.ConnectionID phm.EXPECT().Get(connID) - phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool { newConnID = c phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { newConnID = c return token }) - fn() - return true + _, ok := fn() + return ok }) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})) conn := NewMockQUICConn(mockCtrl) @@ -468,14 +468,14 @@ var _ = Describe("Server", func() { var newConnID protocol.ConnectionID gomock.InOrder( phm.EXPECT().Get(connID), - phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool { newConnID = c phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { newConnID = c return token }) - fn() - return true + _, ok := fn() + return ok }), ) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) @@ -532,10 +532,10 @@ var _ = Describe("Server", func() { It("drops packets if the receive queue is full", func() { phm.EXPECT().Get(gomock.Any()).AnyTimes() - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }).AnyTimes() tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() @@ -594,7 +594,6 @@ var _ = Describe("Server", func() { It("only creates a single connection for a duplicate Initial", func() { var createdConn bool - conn := NewMockQUICConn(mockCtrl) serv.newConn = func( _ sendConn, runner connRunner, @@ -615,15 +614,19 @@ var _ = Describe("Server", func() { _ protocol.VersionNumber, ) quicConn { createdConn = true - return conn + return NewMockQUICConn(mockCtrl) } connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) p := getInitial(connID) phm.EXPECT().Get(connID) phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) { close(done) }) Expect(serv.handlePacketImpl(p)).To(BeTrue()) Expect(createdConn).To(BeFalse()) + Eventually(done).Should(BeClosed()) }) It("rejects new connection attempts if the accept queue is full", func() { @@ -657,10 +660,10 @@ var _ = Describe("Server", func() { } phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }).Times(protocol.MaxAcceptQueueSize) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize) @@ -729,10 +732,10 @@ var _ = Describe("Server", func() { } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) @@ -792,7 +795,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) }) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() (packetHandler, bool)) { close(done) }) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) @@ -994,6 +997,84 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) + It("uses the config returned by GetConfigClient", func() { + conn := NewMockQUICConn(mockCtrl) + + conf := &Config{MaxIncomingStreams: 1234} + serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }}) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + s, err := serv.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(s).To(Equal(conn)) + close(done) + }() + + handshakeChan := make(chan struct{}) + serv.newConn = func( + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ ConnectionIDGenerator, + _ protocol.StatelessResetToken, + conf *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234)) + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) + conn.EXPECT().run().Do(func() {}) + conn.EXPECT().Context().Return(context.Background()) + return conn + } + phm.EXPECT().Get(gomock.Any()) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + _, ok := fn() + return ok + }) + serv.handleInitialImpl( + &receivedPacket{buffer: getPacketBuffer()}, + &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, + ) + Consistently(done).ShouldNot(BeClosed()) + close(handshakeChan) // complete the handshake + Eventually(done).Should(BeClosed()) + }) + + It("rejects a connection attempt when GetConfigClient returns an error", func() { + serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }}) + + phm.EXPECT().Get(gomock.Any()) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { + _, ok := fn() + return ok + }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + rejectHdr := parseHeader(b) + Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) + return len(b), nil + }) + serv.handleInitialImpl( + &receivedPacket{buffer: getPacketBuffer()}, + &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1}, + ) + Eventually(done).Should(BeClosed()) + }) + It("accepts new connections when the handshake completes", func() { conn := NewMockQUICConn(mockCtrl) @@ -1033,10 +1114,10 @@ var _ = Describe("Server", func() { return conn } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) serv.handleInitialImpl( @@ -1107,10 +1188,10 @@ var _ = Describe("Server", func() { return conn } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }) serv.baseServer.handleInitialImpl( &receivedPacket{buffer: getPacketBuffer()}, @@ -1154,10 +1235,10 @@ var _ = Describe("Server", func() { } phm.EXPECT().Get(gomock.Any()).AnyTimes() - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }).Times(protocol.MaxAcceptQueueSize) for i := 0; i < protocol.MaxAcceptQueueSize; i++ { serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) @@ -1216,10 +1297,10 @@ var _ = Describe("Server", func() { } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }) serv.baseServer.handlePacket(p) // make sure there are no Write calls on the packet conn @@ -1346,10 +1427,10 @@ var _ = Describe("Server", func() { tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any()) phm.EXPECT().Get(connID) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }) serv.handlePacket(initial) Eventually(called).Should(BeClosed())