From 94046cdb4b2aed7826373616c6bb1a5c708d820f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 6 Mar 2019 13:58:32 +0900 Subject: [PATCH] implement sending of stateless resets --- client.go | 3 ++- client_test.go | 22 +++++++++------- interface.go | 2 ++ mock_multiplexer_test.go | 8 +++--- mock_packet_handler_manager_test.go | 38 ++++++++++++++++++++++++++ mock_session_runner_test.go | 24 +++++++++++++++++ multiplexer.go | 27 ++++++++++++++----- multiplexer_test.go | 13 ++++++--- packet_handler_map.go | 41 ++++++++++++++++++++++++++--- packet_handler_map_test.go | 39 +++++++++++++++++++++------ server.go | 13 ++++++--- server_test.go | 12 +++++---- session.go | 3 +++ 13 files changed, 199 insertions(+), 46 deletions(-) diff --git a/client.go b/client.go index dc70677e..9ede391f 100644 --- a/client.go +++ b/client.go @@ -120,7 +120,7 @@ func dialContext( createdPacketConn bool, ) (Session, error) { config = populateClientConfig(config, createdPacketConn) - packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength) + packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey) if err != nil { return nil, err } @@ -240,6 +240,7 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config { MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, KeepAlive: config.KeepAlive, + StatelessResetKey: config.StatelessResetKey, } } diff --git a/client_test.go b/client_test.go index 4a857d06..87bd1a9e 100644 --- a/client_test.go +++ b/client_test.go @@ -127,7 +127,7 @@ var _ = Describe("Client", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Close() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) remoteAddrChan := make(chan string, 1) newClientSession = func( @@ -157,7 +157,7 @@ var _ = Describe("Client", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Close() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) hostnameChan := make(chan string, 1) newClientSession = func( @@ -186,7 +186,7 @@ var _ = Describe("Client", func() { It("returns after the handshake is complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) run := make(chan struct{}) newClientSession = func( @@ -222,7 +222,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs while waiting for the connection to become secure", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) testErr := errors.New("early handshake error") newClientSession = func( @@ -256,7 +256,7 @@ var _ = Describe("Client", func() { It("closes the session when the context is canceled", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) sessionRunning := make(chan struct{}) defer close(sessionRunning) @@ -304,7 +304,7 @@ var _ = Describe("Client", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(connID, gomock.Any()) manager.EXPECT().Retire(connID) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) var runner sessionRunner sess := NewMockQuicSession(mockCtrl) @@ -345,7 +345,7 @@ var _ = Describe("Client", func() { } manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) manager.EXPECT().Add(gomock.Any(), gomock.Any()) var conn connection @@ -401,6 +401,7 @@ var _ = Describe("Client", func() { MaxIncomingStreams: 1234, MaxIncomingUniStreams: 4321, ConnectionIDLength: 13, + StatelessResetKey: []byte("foobar"), } c := populateClientConfig(config, false) Expect(c.HandshakeTimeout).To(Equal(1337 * time.Minute)) @@ -408,11 +409,12 @@ var _ = Describe("Client", func() { Expect(c.MaxIncomingStreams).To(Equal(1234)) Expect(c.MaxIncomingUniStreams).To(Equal(4321)) Expect(c.ConnectionIDLength).To(Equal(13)) + Expect(c.StatelessResetKey).To(Equal([]byte("foobar"))) }) It("errors when the Config contains an invalid version", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) version := protocol.VersionNumber(0x1234) _, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}}) @@ -456,7 +458,7 @@ var _ = Describe("Client", func() { It("creates new TLS sessions with the right parameters", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(connID, gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} c := make(chan struct{}) @@ -508,7 +510,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs during version negotiation", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(connID, gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) testErr := errors.New("early handshake error") newClientSession = func( diff --git a/interface.go b/interface.go index 0b80702f..4827c7d9 100644 --- a/interface.go +++ b/interface.go @@ -208,6 +208,8 @@ type Config struct { // If not set, it will default to 100. // If set to a negative value, it doesn't allow any unidirectional streams. MaxIncomingUniStreams int + // The StatelessResetKey is used to generate stateless reset tokens. + StatelessResetKey []byte // KeepAlive defines whether this peer will periodically send a packet to keep the connection alive. KeepAlive bool } diff --git a/mock_multiplexer_test.go b/mock_multiplexer_test.go index 52731b5a..7ca5a254 100644 --- a/mock_multiplexer_test.go +++ b/mock_multiplexer_test.go @@ -35,18 +35,18 @@ func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder { } // AddConn mocks base method -func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int) (packetHandlerManager, error) { +func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 []byte) (packetHandlerManager, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddConn", arg0, arg1) + ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2) ret0, _ := ret[0].(packetHandlerManager) ret1, _ := ret[1].(error) return ret0, ret1 } // AddConn indicates an expected call of AddConn -func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2) } // RemoveConn mocks base method diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 257af2c7..3793d270 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -46,6 +46,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) } +// AddResetToken mocks base method +func (m *MockPacketHandlerManager) AddResetToken(arg0 [16]byte, arg1 packetHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddResetToken", arg0, arg1) +} + +// AddResetToken indicates an expected call of AddResetToken +func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1) +} + // Close mocks base method func (m *MockPacketHandlerManager) Close() error { m.ctrl.T.Helper() @@ -72,6 +84,20 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) } +// GetStatelessResetToken mocks base method +func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) + ret0, _ := ret[0].([16]byte) + return ret0 +} + +// GetStatelessResetToken indicates an expected call of GetStatelessResetToken +func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0) +} + // Remove mocks base method func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() @@ -84,6 +110,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) } +// RemoveResetToken mocks base method +func (m *MockPacketHandlerManager) RemoveResetToken(arg0 [16]byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RemoveResetToken", arg0) +} + +// RemoveResetToken indicates an expected call of RemoveResetToken +func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0) +} + // Retire mocks base method func (m *MockPacketHandlerManager) Retire(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/mock_session_runner_test.go b/mock_session_runner_test.go index ab88320f..fad0e813 100644 --- a/mock_session_runner_test.go +++ b/mock_session_runner_test.go @@ -34,6 +34,18 @@ func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder { return m.recorder } +// AddResetToken mocks base method +func (m *MockSessionRunner) AddResetToken(arg0 [16]byte, arg1 packetHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddResetToken", arg0, arg1) +} + +// AddResetToken indicates an expected call of AddResetToken +func (mr *MockSessionRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockSessionRunner)(nil).AddResetToken), arg0, arg1) +} + // OnHandshakeComplete mocks base method func (m *MockSessionRunner) OnHandshakeComplete(arg0 Session) { m.ctrl.T.Helper() @@ -58,6 +70,18 @@ func (mr *MockSessionRunnerMockRecorder) Remove(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSessionRunner)(nil).Remove), arg0) } +// RemoveResetToken mocks base method +func (m *MockSessionRunner) RemoveResetToken(arg0 [16]byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RemoveResetToken", arg0) +} + +// RemoveResetToken indicates an expected call of RemoveResetToken +func (mr *MockSessionRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockSessionRunner)(nil).RemoveResetToken), arg0) +} + // Retire mocks base method func (m *MockSessionRunner) Retire(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/multiplexer.go b/multiplexer.go index e8c3b7db..eeffca53 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -1,6 +1,7 @@ package quic import ( + "bytes" "fmt" "net" "sync" @@ -14,13 +15,14 @@ var ( ) type multiplexer interface { - AddConn(net.PacketConn, int) (packetHandlerManager, error) + AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte) (packetHandlerManager, error) RemoveConn(net.PacketConn) error } type connManager struct { - connIDLen int - manager packetHandlerManager + connIDLen int + statelessResetKey []byte + manager packetHandlerManager } // The connMultiplexer listens on multiple net.PacketConns and dispatches @@ -29,7 +31,7 @@ type connMultiplexer struct { mutex sync.Mutex conns map[net.PacketConn]connManager - newPacketHandlerManager func(net.PacketConn, int, utils.Logger) packetHandlerManager // so it can be replaced in the tests + newPacketHandlerManager func(net.PacketConn, int, []byte, utils.Logger) packetHandlerManager // so it can be replaced in the tests logger utils.Logger } @@ -47,19 +49,30 @@ func getMultiplexer() multiplexer { return connMuxer } -func (m *connMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandlerManager, error) { +func (m *connMultiplexer) AddConn( + c net.PacketConn, + connIDLen int, + statelessResetKey []byte, +) (packetHandlerManager, error) { m.mutex.Lock() defer m.mutex.Unlock() p, ok := m.conns[c] if !ok { - manager := m.newPacketHandlerManager(c, connIDLen, m.logger) - p = connManager{connIDLen: connIDLen, manager: manager} + manager := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, m.logger) + p = connManager{ + connIDLen: connIDLen, + statelessResetKey: statelessResetKey, + manager: manager, + } m.conns[c] = p } if p.connIDLen != connIDLen { return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) } + if statelessResetKey != nil && !bytes.Equal(p.statelessResetKey, statelessResetKey) { + return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") + } return p.manager, nil } diff --git a/multiplexer_test.go b/multiplexer_test.go index f50f227f..1b40cf11 100644 --- a/multiplexer_test.go +++ b/multiplexer_test.go @@ -8,16 +8,23 @@ import ( var _ = Describe("Client Multiplexer", func() { It("adds a new packet conn ", func() { conn := newMockPacketConn() - _, err := getMultiplexer().AddConn(conn, 8) + _, err := getMultiplexer().AddConn(conn, 8, nil) Expect(err).ToNot(HaveOccurred()) }) It("errors when adding an existing conn with a different connection ID length", func() { conn := newMockPacketConn() - _, err := getMultiplexer().AddConn(conn, 5) + _, err := getMultiplexer().AddConn(conn, 5, nil) Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 6) + _, err = getMultiplexer().AddConn(conn, 6, nil) Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs")) }) + It("errors when adding an existing conn with a different stateless rest key", func() { + conn := newMockPacketConn() + _, err := getMultiplexer().AddConn(conn, 7, []byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + _, err = getMultiplexer().AddConn(conn, 7, []byte("raboof")) + Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn")) + }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 0ae65b51..f7913a44 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -1,7 +1,11 @@ package quic import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" "errors" + "hash" "net" "sync" "time" @@ -30,12 +34,19 @@ type packetHandlerMap struct { deleteRetiredSessionsAfter time.Duration + statelessResetHasher hash.Hash + logger utils.Logger } var _ packetHandlerManager = &packetHandlerMap{} -func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager { +func newPacketHandlerMap( + conn net.PacketConn, + connIDLen int, + statelessResetKey []byte, + logger utils.Logger, +) packetHandlerManager { m := &packetHandlerMap{ conn: conn, connIDLen: connIDLen, @@ -43,6 +54,7 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger handlers: make(map[string]packetHandler), resetTokens: make(map[[16]byte]packetHandler), deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, + statelessResetHasher: hmac.New(sha256.New, statelessResetKey), logger: logger, } go m.listen() @@ -194,8 +206,7 @@ func (h *packetHandlerMap) handlePacket( return } if data[0]&0x80 == 0 { - // TODO(#943): send a stateless reset - h.logger.Debugf("received a short header packet with an unexpected connection ID %s", connID) + go h.maybeSendStatelessReset(p, connID) return } if h.server == nil { // no server set @@ -217,8 +228,30 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { var token [16]byte copy(token[:], data[len(data)-16:]) if sess, ok := h.resetTokens[token]; ok { - sess.destroy(errors.New("received a stateless reset")) + h.logger.Debugf("Received a stateless retry with token %#x. Closing session.", token) + go sess.destroy(errors.New("received a stateless reset")) return true } return false } + +func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte { + h.statelessResetHasher.Write(connID.Bytes()) + var token [16]byte + copy(token[:], h.statelessResetHasher.Sum(nil)) + h.statelessResetHasher.Reset() + return token +} + +func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { + defer p.buffer.Release() + token := h.GetStatelessResetToken(connID) + h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) + data := make([]byte, 23) + rand.Read(data) + data[0] = (data[0] & 0x7f) | 0x40 + data = append(data, token[:]...) + if _, err := h.conn.WriteTo(data, p.remoteAddr); err != nil { + h.logger.Debugf("Error sending Stateless Reset: %s", err) + } +} diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 1b078c1b..86de4df2 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "errors" + "net" "time" "github.com/golang/mock/gomock" @@ -40,7 +41,7 @@ var _ = Describe("Packet Handler Map", func() { BeforeEach(func() { conn = newMockPacketConn() - handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger).(*packetHandlerMap) + handler = newPacketHandlerMap(conn, 5, nil, utils.DefaultLogger).(*packetHandlerMap) }) AfterEach(func() { @@ -163,6 +164,14 @@ var _ = Describe("Packet Handler Map", func() { }) Context("stateless reset handling", func() { + It("generates stateless reset tokens", func() { + connID1 := []byte{0xde, 0xad, 0xbe, 0xef} + connID2 := []byte{0xde, 0xca, 0xfb, 0xad} + token1 := handler.GetStatelessResetToken(connID1) + Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1)) + Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1)) + }) + It("handles stateless resets", func() { packetHandler := NewMockPacketHandler(mockCtrl) token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} @@ -195,16 +204,30 @@ var _ = Describe("Packet Handler Map", func() { It("deletes reset tokens", func() { handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} + packetHandler := NewMockPacketHandler(mockCtrl) + handler.Add(connID, packetHandler) token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) handler.RemoveResetToken(token) - handler.handlePacket(nil, nil, getPacket(connID)) - // don't EXPECT any calls to handlePacket of the MockPacketHandler - packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - handler.handlePacket(nil, nil, packet) - // don't EXPECT any calls to handlePacket of the MockPacketHandler - Expect(handler.resetTokens).To(BeEmpty()) + packetHandler.EXPECT().handlePacket(gomock.Any()) + p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) + p = append(p, make([]byte, 50)...) + p = append(p, token[:]...) + handler.handlePacket(nil, nil, p) + // destroy() would be called from a separate go routine + // make sure we give it enough time to be called to cause an error here + time.Sleep(scaleDuration(25 * time.Millisecond)) + }) + + It("sends stateless resets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, 100)...) + handler.handlePacket(addr, getPacketBuffer(), p) + var reset mockPacketConnWrite + Eventually(conn.dataWritten).Should(Receive(&reset)) + Expect(reset.to).To(Equal(addr)) + Expect(reset.data[0] & 0x80).To(BeZero()) // short header packet + Expect(reset.data).To(HaveLen(protocol.MinStatelessResetSize)) }) }) diff --git a/server.go b/server.go index 4f7b418b..fe2f9388 100644 --- a/server.go +++ b/server.go @@ -36,6 +36,9 @@ type packetHandlerManager interface { Add(protocol.ConnectionID, packetHandler) Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) + AddResetToken([16]byte, packetHandler) + RemoveResetToken([16]byte) + GetStatelessResetToken(protocol.ConnectionID) [16]byte SetServer(unknownPacketHandler) CloseServer() } @@ -55,6 +58,8 @@ type sessionRunner interface { OnHandshakeComplete(Session) Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) + AddResetToken([16]byte, packetHandler) + RemoveResetToken([16]byte) } type runner struct { @@ -143,7 +148,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, } } - sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength) + sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey) if err != nil { return nil, err } @@ -266,6 +271,7 @@ func populateServerConfig(config *Config) *Config { MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, ConnectionIDLength: connIDLen, + StatelessResetKey: config.StatelessResetKey, } } @@ -341,8 +347,8 @@ func (s *server) handlePacketImpl(p *receivedPacket) bool /* was the packet pass s.logger.Debugf("Error parsing packet: %s", err) return false } + // Short header packets should never end up here in the first place if !hdr.IsLongHeader { - // TODO: send a stateless reset return false } // send a Version Negotiation Packet if the client is speaking a different protocol version @@ -430,8 +436,7 @@ func (s *server) createNewSession( srcConnID protocol.ConnectionID, version protocol.VersionNumber, ) (quicSession, error) { - // TODO(#855): generate a real token - token := [16]byte{42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42} + token := s.sessionHandler.GetStatelessResetToken(srcConnID) params := &handshake.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData, diff --git a/server_test.go b/server_test.go index 7313f32b..5fb9d291 100644 --- a/server_test.go +++ b/server_test.go @@ -79,11 +79,12 @@ var _ = Describe("Server", func() { supportedVersions := []protocol.VersionNumber{protocol.VersionTLS} acceptCookie := func(_ net.Addr, _ *Cookie) bool { return true } config := Config{ - Versions: supportedVersions, - AcceptCookie: acceptCookie, - HandshakeTimeout: 1337 * time.Hour, - IdleTimeout: 42 * time.Minute, - KeepAlive: true, + Versions: supportedVersions, + AcceptCookie: acceptCookie, + HandshakeTimeout: 1337 * time.Hour, + IdleTimeout: 42 * time.Minute, + KeepAlive: true, + StatelessResetKey: []byte("foobar"), } ln, err := Listen(conn, tlsConf, &config) Expect(err).ToNot(HaveOccurred()) @@ -94,6 +95,7 @@ var _ = Describe("Server", func() { Expect(server.config.IdleTimeout).To(Equal(42 * time.Minute)) Expect(reflect.ValueOf(server.config.AcceptCookie)).To(Equal(reflect.ValueOf(acceptCookie))) Expect(server.config.KeepAlive).To(BeTrue()) + Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar"))) // stop the listener Expect(ln.Close()).To(Succeed()) }) diff --git a/session.go b/session.go index bd8ba567..74a0f5f5 100644 --- a/session.go +++ b/session.go @@ -927,6 +927,9 @@ func (s *session) processTransportParameters(data []byte) { s.packer.HandleTransportParameters(params) s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.connFlowController.UpdateSendWindow(params.InitialMaxData) + if params.StatelessResetToken != nil { + s.sessionRunner.AddResetToken(*params.StatelessResetToken, s) + } } func (s *session) processTransportParametersForClient(data []byte) (*handshake.TransportParameters, error) {