diff --git a/client.go b/client.go index 1c5654f6..1782cde8 100644 --- a/client.go +++ b/client.go @@ -22,9 +22,10 @@ type client struct { tlsConf *tls.Config config *Config - connIDGenerator ConnectionIDGenerator - srcConnID protocol.ConnectionID - destConnID protocol.ConnectionID + connIDGenerator ConnectionIDGenerator + statelessResetter *statelessResetter + srcConnID protocol.ConnectionID + destConnID protocol.ConnectionID initialPacketNumber protocol.PacketNumber hasNegotiatedVersion bool @@ -137,13 +138,14 @@ func dial( ctx context.Context, conn sendConn, connIDGenerator ConnectionIDGenerator, + statelessResetter *statelessResetter, packetHandlers packetHandlerManager, tlsConf *tls.Config, config *Config, onClose func(), use0RTT bool, ) (quicConn, error) { - c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT) + c, err := newClient(conn, connIDGenerator, statelessResetter, config, tlsConf, onClose, use0RTT) if err != nil { return nil, err } @@ -162,7 +164,15 @@ func dial( return c.conn, nil } -func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) { +func newClient( + sendConn sendConn, + connIDGenerator ConnectionIDGenerator, + statelessResetter *statelessResetter, + config *Config, + tlsConf *tls.Config, + onClose func(), + use0RTT bool, +) (*client, error) { srcConnID, err := connIDGenerator.GenerateConnectionID() if err != nil { return nil, err @@ -172,17 +182,18 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config return nil, err } c := &client{ - connIDGenerator: connIDGenerator, - srcConnID: srcConnID, - destConnID: destConnID, - sendConn: sendConn, - use0RTT: use0RTT, - onClose: onClose, - tlsConf: tlsConf, - config: config, - version: config.Versions[0], - handshakeChan: make(chan struct{}), - logger: utils.DefaultLogger.WithPrefix("client"), + connIDGenerator: connIDGenerator, + statelessResetter: statelessResetter, + srcConnID: srcConnID, + destConnID: destConnID, + sendConn: sendConn, + use0RTT: use0RTT, + onClose: onClose, + tlsConf: tlsConf, + config: config, + version: config.Versions[0], + handshakeChan: make(chan struct{}), + logger: utils.DefaultLogger.WithPrefix("client"), } return c, nil } @@ -197,6 +208,7 @@ func (c *client) dial(ctx context.Context) error { c.destConnID, c.srcConnID, c.connIDGenerator, + c.statelessResetter, c.config, c.tlsConf, c.initialPacketNumber, diff --git a/client_test.go b/client_test.go index 2bfe3cb9..8714647a 100644 --- a/client_test.go +++ b/client_test.go @@ -33,6 +33,7 @@ var _ = Describe("Client", func() { destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, connIDGenerator ConnectionIDGenerator, + statelessResetToken *statelessResetter, conf *Config, tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber, @@ -107,6 +108,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, + _ *statelessResetter, _ *Config, _ *tls.Config, _ protocol.PacketNumber, @@ -124,7 +126,15 @@ var _ = Describe("Client", func() { conn.EXPECT().HandshakeComplete().Return(c) return conn } - cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false) + cl, err := newClient( + packetConn, + &protocol.DefaultConnectionIDGenerator{}, + newStatelessResetter(nil), + populateConfig(config), + tlsConf, + nil, + false, + ) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -144,6 +154,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, + _ *statelessResetter, _ *Config, _ *tls.Config, _ protocol.PacketNumber, @@ -161,7 +172,15 @@ var _ = Describe("Client", func() { return conn } - cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true) + cl, err := newClient( + packetConn, + &protocol.DefaultConnectionIDGenerator{}, + newStatelessResetter(nil), + populateConfig(config), + tlsConf, + nil, + true, + ) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -181,6 +200,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, + _ *statelessResetter, _ *Config, _ *tls.Config, _ protocol.PacketNumber, @@ -197,7 +217,13 @@ var _ = Describe("Client", func() { return conn } var closed bool - cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true) + cl, err := newClient( + packetConn, + &protocol.DefaultConnectionIDGenerator{}, + newStatelessResetter(nil), + populateConfig(config), tlsConf, func() { closed = true }, + true, + ) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -266,6 +292,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, + _ *statelessResetter, configP *Config, _ *tls.Config, _ protocol.PacketNumber, @@ -309,6 +336,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, connID protocol.ConnectionID, _ ConnectionIDGenerator, + _ *statelessResetter, configP *Config, _ *tls.Config, pn protocol.PacketNumber, diff --git a/conn_id_generator.go b/conn_id_generator.go index d7be6540..c309c2cd 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -15,19 +15,19 @@ type connIDGenerator struct { activeSrcConnIDs map[uint64]protocol.ConnectionID initialClientDestConnID *protocol.ConnectionID // nil for the client - addConnectionID func(protocol.ConnectionID) - getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken - removeConnectionID func(protocol.ConnectionID) - retireConnectionID func(protocol.ConnectionID) - replaceWithClosed func([]protocol.ConnectionID, []byte) - queueControlFrame func(wire.Frame) + addConnectionID func(protocol.ConnectionID) + statelessResetter *statelessResetter + removeConnectionID func(protocol.ConnectionID) + retireConnectionID func(protocol.ConnectionID) + replaceWithClosed func([]protocol.ConnectionID, []byte) + queueControlFrame func(wire.Frame) } func newConnIDGenerator( initialConnectionID protocol.ConnectionID, initialClientDestConnID *protocol.ConnectionID, // nil for the client addConnectionID func(protocol.ConnectionID), - getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, + statelessResetter *statelessResetter, removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), replaceWithClosed func([]protocol.ConnectionID, []byte), @@ -35,14 +35,14 @@ func newConnIDGenerator( generator ConnectionIDGenerator, ) *connIDGenerator { m := &connIDGenerator{ - generator: generator, - activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), - addConnectionID: addConnectionID, - getStatelessResetToken: getStatelessResetToken, - removeConnectionID: removeConnectionID, - retireConnectionID: retireConnectionID, - replaceWithClosed: replaceWithClosed, - queueControlFrame: queueControlFrame, + generator: generator, + activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), + addConnectionID: addConnectionID, + statelessResetter: statelessResetter, + removeConnectionID: removeConnectionID, + retireConnectionID: retireConnectionID, + replaceWithClosed: replaceWithClosed, + queueControlFrame: queueControlFrame, } m.activeSrcConnIDs[0] = initialConnectionID m.initialClientDestConnID = initialClientDestConnID @@ -104,7 +104,7 @@ func (m *connIDGenerator) issueNewConnID() error { m.queueControlFrame(&wire.NewConnectionIDFrame{ SequenceNumber: m.highestSeq + 1, ConnectionID: connID, - StatelessResetToken: m.getStatelessResetToken(connID), + StatelessResetToken: m.statelessResetter.GetStatelessResetToken(connID), }) m.highestSeq++ return nil diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index c438158e..05b99865 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -19,14 +19,11 @@ var _ = Describe("Connection ID Generator", func() { replacedWithClosed []protocol.ConnectionID queuedFrames []wire.Frame g *connIDGenerator + statelessResetter *statelessResetter ) initialConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) initialClientDestConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc, 0xd, 0xe}) - - connIDToToken := func(c protocol.ConnectionID) protocol.StatelessResetToken { - b := c.Bytes()[0] - return protocol.StatelessResetToken{b, b, b, b, b, b, b, b, b, b, b, b, b, b, b, b} - } + statelessResetter = newStatelessResetter(nil) BeforeEach(func() { addedConnIDs = nil @@ -38,7 +35,7 @@ var _ = Describe("Connection ID Generator", func() { initialConnID, &initialClientDestConnID, func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) }, - connIDToToken, + statelessResetter, func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, func(cs []protocol.ConnectionID, _ []byte) { replacedWithClosed = append(replacedWithClosed, cs...) }, @@ -61,7 +58,7 @@ var _ = Describe("Connection ID Generator", func() { nf := f.(*wire.NewConnectionIDFrame) Expect(nf.SequenceNumber).To(BeEquivalentTo(i + 1)) Expect(nf.ConnectionID.Len()).To(Equal(7)) - Expect(nf.StatelessResetToken).To(Equal(connIDToToken(nf.ConnectionID))) + Expect(nf.StatelessResetToken).To(Equal(statelessResetter.GetStatelessResetToken(nf.ConnectionID))) } }) diff --git a/connection.go b/connection.go index 4885bffb..4c25fc88 100644 --- a/connection.go +++ b/connection.go @@ -85,7 +85,6 @@ func (p *receivedPacket) Clone() *receivedPacket { type connRunner interface { Add(protocol.ConnectionID, packetHandler) bool - GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) ReplaceWithClosed([]protocol.ConnectionID, []byte) @@ -225,7 +224,7 @@ var newConnection = func( destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, connIDGenerator ConnectionIDGenerator, - statelessResetToken protocol.StatelessResetToken, + statelessResetter *statelessResetter, conf *Config, tlsConf *tls.Config, tokenGenerator *handshake.TokenGenerator, @@ -263,7 +262,7 @@ var newConnection = func( srcConnID, &clientDestConnID, func(connID protocol.ConnectionID) { runner.Add(connID, s) }, - runner.GetStatelessResetToken, + statelessResetter, runner.Remove, runner.Retire, runner.ReplaceWithClosed, @@ -282,6 +281,7 @@ var newConnection = func( s.logger, ) s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize)))) + statelessResetToken := statelessResetter.GetStatelessResetToken(srcConnID) params := &wire.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -340,6 +340,7 @@ var newClientConnection = func( destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, connIDGenerator ConnectionIDGenerator, + statelessResetter *statelessResetter, conf *Config, tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber, @@ -372,7 +373,7 @@ var newClientConnection = func( srcConnID, nil, func(connID protocol.ConnectionID) { runner.Add(connID, s) }, - runner.GetStatelessResetToken, + statelessResetter, runner.Remove, runner.Retire, runner.ReplaceWithClosed, diff --git a/connection_test.go b/connection_test.go index 5da85810..d4ae04aa 100644 --- a/connection_test.go +++ b/connection_test.go @@ -125,7 +125,7 @@ func newServerTestConnection( protocol.ConnectionID{}, srcConnID, &protocol.DefaultConnectionIDGenerator{}, - protocol.StatelessResetToken{}, + newStatelessResetter(nil), populateConfig(config), &tls.Config{}, handshake.NewTokenGenerator(handshake.TokenProtectorKey{}), @@ -180,6 +180,7 @@ func newClientTestConnection( destConnID, srcConnID, &protocol.DefaultConnectionIDGenerator{}, + newStatelessResetter(nil), populateConfig(config), &tls.Config{ServerName: "quic-go.net"}, 0, diff --git a/mock_conn_runner_test.go b/mock_conn_runner_test.go index ef1e170a..3db6a9b9 100644 --- a/mock_conn_runner_test.go +++ b/mock_conn_runner_test.go @@ -114,44 +114,6 @@ func (c *MockConnRunnerAddResetTokenCall) DoAndReturn(f func(protocol.StatelessR return c } -// GetStatelessResetToken mocks base method. -func (m *MockConnRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) - ret0, _ := ret[0].(protocol.StatelessResetToken) - return ret0 -} - -// GetStatelessResetToken indicates an expected call of GetStatelessResetToken. -func (mr *MockConnRunnerMockRecorder) GetStatelessResetToken(arg0 any) *MockConnRunnerGetStatelessResetTokenCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockConnRunner)(nil).GetStatelessResetToken), arg0) - return &MockConnRunnerGetStatelessResetTokenCall{Call: call} -} - -// MockConnRunnerGetStatelessResetTokenCall wrap *gomock.Call -type MockConnRunnerGetStatelessResetTokenCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockConnRunnerGetStatelessResetTokenCall) Return(arg0 protocol.StatelessResetToken) *MockConnRunnerGetStatelessResetTokenCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockConnRunnerGetStatelessResetTokenCall) Do(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockConnRunnerGetStatelessResetTokenCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockConnRunnerGetStatelessResetTokenCall) DoAndReturn(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockConnRunnerGetStatelessResetTokenCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // Remove mocks base method. func (m *MockConnRunner) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index f99d909c..a51e590b 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -266,44 +266,6 @@ func (c *MockPacketHandlerManagerGetByResetTokenCall) DoAndReturn(f func(protoco return c } -// GetStatelessResetToken mocks base method. -func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) - ret0, _ := ret[0].(protocol.StatelessResetToken) - return ret0 -} - -// GetStatelessResetToken indicates an expected call of GetStatelessResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 any) *MockPacketHandlerManagerGetStatelessResetTokenCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0) - return &MockPacketHandlerManagerGetStatelessResetTokenCall{Call: call} -} - -// MockPacketHandlerManagerGetStatelessResetTokenCall wrap *gomock.Call -type MockPacketHandlerManagerGetStatelessResetTokenCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPacketHandlerManagerGetStatelessResetTokenCall) Return(arg0 protocol.StatelessResetToken) *MockPacketHandlerManagerGetStatelessResetTokenCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPacketHandlerManagerGetStatelessResetTokenCall) Do(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockPacketHandlerManagerGetStatelessResetTokenCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPacketHandlerManagerGetStatelessResetTokenCall) DoAndReturn(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockPacketHandlerManagerGetStatelessResetTokenCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // Remove mocks base method. func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/packet_handler_map.go b/packet_handler_map.go index 7840202c..84841984 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -1,10 +1,6 @@ package quic import ( - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "hash" "io" "net" "sync" @@ -56,15 +52,12 @@ type packetHandlerMap struct { deleteRetiredConnsAfter time.Duration - statelessResetMutex sync.Mutex - statelessResetHasher hash.Hash - logger utils.Logger } var _ packetHandlerManager = &packetHandlerMap{} -func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap { +func newPacketHandlerMap(enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap { h := &packetHandlerMap{ closeChan: make(chan struct{}), handlers: make(map[protocol.ConnectionID]packetHandler), @@ -73,9 +66,6 @@ func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePa enqueueClosePacket: enqueueClosePacket, logger: logger, } - if key != nil { - h.statelessResetHasher = hmac.New(sha256.New, key[:]) - } if h.logger.Debug() { go h.logUsage() } @@ -236,20 +226,3 @@ func (h *packetHandlerMap) Close(e error) { h.mutex.Unlock() wg.Wait() } - -func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { - var token protocol.StatelessResetToken - if h.statelessResetHasher == nil { - // Return a random stateless reset token. - // This token will be sent in the server's transport parameters. - // By using a random token, an off-path attacker won't be able to disrupt the connection. - rand.Read(token[:]) - return token - } - h.statelessResetMutex.Lock() - h.statelessResetHasher.Write(connID.Bytes()) - copy(token[:], h.statelessResetHasher.Sum(nil)) - h.statelessResetHasher.Reset() - h.statelessResetMutex.Unlock() - return token -} diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 0b916a9a..353e2bff 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -14,7 +14,7 @@ import ( ) func TestPacketHandlerMapAddAndRemove(t *testing.T) { - m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + m := newPacketHandlerMap(nil, utils.DefaultLogger) connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) h := &mockPacketHandler{} require.True(t, m.Add(connID, h)) @@ -36,7 +36,7 @@ func TestPacketHandlerMapAddAndRemove(t *testing.T) { } func TestPacketHandlerMapAddWithClientChosenConnID(t *testing.T) { - m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + m := newPacketHandlerMap(nil, utils.DefaultLogger) h := &mockPacketHandler{} connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) @@ -54,7 +54,7 @@ func TestPacketHandlerMapAddWithClientChosenConnID(t *testing.T) { } func TestPacketHandlerMapRetire(t *testing.T) { - m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + m := newPacketHandlerMap(nil, utils.DefaultLogger) dur := scaleDuration(10 * time.Millisecond) m.deleteRetiredConnsAfter = dur connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) @@ -76,7 +76,7 @@ func TestPacketHandlerMapRetire(t *testing.T) { } func TestPacketHandlerMapAddGetRemoveResetTokens(t *testing.T) { - m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + m := newPacketHandlerMap(nil, utils.DefaultLogger) token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} handler := &mockPacketHandler{} m.AddResetToken(token, handler) @@ -88,43 +88,12 @@ func TestPacketHandlerMapAddGetRemoveResetTokens(t *testing.T) { require.False(t, ok) } -func TestPacketHandlerMapGenerateStatelessResetToken(t *testing.T) { - t.Run("no key", func(t *testing.T) { - m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) - b := make([]byte, 8) - rand.Read(b) - connID := protocol.ParseConnectionID(b) - tokens := make(map[protocol.StatelessResetToken]struct{}) - for i := 0; i < 100; i++ { - token := m.GetStatelessResetToken(connID) - require.NotZero(t, token) - if _, ok := tokens[token]; ok { - t.Fatalf("token %s already exists", token) - } - tokens[token] = struct{}{} - } - }) - - t.Run("with key", func(t *testing.T) { - var key StatelessResetKey - rand.Read(key[:]) - m := newPacketHandlerMap(&key, nil, utils.DefaultLogger) - b := make([]byte, 8) - rand.Read(b) - connID := protocol.ParseConnectionID(b) - token := m.GetStatelessResetToken(connID) - require.NotZero(t, token) - require.Equal(t, token, m.GetStatelessResetToken(connID)) - // generate a new connection ID - rand.Read(b) - connID2 := protocol.ParseConnectionID(b) - require.NotEqual(t, token, m.GetStatelessResetToken(connID2)) - }) -} - func TestPacketHandlerMapReplaceWithLocalClosed(t *testing.T) { var closePackets []closePacket - m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger) + m := newPacketHandlerMap( + func(p closePacket) { closePackets = append(closePackets, p) }, + utils.DefaultLogger, + ) dur := scaleDuration(10 * time.Millisecond) m.deleteRetiredConnsAfter = dur @@ -150,7 +119,10 @@ func TestPacketHandlerMapReplaceWithLocalClosed(t *testing.T) { func TestPacketHandlerMapReplaceWithRemoteClosed(t *testing.T) { var closePackets []closePacket - m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger) + m := newPacketHandlerMap( + func(p closePacket) { closePackets = append(closePackets, p) }, + utils.DefaultLogger, + ) dur := scaleDuration(50 * time.Millisecond) m.deleteRetiredConnsAfter = dur @@ -173,7 +145,7 @@ func TestPacketHandlerMapReplaceWithRemoteClosed(t *testing.T) { } func TestPacketHandlerMapClose(t *testing.T) { - m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + m := newPacketHandlerMap(nil, utils.DefaultLogger) testErr := errors.New("shutdown") const numConns = 10 destroyChan := make(chan error, 2*numConns) diff --git a/server.go b/server.go index cece797b..2bb821ab 100644 --- a/server.go +++ b/server.go @@ -72,9 +72,10 @@ type baseServer struct { tokenGenerator *handshake.TokenGenerator maxTokenAge time.Duration - connIDGenerator ConnectionIDGenerator - connHandler packetHandlerManager - onClose func() + connIDGenerator ConnectionIDGenerator + statelessResetter *statelessResetter + connHandler packetHandlerManager + onClose func() receivedPackets chan receivedPacket @@ -95,7 +96,7 @@ type baseServer struct { protocol.ConnectionID, /* destination connection ID */ protocol.ConnectionID, /* source connection ID */ ConnectionIDGenerator, - protocol.StatelessResetToken, + *statelessResetter, *Config, *tls.Config, *handshake.TokenGenerator, @@ -248,6 +249,7 @@ func newServer( conn rawConn, connHandler packetHandlerManager, connIDGenerator ConnectionIDGenerator, + statelessResetter *statelessResetter, connContext func(context.Context) context.Context, tlsConf *tls.Config, config *Config, @@ -268,6 +270,7 @@ func newServer( maxTokenAge: maxTokenAge, verifySourceAddress: verifySourceAddress, connIDGenerator: connIDGenerator, + statelessResetter: statelessResetter, connHandler: connHandler, connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize), errorChan: make(chan struct{}), @@ -707,7 +710,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error hdr.SrcConnectionID, connID, s.connIDGenerator, - s.connHandler.GetStatelessResetToken(connID), + s.statelessResetter, config, s.tlsConf, s.tokenGenerator, diff --git a/server_test.go b/server_test.go index be93ee20..64a16c51 100644 --- a/server_test.go +++ b/server_test.go @@ -297,7 +297,7 @@ var _ = Describe("Server", func() { destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, _ ConnectionIDGenerator, - tokenP protocol.StatelessResetToken, + _ *statelessResetter, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -314,7 +314,6 @@ var _ = Describe("Server", func() { Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) newConnID = srcConnID - Expect(tokenP).To(Equal(token)) conn.EXPECT().handlePacket(p) conn.EXPECT().run().Do(func() error { close(run); return nil }) conn.EXPECT().Context().Return(context.Background()) @@ -322,7 +321,6 @@ var _ = Describe("Server", func() { return conn } phm.EXPECT().Get(connID) - phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token) phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool { Expect(cid).To(Equal(newConnID)) return true @@ -500,7 +498,7 @@ var _ = Describe("Server", func() { destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, _ ConnectionIDGenerator, - tokenP protocol.StatelessResetToken, + _ *statelessResetter, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -517,7 +515,6 @@ var _ = Describe("Server", func() { Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) newConnID = srcConnID - Expect(tokenP).To(Equal(token)) conn.EXPECT().handlePacket(p) conn.EXPECT().run().Do(func() error { close(run); return nil }) conn.EXPECT().Context().Return(context.Background()) @@ -526,7 +523,6 @@ var _ = Describe("Server", func() { } gomock.InOrder( phm.EXPECT().Get(connID), - phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token), phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool { Expect(c).To(Equal(newConnID)) return true @@ -553,7 +549,6 @@ var _ = Describe("Server", func() { serv.verifySourceAddress = func(net.Addr) bool { return false } phm.EXPECT().Get(gomock.Any()).AnyTimes() - phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() acceptConn := make(chan struct{}) @@ -569,7 +564,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, - _ protocol.StatelessResetToken, + _ *statelessResetter, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -625,7 +620,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, - _ protocol.StatelessResetToken, + _ *statelessResetter, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -645,7 +640,6 @@ var _ = Describe("Server", func() { connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) p := getInitial(connID) phm.EXPECT().Get(connID) - phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision Expect(serv.handlePacketImpl(p)).To(BeTrue()) Eventually(done).Should(BeClosed()) @@ -657,7 +651,6 @@ var _ = Describe("Server", func() { serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() } phm.EXPECT().Get(gomock.Any()).AnyTimes() - phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() connChan := make(chan *MockQUICConn, 1) @@ -675,7 +668,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, - _ protocol.StatelessResetToken, + _ *statelessResetter, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -737,7 +730,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, - _ protocol.StatelessResetToken, + _ *statelessResetter, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -769,7 +762,6 @@ var _ = Describe("Server", func() { done := make(chan struct{}) phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool { close(done) return true @@ -972,7 +964,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, - _ protocol.StatelessResetToken, + _ *statelessResetter, conf *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -990,7 +982,6 @@ var _ = Describe("Server", func() { return conn } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) serv.handleInitialImpl( receivedPacket{buffer: getPacketBuffer()}, @@ -1040,7 +1031,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, - _ protocol.StatelessResetToken, + _ *statelessResetter, conf *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -1057,7 +1048,6 @@ var _ = Describe("Server", func() { return conn } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) serv.handleInitialImpl( receivedPacket{buffer: getPacketBuffer()}, @@ -1111,7 +1101,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, - _ protocol.StatelessResetToken, + _ *statelessResetter, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -1127,7 +1117,6 @@ var _ = Describe("Server", func() { return conn } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) serv.handleInitialImpl( receivedPacket{buffer: getPacketBuffer()}, @@ -1182,7 +1171,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, - _ protocol.StatelessResetToken, + _ *statelessResetter, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -1198,7 +1187,6 @@ var _ = Describe("Server", func() { return conn } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) serv.baseServer.handleInitialImpl( receivedPacket{buffer: getPacketBuffer()}, @@ -1224,7 +1212,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, - _ protocol.StatelessResetToken, + _ *statelessResetter, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -1244,7 +1232,6 @@ var _ = Describe("Server", func() { } phm.EXPECT().Get(gomock.Any()).AnyTimes() - phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize) for i := 0; i < protocol.MaxAcceptQueueSize; i++ { conn := NewMockQUICConn(mockCtrl) @@ -1257,7 +1244,6 @@ var _ = Describe("Server", func() { wg.Add(1) rejected := make(chan struct{}) - phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) conn := NewMockQUICConn(mockCtrl) conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(qerr.TransportErrorCode) { @@ -1284,7 +1270,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, - _ protocol.StatelessResetToken, + _ *statelessResetter, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -1302,7 +1288,6 @@ var _ = Describe("Server", func() { } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) serv.baseServer.handlePacket(p) // make sure there are no Write calls on the packet conn @@ -1407,7 +1392,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, - _ protocol.StatelessResetToken, + _ *statelessResetter, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, @@ -1433,7 +1418,6 @@ var _ = Describe("Server", func() { } phm.EXPECT().Get(connID) - phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) serv.handlePacket(initial) Eventually(called).Should(BeClosed()) diff --git a/stateless_reset.go b/stateless_reset.go new file mode 100644 index 00000000..cd0059a5 --- /dev/null +++ b/stateless_reset.go @@ -0,0 +1,42 @@ +package quic + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "hash" + "sync" + + "github.com/quic-go/quic-go/internal/protocol" +) + +type statelessResetter struct { + mx sync.Mutex + h hash.Hash +} + +// newStatelessRetter creates a new stateless reset generator. +// It is valid to use a nil key. In that case, a random key will be used. +// This makes is impossible for on-path attackers to shut down established connections. +func newStatelessResetter(key *StatelessResetKey) *statelessResetter { + var h hash.Hash + if key != nil { + h = hmac.New(sha256.New, key[:]) + } else { + b := make([]byte, 32) + _, _ = rand.Read(b) + h = hmac.New(sha256.New, b) + } + return &statelessResetter{h: h} +} + +func (r *statelessResetter) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { + r.mx.Lock() + defer r.mx.Unlock() + + var token protocol.StatelessResetToken + r.h.Write(connID.Bytes()) + copy(token[:], r.h.Sum(nil)) + r.h.Reset() + return token +} diff --git a/stateless_reset_test.go b/stateless_reset_test.go new file mode 100644 index 00000000..39c5f2f3 --- /dev/null +++ b/stateless_reset_test.go @@ -0,0 +1,42 @@ +package quic + +import ( + "crypto/rand" + "testing" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/stretchr/testify/require" +) + +func TestStatelessResetter(t *testing.T) { + t.Run("no key", func(t *testing.T) { + r1 := newStatelessResetter(nil) + r2 := newStatelessResetter(nil) + for i := 0; i < 100; i++ { + b := make([]byte, 15) + rand.Read(b) + connID := protocol.ParseConnectionID(b) + t1 := r1.GetStatelessResetToken(connID) + t2 := r2.GetStatelessResetToken(connID) + require.NotZero(t, t1) + require.NotZero(t, t2) + require.NotEqual(t, t1, t2) + } + }) + + t.Run("with key", func(t *testing.T) { + var key StatelessResetKey + rand.Read(key[:]) + m := newStatelessResetter(&key) + b := make([]byte, 8) + rand.Read(b) + connID := protocol.ParseConnectionID(b) + token := m.GetStatelessResetToken(connID) + require.NotZero(t, token) + require.Equal(t, token, m.GetStatelessResetToken(connID)) + // generate a new connection ID + rand.Read(b) + connID2 := protocol.ParseConnectionID(b) + require.NotEqual(t, token, m.GetStatelessResetToken(connID2)) + }) +} diff --git a/transport.go b/transport.go index dd69ded5..32867550 100644 --- a/transport.go +++ b/transport.go @@ -115,7 +115,8 @@ type Transport struct { connIDLen int // Set in init. // If no ConnectionIDGenerator is set, this is set to a default. - connIDGenerator ConnectionIDGenerator + connIDGenerator ConnectionIDGenerator + statelessResetter *statelessResetter server *baseServer @@ -183,6 +184,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo t.conn, t.handlerMap, t.connIDGenerator, + t.statelessResetter, t.ConnContext, tlsConf, conf, @@ -222,7 +224,17 @@ func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsCon } tlsConf = tlsConf.Clone() setTLSConfigServerName(tlsConf, addr, host) - return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT) + return dial( + ctx, + newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), + t.connIDGenerator, + t.statelessResetter, + t.handlerMap, + tlsConf, + conf, + onClose, + use0RTT, + ) } func (t *Transport) init(allowZeroLengthConnIDs bool) error { @@ -242,7 +254,7 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.logger = utils.DefaultLogger // TODO: make this configurable t.conn = conn if t.handlerMap == nil { // allows mocking the handlerMap in tests - t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) + t.handlerMap = newPacketHandlerMap(t.enqueueClosePacket, t.logger) } t.listening = make(chan struct{}) @@ -268,6 +280,7 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.connIDLen = connIDLen t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen} } + t.statelessResetter = newStatelessResetter(t.StatelessResetKey) go t.listen(conn) go t.runSendQueue() @@ -478,7 +491,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) return } - token := t.handlerMap.GetStatelessResetToken(connID) + token := t.statelessResetter.GetStatelessResetToken(connID) t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) rand.Read(data) diff --git a/transport_test.go b/transport_test.go index ccbebd34..4b4f54b4 100644 --- a/transport_test.go +++ b/transport_test.go @@ -254,8 +254,6 @@ func TestTransportStatelessResetSending(t *testing.T) { // but a stateless reset is sent for packets larger than MinStatelessResetSize phm.EXPECT().Get(connID) // no handler for this connection ID phm.EXPECT().GetByResetToken(gomock.Any()) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - phm.EXPECT().GetStatelessResetToken(connID).Return(token) _, err = conn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr()) require.NoError(t, err) conn.SetReadDeadline(time.Now().Add(time.Second)) @@ -263,7 +261,8 @@ func TestTransportStatelessResetSending(t *testing.T) { n, addr, err := conn.ReadFrom(p) require.NoError(t, err) require.Equal(t, addr, tr.Conn.LocalAddr()) - require.Contains(t, string(p[:n]), string(token[:])) + srt := newStatelessResetter(tr.StatelessResetKey).GetStatelessResetToken(connID) + require.Contains(t, string(p[:n]), string(srt[:])) } func TestTransportDropsUnparseableQUICPackets(t *testing.T) {