From e65df402dd96c9af9874c747c901dce8cedd55c4 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 27 Nov 2019 09:24:56 +0700 Subject: [PATCH] only create a single session when two Initials arrive at the same time --- mock_packet_handler_manager_test.go | 28 +++++ packet_handler_map.go | 18 +++- packet_handler_map_test.go | 14 ++- server.go | 23 +++- server_test.go | 159 +++++++++++++++++++--------- session.go | 5 +- session_test.go | 2 +- 7 files changed, 190 insertions(+), 59 deletions(-) diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 1172a20f..2ccdbe16 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -48,6 +48,20 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) } +// AddIfNotTaken mocks base method +func (m *MockPacketHandlerManager) AddIfNotTaken(arg0 protocol.ConnectionID, arg1 packetHandler) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddIfNotTaken", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// AddIfNotTaken indicates an expected call of AddIfNotTaken +func (mr *MockPacketHandlerManagerMockRecorder) AddIfNotTaken(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddIfNotTaken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddIfNotTaken), arg0, arg1) +} + // AddResetToken mocks base method func (m *MockPacketHandlerManager) AddResetToken(arg0 [16]byte, arg1 packetHandler) { m.ctrl.T.Helper() @@ -86,6 +100,20 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) } +// GetStatelessResetToken mocks base method +func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) + ret0, _ := ret[0].([16]byte) + return ret0 +} + +// GetStatelessResetToken indicates an expected call of GetStatelessResetToken +func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0) +} + // Remove mocks base method func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/packet_handler_map.go b/packet_handler_map.go index 83085bcb..7e45d2a6 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -99,7 +99,19 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) h.mutex.Lock() h.handlers[string(id)] = handler h.mutex.Unlock() - return h.getStatelessResetToken(id) + return h.GetStatelessResetToken(id) +} + +func (h *packetHandlerMap) AddIfNotTaken(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { + sid := string(id) + h.mutex.Lock() + defer h.mutex.Unlock() + + if _, ok := h.handlers[sid]; !ok { + h.handlers[sid] = handler + return true + } + return false } func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { @@ -284,7 +296,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { return false } -func (h *packetHandlerMap) getStatelessResetToken(connID protocol.ConnectionID) [16]byte { +func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte { var token [16]byte if !h.statelessResetEnabled { // Return a random stateless reset token. @@ -311,7 +323,7 @@ func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID pro if len(p.data) <= protocol.MinStatelessResetSize { return } - token := h.getStatelessResetToken(connID) + token := h.GetStatelessResetToken(connID) h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) rand.Read(data) diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 2e805e10..1f15cea1 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -175,6 +175,12 @@ var _ = Describe("Packet Handler Map", func() { conn.Close() Eventually(done).Should(BeClosed()) }) + + It("says if a connection ID is already taken", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + Expect(handler.AddIfNotTaken(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) + Expect(handler.AddIfNotTaken(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) + }) }) Context("running a server", func() { @@ -289,7 +295,7 @@ var _ = Describe("Packet Handler Map", func() { statelessResetKey = key }) - It("generates stateless reset tokens", func() { + It("generates stateless reset tokens when adding new sessions", func() { connID1 := []byte{0xde, 0xad, 0xbe, 0xef} connID2 := []byte{0xde, 0xca, 0xfb, 0xad} token1 := handler.Add(connID1, nil) @@ -297,6 +303,12 @@ var _ = Describe("Packet Handler Map", func() { Expect(handler.Add(connID2, nil)).ToNot(Equal(token1)) }) + It("generates stateless reset tokens", func() { + connID1 := []byte{0xde, 0xad, 0xbe, 0xef} + connID2 := []byte{0xde, 0xca, 0xfb, 0xad} + Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2))) + }) + It("sends stateless resets", func() { addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} p := append([]byte{40}, make([]byte, 100)...) diff --git a/server.go b/server.go index 02f5e394..79a24767 100644 --- a/server.go +++ b/server.go @@ -17,6 +17,8 @@ import ( "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" + + "github.com/onsi/ginkgo" ) // packetHandler handles packets @@ -37,6 +39,8 @@ type packetHandlerManager interface { SetServer(unknownPacketHandler) CloseServer() sessionRunner + AddIfNotTaken(protocol.ConnectionID, packetHandler) bool + GetStatelessResetToken(protocol.ConnectionID) [16]byte } type quicSession interface { @@ -70,7 +74,7 @@ type baseServer struct { sessionHandler packetHandlerManager // set as a member, so they can be set in the tests - newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* client dest connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, *Config, *tls.Config, *handshake.TokenGenerator, utils.Logger, protocol.VersionNumber) quicSession + newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* client dest connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, [16]byte, *Config, *tls.Config, *handshake.TokenGenerator, utils.Logger, protocol.VersionNumber) quicSession serverError error errorChan chan struct{} @@ -327,6 +331,7 @@ func (s *baseServer) Addr() net.Addr { func (s *baseServer) handlePacket(p *receivedPacket) { go func() { + defer ginkgo.GinkgoRecover() if shouldReleaseBuffer := s.handlePacketImpl(p); !shouldReleaseBuffer { p.buffer.Release() } @@ -368,7 +373,9 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet s.logger.Errorf("Error occurred handling initial packet: %s", err) return false } - if sess == nil { // a retry was done, or the connection attempt was rejected + // A retry was done, or the connection attempt was rejected, + // or if the Initial was a duplicate. + if sess == nil { return false } // Don't put the packet buffer back if a new session was created. @@ -419,7 +426,9 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui connID, hdr.Version, ) - sess.handlePacket(p) + if sess != nil { + sess.handlePacket(p) + } return sess, nil } @@ -438,12 +447,20 @@ func (s *baseServer) createNewSession( clientDestConnID, destConnID, srcConnID, + s.sessionHandler.GetStatelessResetToken(srcConnID), s.config, s.tlsConf, s.tokenGenerator, s.logger, version, ) + added := s.sessionHandler.AddIfNotTaken(clientDestConnID, sess) + // We're already keeping track of this connection ID. + // This might happen if we receive two copies of the Initial at the same time. + if !added { + return nil + } + s.sessionHandler.Add(srcConnID, sess) go sess.run() go s.handleNewSession(sess) return sess diff --git a/server_test.go b/server_test.go index 43479b51..2b0d0d47 100644 --- a/server_test.go +++ b/server_test.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "context" + "crypto/rand" "crypto/tls" "errors" "net" @@ -10,6 +11,7 @@ import ( "sync" "time" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" @@ -39,6 +41,29 @@ var _ = Describe("Server", func() { } } + getInitial := func(destConnID protocol.ConnectionID) *receivedPacket { + senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: destConnID, + Version: protocol.VersionTLS, + } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + p.buffer = getPacketBuffer() + p.remoteAddr = senderAddr + return p + } + + getInitialWithRandomDestConnID := func() *receivedPacket { + destConnID := make([]byte, 10) + _, err := rand.Read(destConnID) + Expect(err).ToNot(HaveOccurred()) + + return getInitial(destConnID) + } + parseHeader := func(data []byte) *wire.Header { hdr, _, _, err := wire.ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) @@ -127,12 +152,17 @@ var _ = Describe("Server", func() { }) Context("server accepting sessions that completed the handshake", func() { - var serv *baseServer + var ( + serv *baseServer + phm *MockPacketHandlerManager + ) BeforeEach(func() { ln, err := Listen(conn, tlsConf, nil) Expect(err).ToNot(HaveOccurred()) serv = ln.(*baseServer) + phm = NewMockPacketHandlerManager(mockCtrl) + serv.sessionHandler = phm }) Context("handling packets", func() { @@ -282,6 +312,14 @@ var _ = Describe("Server", func() { } p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) run := make(chan struct{}) + var token [16]byte + rand.Read(token[:]) + var newConnID protocol.ConnectionID + phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte { + newConnID = c + return token + }) + sess := NewMockQuicSession(mockCtrl) serv.newSession = func( _ connection, _ sessionRunner, @@ -289,6 +327,7 @@ var _ = Describe("Server", func() { origConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + tokenP [16]byte, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -300,7 +339,8 @@ var _ = Describe("Server", func() { // make sure we're using a server-generated connection ID Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) - sess := NewMockQuicSession(mockCtrl) + Expect(srcConnID).To(Equal(newConnID)) + Expect(tokenP).To(Equal(token)) sess.EXPECT().handlePacket(p) sess.EXPECT().run().Do(func() { close(run) }) sess.EXPECT().Context().Return(context.Background()) @@ -308,6 +348,11 @@ var _ = Describe("Server", func() { return sess } + phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true) + phm.EXPECT().Add(gomock.Any(), sess).Do(func(c protocol.ConnectionID, _ packetHandler) { + Expect(c).To(Equal(newConnID)) + }) + done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -321,19 +366,10 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) - It("rejects new connection attempts if the accept queue is full", func() { + It("only creates a single session for a duplicate Initial", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} - - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Version: protocol.VersionTLS, - } - p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - p.remoteAddr = senderAddr + var createdSession bool + sess := NewMockQuicSession(mockCtrl) serv.newSession = func( _ connection, runner sessionRunner, @@ -341,6 +377,35 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ [16]byte, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicSession { + createdSession = true + return sess + } + + p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false) + Expect(serv.handlePacketImpl(p)).To(BeFalse()) + Expect(createdSession).To(BeTrue()) + }) + + It("rejects new connection attempts if the accept queue is full", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + + serv.newSession = func( + _ connection, + runner sessionRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ [16]byte, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -348,7 +413,7 @@ var _ = Describe("Server", func() { _ protocol.VersionNumber, ) quicSession { sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(p) + sess.EXPECT().handlePacket(gomock.Any()) sess.EXPECT().run() sess.EXPECT().Context().Return(context.Background()) ctx, cancel := context.WithCancel(context.Background()) @@ -357,21 +422,28 @@ var _ = Describe("Server", func() { return sess } + phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) + phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize) + phm.EXPECT().Add(gomock.Any(), gomock.Any()).Times(protocol.MaxAcceptQueueSize) + var wg sync.WaitGroup wg.Add(protocol.MaxAcceptQueueSize) for i := 0; i < protocol.MaxAcceptQueueSize; i++ { go func() { defer GinkgoRecover() defer wg.Done() - serv.handlePacket(p) + serv.handlePacket(getInitialWithRandomDestConnID()) Consistently(conn.dataWritten).ShouldNot(Receive()) }() } wg.Wait() + p := getInitialWithRandomDestConnID() + hdr, _, _, err := wire.ParsePacket(p.data, 0) + Expect(err).ToNot(HaveOccurred()) serv.handlePacket(p) var reject mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&reject)) - Expect(reject.to).To(Equal(senderAddr)) + Expect(reject.to).To(Equal(p.remoteAddr)) rejectHdr := parseHeader(reject.data) Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) Expect(rejectHdr.Version).To(Equal(hdr.Version)) @@ -381,17 +453,8 @@ var _ = Describe("Server", func() { It("doesn't accept new sessions if they were closed in the mean time", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Version: protocol.VersionTLS, - } - p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - p.remoteAddr = senderAddr + p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) ctx, cancel := context.WithCancel(context.Background()) sessionCreated := make(chan struct{}) sess := NewMockQuicSession(mockCtrl) @@ -402,6 +465,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ [16]byte, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -418,6 +482,10 @@ var _ = Describe("Server", func() { return sess } + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true) + phm.EXPECT().Add(gomock.Any(), gomock.Any()) + serv.handlePacket(p) Consistently(conn.dataWritten).ShouldNot(Receive()) Eventually(sessionCreated).Should(BeClosed()) @@ -433,6 +501,8 @@ var _ = Describe("Server", func() { Consistently(done).ShouldNot(BeClosed()) // make the go routine return + phm.EXPECT().CloseServer() + sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -498,6 +568,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ [16]byte, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -509,6 +580,9 @@ var _ = Describe("Server", func() { sess.EXPECT().Context().Return(context.Background()) return sess } + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true) + phm.EXPECT().Add(gomock.Any(), gomock.Any()) serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) Consistently(done).ShouldNot(BeClosed()) cancel() // complete the handshake @@ -546,6 +620,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ [16]byte, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -567,15 +642,6 @@ var _ = Describe("Server", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Version: protocol.VersionTLS, - } - p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - p.remoteAddr = senderAddr serv.newSession = func( _ connection, runner sessionRunner, @@ -583,6 +649,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ [16]byte, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -592,7 +659,7 @@ var _ = Describe("Server", func() { ready := make(chan struct{}) close(ready) sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(p) + sess.EXPECT().handlePacket(gomock.Any()) sess.EXPECT().run() sess.EXPECT().earlySessionReady().Return(ready) sess.EXPECT().Context().Return(context.Background()) @@ -605,11 +672,14 @@ var _ = Describe("Server", func() { go func() { defer GinkgoRecover() defer wg.Done() - serv.handlePacket(p) + serv.handlePacket(getInitialWithRandomDestConnID()) Consistently(conn.dataWritten).ShouldNot(Receive()) }() } wg.Wait() + + p := getInitialWithRandomDestConnID() + hdr := parseHeader(p.data) serv.handlePacket(p) var reject mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&reject)) @@ -623,17 +693,8 @@ var _ = Describe("Server", func() { It("doesn't accept new sessions if they were closed in the mean time", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Version: protocol.VersionTLS, - } - p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - p.remoteAddr = senderAddr + p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) ctx, cancel := context.WithCancel(context.Background()) sessionCreated := make(chan struct{}) sess := NewMockQuicSession(mockCtrl) @@ -644,6 +705,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ [16]byte, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -673,6 +735,7 @@ var _ = Describe("Server", func() { Consistently(done).ShouldNot(BeClosed()) // make the go routine return + sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) diff --git a/session.go b/session.go index 79c84a7d..92ad35f3 100644 --- a/session.go +++ b/session.go @@ -196,6 +196,7 @@ var newSession = func( clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + statelessResetToken [16]byte, conf *Config, tlsConf *tls.Config, tokenGenerator *handshake.TokenGenerator, @@ -239,8 +240,6 @@ var newSession = func( initialStream := newCryptoStream() handshakeStream := newCryptoStream() oneRTTStream := newPostHandshakeCryptoStream(s.framer) - runner.Add(clientDestConnID, s) - token := runner.Add(srcConnID, s) params := &handshake.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData, @@ -252,7 +251,7 @@ var newSession = func( MaxAckDelay: protocol.MaxAckDelayInclGranularity, AckDelayExponent: protocol.AckDelayExponent, DisableMigration: true, - StatelessResetToken: &token, + StatelessResetToken: &statelessResetToken, OriginalConnectionID: origDestConnID, ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, } diff --git a/session_test.go b/session_test.go index e942c078..0027cf74 100644 --- a/session_test.go +++ b/session_test.go @@ -108,7 +108,6 @@ var _ = Describe("Session", func() { Eventually(areSessionsRunning).Should(BeFalse()) sessionRunner = NewMockSessionRunner(mockCtrl) - sessionRunner.EXPECT().Add(gomock.Any(), gomock.Any()).Times(2) mconn = newMockConnection() tokenGenerator, err := handshake.NewTokenGenerator() Expect(err).ToNot(HaveOccurred()) @@ -119,6 +118,7 @@ var _ = Describe("Session", func() { clientDestConnID, destConnID, srcConnID, + [16]byte{}, populateServerConfig(&Config{}), nil, // tls.Config tokenGenerator,