simplify generation of stateless reset tokens (#4858)

This commit is contained in:
Marten Seemann 2025-01-11 01:52:59 -08:00 committed by GitHub
parent 9950b4c687
commit 62947d97f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 224 additions and 233 deletions

View file

@ -23,6 +23,7 @@ type client struct {
config *Config
connIDGenerator ConnectionIDGenerator
statelessResetter *statelessResetter
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
@ -137,13 +138,14 @@ func dial(
ctx context.Context,
conn sendConn,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
packetHandlers packetHandlerManager,
tlsConf *tls.Config,
config *Config,
onClose func(),
use0RTT bool,
) (quicConn, error) {
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
c, err := newClient(conn, connIDGenerator, statelessResetter, config, tlsConf, onClose, use0RTT)
if err != nil {
return nil, err
}
@ -162,7 +164,15 @@ func dial(
return c.conn, nil
}
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
func newClient(
sendConn sendConn,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
config *Config,
tlsConf *tls.Config,
onClose func(),
use0RTT bool,
) (*client, error) {
srcConnID, err := connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
@ -173,6 +183,7 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config
}
c := &client{
connIDGenerator: connIDGenerator,
statelessResetter: statelessResetter,
srcConnID: srcConnID,
destConnID: destConnID,
sendConn: sendConn,
@ -197,6 +208,7 @@ func (c *client) dial(ctx context.Context) error {
c.destConnID,
c.srcConnID,
c.connIDGenerator,
c.statelessResetter,
c.config,
c.tlsConf,
c.initialPacketNumber,

View file

@ -33,6 +33,7 @@ var _ = Describe("Client", func() {
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
statelessResetToken *statelessResetter,
conf *Config,
tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
@ -107,6 +108,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
@ -124,7 +126,15 @@ var _ = Describe("Client", func() {
conn.EXPECT().HandshakeComplete().Return(c)
return conn
}
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false)
cl, err := newClient(
packetConn,
&protocol.DefaultConnectionIDGenerator{},
newStatelessResetter(nil),
populateConfig(config),
tlsConf,
nil,
false,
)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
@ -144,6 +154,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
@ -161,7 +172,15 @@ var _ = Describe("Client", func() {
return conn
}
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true)
cl, err := newClient(
packetConn,
&protocol.DefaultConnectionIDGenerator{},
newStatelessResetter(nil),
populateConfig(config),
tlsConf,
nil,
true,
)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
@ -181,6 +200,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
@ -197,7 +217,13 @@ var _ = Describe("Client", func() {
return conn
}
var closed bool
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true)
cl, err := newClient(
packetConn,
&protocol.DefaultConnectionIDGenerator{},
newStatelessResetter(nil),
populateConfig(config), tlsConf, func() { closed = true },
true,
)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
@ -266,6 +292,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
configP *Config,
_ *tls.Config,
_ protocol.PacketNumber,
@ -309,6 +336,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID,
connID protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
configP *Config,
_ *tls.Config,
pn protocol.PacketNumber,

View file

@ -16,7 +16,7 @@ type connIDGenerator struct {
initialClientDestConnID *protocol.ConnectionID // nil for the client
addConnectionID func(protocol.ConnectionID)
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
statelessResetter *statelessResetter
removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func([]protocol.ConnectionID, []byte)
@ -27,7 +27,7 @@ func newConnIDGenerator(
initialConnectionID protocol.ConnectionID,
initialClientDestConnID *protocol.ConnectionID, // nil for the client
addConnectionID func(protocol.ConnectionID),
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
statelessResetter *statelessResetter,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func([]protocol.ConnectionID, []byte),
@ -38,7 +38,7 @@ func newConnIDGenerator(
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
addConnectionID: addConnectionID,
getStatelessResetToken: getStatelessResetToken,
statelessResetter: statelessResetter,
removeConnectionID: removeConnectionID,
retireConnectionID: retireConnectionID,
replaceWithClosed: replaceWithClosed,
@ -104,7 +104,7 @@ func (m *connIDGenerator) issueNewConnID() error {
m.queueControlFrame(&wire.NewConnectionIDFrame{
SequenceNumber: m.highestSeq + 1,
ConnectionID: connID,
StatelessResetToken: m.getStatelessResetToken(connID),
StatelessResetToken: m.statelessResetter.GetStatelessResetToken(connID),
})
m.highestSeq++
return nil

View file

@ -19,14 +19,11 @@ var _ = Describe("Connection ID Generator", func() {
replacedWithClosed []protocol.ConnectionID
queuedFrames []wire.Frame
g *connIDGenerator
statelessResetter *statelessResetter
)
initialConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
initialClientDestConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc, 0xd, 0xe})
connIDToToken := func(c protocol.ConnectionID) protocol.StatelessResetToken {
b := c.Bytes()[0]
return protocol.StatelessResetToken{b, b, b, b, b, b, b, b, b, b, b, b, b, b, b, b}
}
statelessResetter = newStatelessResetter(nil)
BeforeEach(func() {
addedConnIDs = nil
@ -38,7 +35,7 @@ var _ = Describe("Connection ID Generator", func() {
initialConnID,
&initialClientDestConnID,
func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) },
connIDToToken,
statelessResetter,
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
func(cs []protocol.ConnectionID, _ []byte) { replacedWithClosed = append(replacedWithClosed, cs...) },
@ -61,7 +58,7 @@ var _ = Describe("Connection ID Generator", func() {
nf := f.(*wire.NewConnectionIDFrame)
Expect(nf.SequenceNumber).To(BeEquivalentTo(i + 1))
Expect(nf.ConnectionID.Len()).To(Equal(7))
Expect(nf.StatelessResetToken).To(Equal(connIDToToken(nf.ConnectionID)))
Expect(nf.StatelessResetToken).To(Equal(statelessResetter.GetStatelessResetToken(nf.ConnectionID)))
}
})

View file

@ -85,7 +85,6 @@ func (p *receivedPacket) Clone() *receivedPacket {
type connRunner interface {
Add(protocol.ConnectionID, packetHandler) bool
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
Retire(protocol.ConnectionID)
Remove(protocol.ConnectionID)
ReplaceWithClosed([]protocol.ConnectionID, []byte)
@ -225,7 +224,7 @@ var newConnection = func(
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
statelessResetToken protocol.StatelessResetToken,
statelessResetter *statelessResetter,
conf *Config,
tlsConf *tls.Config,
tokenGenerator *handshake.TokenGenerator,
@ -263,7 +262,7 @@ var newConnection = func(
srcConnID,
&clientDestConnID,
func(connID protocol.ConnectionID) { runner.Add(connID, s) },
runner.GetStatelessResetToken,
statelessResetter,
runner.Remove,
runner.Retire,
runner.ReplaceWithClosed,
@ -282,6 +281,7 @@ var newConnection = func(
s.logger,
)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
statelessResetToken := statelessResetter.GetStatelessResetToken(srcConnID)
params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -340,6 +340,7 @@ var newClientConnection = func(
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
conf *Config,
tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
@ -372,7 +373,7 @@ var newClientConnection = func(
srcConnID,
nil,
func(connID protocol.ConnectionID) { runner.Add(connID, s) },
runner.GetStatelessResetToken,
statelessResetter,
runner.Remove,
runner.Retire,
runner.ReplaceWithClosed,

View file

@ -125,7 +125,7 @@ func newServerTestConnection(
protocol.ConnectionID{},
srcConnID,
&protocol.DefaultConnectionIDGenerator{},
protocol.StatelessResetToken{},
newStatelessResetter(nil),
populateConfig(config),
&tls.Config{},
handshake.NewTokenGenerator(handshake.TokenProtectorKey{}),
@ -180,6 +180,7 @@ func newClientTestConnection(
destConnID,
srcConnID,
&protocol.DefaultConnectionIDGenerator{},
newStatelessResetter(nil),
populateConfig(config),
&tls.Config{ServerName: "quic-go.net"},
0,

View file

@ -114,44 +114,6 @@ func (c *MockConnRunnerAddResetTokenCall) DoAndReturn(f func(protocol.StatelessR
return c
}
// GetStatelessResetToken mocks base method.
func (m *MockConnRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0)
ret0, _ := ret[0].(protocol.StatelessResetToken)
return ret0
}
// GetStatelessResetToken indicates an expected call of GetStatelessResetToken.
func (mr *MockConnRunnerMockRecorder) GetStatelessResetToken(arg0 any) *MockConnRunnerGetStatelessResetTokenCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockConnRunner)(nil).GetStatelessResetToken), arg0)
return &MockConnRunnerGetStatelessResetTokenCall{Call: call}
}
// MockConnRunnerGetStatelessResetTokenCall wrap *gomock.Call
type MockConnRunnerGetStatelessResetTokenCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockConnRunnerGetStatelessResetTokenCall) Return(arg0 protocol.StatelessResetToken) *MockConnRunnerGetStatelessResetTokenCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockConnRunnerGetStatelessResetTokenCall) Do(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockConnRunnerGetStatelessResetTokenCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockConnRunnerGetStatelessResetTokenCall) DoAndReturn(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockConnRunnerGetStatelessResetTokenCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Remove mocks base method.
func (m *MockConnRunner) Remove(arg0 protocol.ConnectionID) {
m.ctrl.T.Helper()

View file

@ -266,44 +266,6 @@ func (c *MockPacketHandlerManagerGetByResetTokenCall) DoAndReturn(f func(protoco
return c
}
// GetStatelessResetToken mocks base method.
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0)
ret0, _ := ret[0].(protocol.StatelessResetToken)
return ret0
}
// GetStatelessResetToken indicates an expected call of GetStatelessResetToken.
func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 any) *MockPacketHandlerManagerGetStatelessResetTokenCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0)
return &MockPacketHandlerManagerGetStatelessResetTokenCall{Call: call}
}
// MockPacketHandlerManagerGetStatelessResetTokenCall wrap *gomock.Call
type MockPacketHandlerManagerGetStatelessResetTokenCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockPacketHandlerManagerGetStatelessResetTokenCall) Return(arg0 protocol.StatelessResetToken) *MockPacketHandlerManagerGetStatelessResetTokenCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPacketHandlerManagerGetStatelessResetTokenCall) Do(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockPacketHandlerManagerGetStatelessResetTokenCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPacketHandlerManagerGetStatelessResetTokenCall) DoAndReturn(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockPacketHandlerManagerGetStatelessResetTokenCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Remove mocks base method.
func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) {
m.ctrl.T.Helper()

View file

@ -1,10 +1,6 @@
package quic
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"hash"
"io"
"net"
"sync"
@ -56,15 +52,12 @@ type packetHandlerMap struct {
deleteRetiredConnsAfter time.Duration
statelessResetMutex sync.Mutex
statelessResetHasher hash.Hash
logger utils.Logger
}
var _ packetHandlerManager = &packetHandlerMap{}
func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
func newPacketHandlerMap(enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
h := &packetHandlerMap{
closeChan: make(chan struct{}),
handlers: make(map[protocol.ConnectionID]packetHandler),
@ -73,9 +66,6 @@ func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePa
enqueueClosePacket: enqueueClosePacket,
logger: logger,
}
if key != nil {
h.statelessResetHasher = hmac.New(sha256.New, key[:])
}
if h.logger.Debug() {
go h.logUsage()
}
@ -236,20 +226,3 @@ func (h *packetHandlerMap) Close(e error) {
h.mutex.Unlock()
wg.Wait()
}
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
var token protocol.StatelessResetToken
if h.statelessResetHasher == nil {
// Return a random stateless reset token.
// This token will be sent in the server's transport parameters.
// By using a random token, an off-path attacker won't be able to disrupt the connection.
rand.Read(token[:])
return token
}
h.statelessResetMutex.Lock()
h.statelessResetHasher.Write(connID.Bytes())
copy(token[:], h.statelessResetHasher.Sum(nil))
h.statelessResetHasher.Reset()
h.statelessResetMutex.Unlock()
return token
}

View file

@ -14,7 +14,7 @@ import (
)
func TestPacketHandlerMapAddAndRemove(t *testing.T) {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
m := newPacketHandlerMap(nil, utils.DefaultLogger)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
h := &mockPacketHandler{}
require.True(t, m.Add(connID, h))
@ -36,7 +36,7 @@ func TestPacketHandlerMapAddAndRemove(t *testing.T) {
}
func TestPacketHandlerMapAddWithClientChosenConnID(t *testing.T) {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
m := newPacketHandlerMap(nil, utils.DefaultLogger)
h := &mockPacketHandler{}
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
@ -54,7 +54,7 @@ func TestPacketHandlerMapAddWithClientChosenConnID(t *testing.T) {
}
func TestPacketHandlerMapRetire(t *testing.T) {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
m := newPacketHandlerMap(nil, utils.DefaultLogger)
dur := scaleDuration(10 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
@ -76,7 +76,7 @@ func TestPacketHandlerMapRetire(t *testing.T) {
}
func TestPacketHandlerMapAddGetRemoveResetTokens(t *testing.T) {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
m := newPacketHandlerMap(nil, utils.DefaultLogger)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
handler := &mockPacketHandler{}
m.AddResetToken(token, handler)
@ -88,43 +88,12 @@ func TestPacketHandlerMapAddGetRemoveResetTokens(t *testing.T) {
require.False(t, ok)
}
func TestPacketHandlerMapGenerateStatelessResetToken(t *testing.T) {
t.Run("no key", func(t *testing.T) {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
b := make([]byte, 8)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
tokens := make(map[protocol.StatelessResetToken]struct{})
for i := 0; i < 100; i++ {
token := m.GetStatelessResetToken(connID)
require.NotZero(t, token)
if _, ok := tokens[token]; ok {
t.Fatalf("token %s already exists", token)
}
tokens[token] = struct{}{}
}
})
t.Run("with key", func(t *testing.T) {
var key StatelessResetKey
rand.Read(key[:])
m := newPacketHandlerMap(&key, nil, utils.DefaultLogger)
b := make([]byte, 8)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
token := m.GetStatelessResetToken(connID)
require.NotZero(t, token)
require.Equal(t, token, m.GetStatelessResetToken(connID))
// generate a new connection ID
rand.Read(b)
connID2 := protocol.ParseConnectionID(b)
require.NotEqual(t, token, m.GetStatelessResetToken(connID2))
})
}
func TestPacketHandlerMapReplaceWithLocalClosed(t *testing.T) {
var closePackets []closePacket
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
m := newPacketHandlerMap(
func(p closePacket) { closePackets = append(closePackets, p) },
utils.DefaultLogger,
)
dur := scaleDuration(10 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
@ -150,7 +119,10 @@ func TestPacketHandlerMapReplaceWithLocalClosed(t *testing.T) {
func TestPacketHandlerMapReplaceWithRemoteClosed(t *testing.T) {
var closePackets []closePacket
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
m := newPacketHandlerMap(
func(p closePacket) { closePackets = append(closePackets, p) },
utils.DefaultLogger,
)
dur := scaleDuration(50 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
@ -173,7 +145,7 @@ func TestPacketHandlerMapReplaceWithRemoteClosed(t *testing.T) {
}
func TestPacketHandlerMapClose(t *testing.T) {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
m := newPacketHandlerMap(nil, utils.DefaultLogger)
testErr := errors.New("shutdown")
const numConns = 10
destroyChan := make(chan error, 2*numConns)

View file

@ -73,6 +73,7 @@ type baseServer struct {
maxTokenAge time.Duration
connIDGenerator ConnectionIDGenerator
statelessResetter *statelessResetter
connHandler packetHandlerManager
onClose func()
@ -95,7 +96,7 @@ type baseServer struct {
protocol.ConnectionID, /* destination connection ID */
protocol.ConnectionID, /* source connection ID */
ConnectionIDGenerator,
protocol.StatelessResetToken,
*statelessResetter,
*Config,
*tls.Config,
*handshake.TokenGenerator,
@ -248,6 +249,7 @@ func newServer(
conn rawConn,
connHandler packetHandlerManager,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
connContext func(context.Context) context.Context,
tlsConf *tls.Config,
config *Config,
@ -268,6 +270,7 @@ func newServer(
maxTokenAge: maxTokenAge,
verifySourceAddress: verifySourceAddress,
connIDGenerator: connIDGenerator,
statelessResetter: statelessResetter,
connHandler: connHandler,
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
@ -707,7 +710,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
hdr.SrcConnectionID,
connID,
s.connIDGenerator,
s.connHandler.GetStatelessResetToken(connID),
s.statelessResetter,
config,
s.tlsConf,
s.tokenGenerator,

View file

@ -297,7 +297,7 @@ var _ = Describe("Server", func() {
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
_ ConnectionIDGenerator,
tokenP protocol.StatelessResetToken,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -314,7 +314,6 @@ var _ = Describe("Server", func() {
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
newConnID = srcConnID
Expect(tokenP).To(Equal(token))
conn.EXPECT().handlePacket(p)
conn.EXPECT().run().Do(func() error { close(run); return nil })
conn.EXPECT().Context().Return(context.Background())
@ -322,7 +321,6 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(connID)
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool {
Expect(cid).To(Equal(newConnID))
return true
@ -500,7 +498,7 @@ var _ = Describe("Server", func() {
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
_ ConnectionIDGenerator,
tokenP protocol.StatelessResetToken,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -517,7 +515,6 @@ var _ = Describe("Server", func() {
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
newConnID = srcConnID
Expect(tokenP).To(Equal(token))
conn.EXPECT().handlePacket(p)
conn.EXPECT().run().Do(func() error { close(run); return nil })
conn.EXPECT().Context().Return(context.Background())
@ -526,7 +523,6 @@ var _ = Describe("Server", func() {
}
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token),
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool {
Expect(c).To(Equal(newConnID))
return true
@ -553,7 +549,6 @@ var _ = Describe("Server", func() {
serv.verifySourceAddress = func(net.Addr) bool { return false }
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
acceptConn := make(chan struct{})
@ -569,7 +564,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -625,7 +620,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -645,7 +640,6 @@ var _ = Describe("Server", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
p := getInitial(connID)
phm.EXPECT().Get(connID)
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision
Expect(serv.handlePacketImpl(p)).To(BeTrue())
Eventually(done).Should(BeClosed())
@ -657,7 +651,6 @@ var _ = Describe("Server", func() {
serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() }
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
connChan := make(chan *MockQUICConn, 1)
@ -675,7 +668,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -737,7 +730,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -769,7 +762,6 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool {
close(done)
return true
@ -972,7 +964,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *statelessResetter,
conf *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -990,7 +982,6 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
@ -1040,7 +1031,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *statelessResetter,
conf *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -1057,7 +1048,6 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
@ -1111,7 +1101,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -1127,7 +1117,6 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
@ -1182,7 +1171,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -1198,7 +1187,6 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.baseServer.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
@ -1224,7 +1212,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -1244,7 +1232,6 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
conn := NewMockQUICConn(mockCtrl)
@ -1257,7 +1244,6 @@ var _ = Describe("Server", func() {
wg.Add(1)
rejected := make(chan struct{})
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(qerr.TransportErrorCode) {
@ -1284,7 +1270,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -1302,7 +1288,6 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.baseServer.handlePacket(p)
// make sure there are no Write calls on the packet conn
@ -1407,7 +1392,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -1433,7 +1418,6 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(connID)
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.handlePacket(initial)
Eventually(called).Should(BeClosed())

42
stateless_reset.go Normal file
View file

@ -0,0 +1,42 @@
package quic
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"hash"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
)
type statelessResetter struct {
mx sync.Mutex
h hash.Hash
}
// newStatelessRetter creates a new stateless reset generator.
// It is valid to use a nil key. In that case, a random key will be used.
// This makes is impossible for on-path attackers to shut down established connections.
func newStatelessResetter(key *StatelessResetKey) *statelessResetter {
var h hash.Hash
if key != nil {
h = hmac.New(sha256.New, key[:])
} else {
b := make([]byte, 32)
_, _ = rand.Read(b)
h = hmac.New(sha256.New, b)
}
return &statelessResetter{h: h}
}
func (r *statelessResetter) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
r.mx.Lock()
defer r.mx.Unlock()
var token protocol.StatelessResetToken
r.h.Write(connID.Bytes())
copy(token[:], r.h.Sum(nil))
r.h.Reset()
return token
}

42
stateless_reset_test.go Normal file
View file

@ -0,0 +1,42 @@
package quic
import (
"crypto/rand"
"testing"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/stretchr/testify/require"
)
func TestStatelessResetter(t *testing.T) {
t.Run("no key", func(t *testing.T) {
r1 := newStatelessResetter(nil)
r2 := newStatelessResetter(nil)
for i := 0; i < 100; i++ {
b := make([]byte, 15)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
t1 := r1.GetStatelessResetToken(connID)
t2 := r2.GetStatelessResetToken(connID)
require.NotZero(t, t1)
require.NotZero(t, t2)
require.NotEqual(t, t1, t2)
}
})
t.Run("with key", func(t *testing.T) {
var key StatelessResetKey
rand.Read(key[:])
m := newStatelessResetter(&key)
b := make([]byte, 8)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
token := m.GetStatelessResetToken(connID)
require.NotZero(t, token)
require.Equal(t, token, m.GetStatelessResetToken(connID))
// generate a new connection ID
rand.Read(b)
connID2 := protocol.ParseConnectionID(b)
require.NotEqual(t, token, m.GetStatelessResetToken(connID2))
})
}

View file

@ -116,6 +116,7 @@ type Transport struct {
// Set in init.
// If no ConnectionIDGenerator is set, this is set to a default.
connIDGenerator ConnectionIDGenerator
statelessResetter *statelessResetter
server *baseServer
@ -183,6 +184,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
t.conn,
t.handlerMap,
t.connIDGenerator,
t.statelessResetter,
t.ConnContext,
tlsConf,
conf,
@ -222,7 +224,17 @@ func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsCon
}
tlsConf = tlsConf.Clone()
setTLSConfigServerName(tlsConf, addr, host)
return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT)
return dial(
ctx,
newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger),
t.connIDGenerator,
t.statelessResetter,
t.handlerMap,
tlsConf,
conf,
onClose,
use0RTT,
)
}
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
@ -242,7 +254,7 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.logger = utils.DefaultLogger // TODO: make this configurable
t.conn = conn
if t.handlerMap == nil { // allows mocking the handlerMap in tests
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
t.handlerMap = newPacketHandlerMap(t.enqueueClosePacket, t.logger)
}
t.listening = make(chan struct{})
@ -268,6 +280,7 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.connIDLen = connIDLen
t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
}
t.statelessResetter = newStatelessResetter(t.StatelessResetKey)
go t.listen(conn)
go t.runSendQueue()
@ -478,7 +491,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) {
t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
return
}
token := t.handlerMap.GetStatelessResetToken(connID)
token := t.statelessResetter.GetStatelessResetToken(connID)
t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
rand.Read(data)

View file

@ -254,8 +254,6 @@ func TestTransportStatelessResetSending(t *testing.T) {
// but a stateless reset is sent for packets larger than MinStatelessResetSize
phm.EXPECT().Get(connID) // no handler for this connection ID
phm.EXPECT().GetByResetToken(gomock.Any())
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
phm.EXPECT().GetStatelessResetToken(connID).Return(token)
_, err = conn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr())
require.NoError(t, err)
conn.SetReadDeadline(time.Now().Add(time.Second))
@ -263,7 +261,8 @@ func TestTransportStatelessResetSending(t *testing.T) {
n, addr, err := conn.ReadFrom(p)
require.NoError(t, err)
require.Equal(t, addr, tr.Conn.LocalAddr())
require.Contains(t, string(p[:n]), string(token[:]))
srt := newStatelessResetter(tr.StatelessResetKey).GetStatelessResetToken(connID)
require.Contains(t, string(p[:n]), string(srt[:]))
}
func TestTransportDropsUnparseableQUICPackets(t *testing.T) {