diff --git a/client.go b/client.go index 6492a697..019e574f 100644 --- a/client.go +++ b/client.go @@ -368,6 +368,8 @@ func (c *client) createNewTLSSession(_ protocol.VersionNumber) { c.version, ) c.mutex.Unlock() + // It's not possible to use the stateless reset token for the client's (first) connection ID, + // since there's no way to securely communicate it to the server. c.packetHandlers.Add(c.srcConnID, c) } diff --git a/client_test.go b/client_test.go index ae9c78ee..bf7961c4 100644 --- a/client_test.go +++ b/client_test.go @@ -338,45 +338,6 @@ var _ = Describe("Client", func() { Eventually(dialed).Should(BeClosed()) }) - It("removes closed sessions from the multiplexer", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(connID, gomock.Any()) - manager.EXPECT().Retire(connID) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) - - var runner sessionRunner - sess := NewMockQuicSession(mockCtrl) - newClientSession = func( - _ connection, - runnerP sessionRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ protocol.VersionNumber, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicSession { - runner = runnerP - return sess - } - sess.EXPECT().run().Do(func() { - runner.Retire(connID) - }) - sess.EXPECT().HandshakeComplete().Return(context.Background()) - - _, err := DialContext( - context.Background(), - packetConn, - addr, - "localhost:1337", - tlsConf, - &Config{}, - ) - Expect(err).ToNot(HaveOccurred()) - }) - It("closes the connection when it was created by DialAddr", func() { if os.Getenv("APPVEYOR") == "True" { Skip("This test is flaky on AppVeyor.") diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 9daf8f0c..1172a20f 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -35,9 +35,11 @@ func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorde } // Add mocks base method -func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) { +func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) [16]byte { m.ctrl.T.Helper() - m.ctrl.Call(m, "Add", arg0, arg1) + ret := m.ctrl.Call(m, "Add", arg0, arg1) + ret0, _ := ret[0].([16]byte) + return ret0 } // Add indicates an expected call of Add @@ -84,20 +86,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) } -// GetStatelessResetToken mocks base method -func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) - ret0, _ := ret[0].([16]byte) - return ret0 -} - -// GetStatelessResetToken indicates an expected call of GetStatelessResetToken -func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0) -} - // Remove mocks base method func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/mock_session_runner_test.go b/mock_session_runner_test.go index f1f85440..709d5701 100644 --- a/mock_session_runner_test.go +++ b/mock_session_runner_test.go @@ -34,6 +34,20 @@ func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder { return m.recorder } +// Add mocks base method +func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) [16]byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Add", arg0, arg1) + ret0, _ := ret[0].([16]byte) + return ret0 +} + +// Add indicates an expected call of Add +func (mr *MockSessionRunnerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSessionRunner)(nil).Add), arg0, arg1) +} + // AddResetToken mocks base method func (m *MockSessionRunner) AddResetToken(arg0 [16]byte, arg1 packetHandler) { m.ctrl.T.Helper() @@ -46,20 +60,6 @@ func (mr *MockSessionRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockSessionRunner)(nil).AddResetToken), arg0, arg1) } -// GetStatelessResetToken mocks base method -func (m *MockSessionRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) - ret0, _ := ret[0].([16]byte) - return ret0 -} - -// GetStatelessResetToken indicates an expected call of GetStatelessResetToken -func (mr *MockSessionRunnerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockSessionRunner)(nil).GetStatelessResetToken), arg0) -} - // Remove mocks base method func (m *MockSessionRunner) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/packet_handler_map.go b/packet_handler_map.go index 864d72b6..83085bcb 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -95,10 +95,11 @@ func (h *packetHandlerMap) logUsage() { } } -func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) { +func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) [16]byte { h.mutex.Lock() h.handlers[string(id)] = handler h.mutex.Unlock() + return h.getStatelessResetToken(id) } func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { @@ -283,7 +284,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { return false } -func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte { +func (h *packetHandlerMap) getStatelessResetToken(connID protocol.ConnectionID) [16]byte { var token [16]byte if !h.statelessResetEnabled { // Return a random stateless reset token. @@ -310,7 +311,7 @@ func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID pro if len(p.data) <= protocol.MinStatelessResetSize { return } - token := h.GetStatelessResetToken(connID) + token := h.getStatelessResetToken(connID) h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) rand.Read(data) diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index b639247a..2e805e10 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -292,9 +292,9 @@ var _ = Describe("Packet Handler Map", func() { It("generates stateless reset tokens", func() { connID1 := []byte{0xde, 0xad, 0xbe, 0xef} connID2 := []byte{0xde, 0xca, 0xfb, 0xad} - token1 := handler.GetStatelessResetToken(connID1) - Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1)) - Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1)) + token1 := handler.Add(connID1, nil) + Expect(handler.Add(connID1, nil)).To(Equal(token1)) + Expect(handler.Add(connID2, nil)).ToNot(Equal(token1)) }) It("sends stateless resets", func() { diff --git a/server.go b/server.go index b5665489..02f5e394 100644 --- a/server.go +++ b/server.go @@ -34,7 +34,6 @@ type unknownPacketHandler interface { type packetHandlerManager interface { io.Closer - Add(protocol.ConnectionID, packetHandler) SetServer(unknownPacketHandler) CloseServer() sessionRunner @@ -364,7 +363,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet s.logger.Debugf("<- Received Initial packet.") - sess, connID, err := s.handleInitialImpl(p, hdr) + sess, err := s.handleInitialImpl(p, hdr) if err != nil { s.logger.Errorf("Error occurred handling initial packet: %s", err) return false @@ -374,13 +373,12 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet } // Don't put the packet buffer back if a new session was created. // The session will handle the packet and take of that. - s.sessionHandler.Add(connID, sess) return true } -func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, protocol.ConnectionID, error) { +func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, error) { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { - return nil, nil, errors.New("too short connection ID") + return nil, errors.New("too short connection ID") } var token *Token @@ -400,17 +398,17 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the session. (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) - return nil, nil, s.sendRetry(p.remoteAddr, hdr) + return nil, s.sendRetry(p.remoteAddr, hdr) } if queueLen := atomic.LoadInt32(&s.sessionQueueLen); queueLen >= protocol.MaxAcceptQueueSize { s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) - return nil, nil, s.sendServerBusy(p.remoteAddr, hdr) + return nil, s.sendServerBusy(p.remoteAddr, hdr) } connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) if err != nil { - return nil, nil, err + return nil, err } s.logger.Debugf("Changing connection ID to %s.", connID) sess := s.createNewSession( @@ -422,7 +420,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui hdr.Version, ) sess.handlePacket(p) - return sess, connID, nil + return sess, nil } func (s *baseServer) createNewSession( diff --git a/server_test.go b/server_test.go index a430be96..43479b51 100644 --- a/server_test.go +++ b/server_test.go @@ -433,7 +433,6 @@ var _ = Describe("Server", func() { Consistently(done).ShouldNot(BeClosed()) // make the go routine return - sess.EXPECT().getPerspective() Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -674,7 +673,6 @@ var _ = Describe("Server", func() { Consistently(done).ShouldNot(BeClosed()) // make the go routine return - sess.EXPECT().getPerspective() Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) diff --git a/session.go b/session.go index fd6a81ad..63928efa 100644 --- a/session.go +++ b/session.go @@ -73,13 +73,13 @@ func (p *receivedPacket) Clone() *receivedPacket { } type sessionRunner interface { + Add(protocol.ConnectionID, packetHandler) [16]byte Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) ReplaceWithClosed(protocol.ConnectionID, packetHandler) AddResetToken([16]byte, packetHandler) RemoveResetToken([16]byte) RetireResetToken([16]byte) - GetStatelessResetToken(protocol.ConnectionID) [16]byte } type handshakeRunner struct { @@ -220,7 +220,7 @@ var newSession = func( initialStream := newCryptoStream() handshakeStream := newCryptoStream() oneRTTStream := newPostHandshakeCryptoStream(s.framer) - token := s.sessionRunner.GetStatelessResetToken(srcConnID) + token := s.sessionRunner.Add(srcConnID, s) params := &handshake.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData, diff --git a/session_test.go b/session_test.go index c74b189b..969d2d64 100644 --- a/session_test.go +++ b/session_test.go @@ -104,7 +104,7 @@ var _ = Describe("Session", func() { Eventually(areSessionsRunning).Should(BeFalse()) sessionRunner = NewMockSessionRunner(mockCtrl) - sessionRunner.EXPECT().GetStatelessResetToken(gomock.Any()) + sessionRunner.EXPECT().Add(gomock.Any(), gomock.Any()) mconn = newMockConnection() tokenGenerator, err := handshake.NewTokenGenerator() Expect(err).ToNot(HaveOccurred())