mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-06 13:47:35 +03:00
simplify generation of stateless reset tokens (#4858)
This commit is contained in:
parent
9950b4c687
commit
62947d97f5
16 changed files with 224 additions and 233 deletions
16
client.go
16
client.go
|
@ -23,6 +23,7 @@ type client struct {
|
|||
config *Config
|
||||
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
statelessResetter *statelessResetter
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
|
||||
|
@ -137,13 +138,14 @@ func dial(
|
|||
ctx context.Context,
|
||||
conn sendConn,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetter *statelessResetter,
|
||||
packetHandlers packetHandlerManager,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
onClose func(),
|
||||
use0RTT bool,
|
||||
) (quicConn, error) {
|
||||
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
|
||||
c, err := newClient(conn, connIDGenerator, statelessResetter, config, tlsConf, onClose, use0RTT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -162,7 +164,15 @@ func dial(
|
|||
return c.conn, nil
|
||||
}
|
||||
|
||||
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
|
||||
func newClient(
|
||||
sendConn sendConn,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetter *statelessResetter,
|
||||
config *Config,
|
||||
tlsConf *tls.Config,
|
||||
onClose func(),
|
||||
use0RTT bool,
|
||||
) (*client, error) {
|
||||
srcConnID, err := connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -173,6 +183,7 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config
|
|||
}
|
||||
c := &client{
|
||||
connIDGenerator: connIDGenerator,
|
||||
statelessResetter: statelessResetter,
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
sendConn: sendConn,
|
||||
|
@ -197,6 +208,7 @@ func (c *client) dial(ctx context.Context) error {
|
|||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.connIDGenerator,
|
||||
c.statelessResetter,
|
||||
c.config,
|
||||
c.tlsConf,
|
||||
c.initialPacketNumber,
|
||||
|
|
|
@ -33,6 +33,7 @@ var _ = Describe("Client", func() {
|
|||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetToken *statelessResetter,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
|
@ -107,6 +108,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
|
@ -124,7 +126,15 @@ var _ = Describe("Client", func() {
|
|||
conn.EXPECT().HandshakeComplete().Return(c)
|
||||
return conn
|
||||
}
|
||||
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false)
|
||||
cl, err := newClient(
|
||||
packetConn,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
newStatelessResetter(nil),
|
||||
populateConfig(config),
|
||||
tlsConf,
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
|
@ -144,6 +154,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
|
@ -161,7 +172,15 @@ var _ = Describe("Client", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true)
|
||||
cl, err := newClient(
|
||||
packetConn,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
newStatelessResetter(nil),
|
||||
populateConfig(config),
|
||||
tlsConf,
|
||||
nil,
|
||||
true,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
|
@ -181,6 +200,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
|
@ -197,7 +217,13 @@ var _ = Describe("Client", func() {
|
|||
return conn
|
||||
}
|
||||
var closed bool
|
||||
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true)
|
||||
cl, err := newClient(
|
||||
packetConn,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
newStatelessResetter(nil),
|
||||
populateConfig(config), tlsConf, func() { closed = true },
|
||||
true,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
|
@ -266,6 +292,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
|
@ -309,6 +336,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.ConnectionID,
|
||||
connID protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
pn protocol.PacketNumber,
|
||||
|
|
|
@ -16,7 +16,7 @@ type connIDGenerator struct {
|
|||
initialClientDestConnID *protocol.ConnectionID // nil for the client
|
||||
|
||||
addConnectionID func(protocol.ConnectionID)
|
||||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
|
||||
statelessResetter *statelessResetter
|
||||
removeConnectionID func(protocol.ConnectionID)
|
||||
retireConnectionID func(protocol.ConnectionID)
|
||||
replaceWithClosed func([]protocol.ConnectionID, []byte)
|
||||
|
@ -27,7 +27,7 @@ func newConnIDGenerator(
|
|||
initialConnectionID protocol.ConnectionID,
|
||||
initialClientDestConnID *protocol.ConnectionID, // nil for the client
|
||||
addConnectionID func(protocol.ConnectionID),
|
||||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
|
||||
statelessResetter *statelessResetter,
|
||||
removeConnectionID func(protocol.ConnectionID),
|
||||
retireConnectionID func(protocol.ConnectionID),
|
||||
replaceWithClosed func([]protocol.ConnectionID, []byte),
|
||||
|
@ -38,7 +38,7 @@ func newConnIDGenerator(
|
|||
generator: generator,
|
||||
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
|
||||
addConnectionID: addConnectionID,
|
||||
getStatelessResetToken: getStatelessResetToken,
|
||||
statelessResetter: statelessResetter,
|
||||
removeConnectionID: removeConnectionID,
|
||||
retireConnectionID: retireConnectionID,
|
||||
replaceWithClosed: replaceWithClosed,
|
||||
|
@ -104,7 +104,7 @@ func (m *connIDGenerator) issueNewConnID() error {
|
|||
m.queueControlFrame(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: m.highestSeq + 1,
|
||||
ConnectionID: connID,
|
||||
StatelessResetToken: m.getStatelessResetToken(connID),
|
||||
StatelessResetToken: m.statelessResetter.GetStatelessResetToken(connID),
|
||||
})
|
||||
m.highestSeq++
|
||||
return nil
|
||||
|
|
|
@ -19,14 +19,11 @@ var _ = Describe("Connection ID Generator", func() {
|
|||
replacedWithClosed []protocol.ConnectionID
|
||||
queuedFrames []wire.Frame
|
||||
g *connIDGenerator
|
||||
statelessResetter *statelessResetter
|
||||
)
|
||||
initialConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
|
||||
initialClientDestConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc, 0xd, 0xe})
|
||||
|
||||
connIDToToken := func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
b := c.Bytes()[0]
|
||||
return protocol.StatelessResetToken{b, b, b, b, b, b, b, b, b, b, b, b, b, b, b, b}
|
||||
}
|
||||
statelessResetter = newStatelessResetter(nil)
|
||||
|
||||
BeforeEach(func() {
|
||||
addedConnIDs = nil
|
||||
|
@ -38,7 +35,7 @@ var _ = Describe("Connection ID Generator", func() {
|
|||
initialConnID,
|
||||
&initialClientDestConnID,
|
||||
func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) },
|
||||
connIDToToken,
|
||||
statelessResetter,
|
||||
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
|
||||
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
|
||||
func(cs []protocol.ConnectionID, _ []byte) { replacedWithClosed = append(replacedWithClosed, cs...) },
|
||||
|
@ -61,7 +58,7 @@ var _ = Describe("Connection ID Generator", func() {
|
|||
nf := f.(*wire.NewConnectionIDFrame)
|
||||
Expect(nf.SequenceNumber).To(BeEquivalentTo(i + 1))
|
||||
Expect(nf.ConnectionID.Len()).To(Equal(7))
|
||||
Expect(nf.StatelessResetToken).To(Equal(connIDToToken(nf.ConnectionID)))
|
||||
Expect(nf.StatelessResetToken).To(Equal(statelessResetter.GetStatelessResetToken(nf.ConnectionID)))
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -85,7 +85,6 @@ func (p *receivedPacket) Clone() *receivedPacket {
|
|||
|
||||
type connRunner interface {
|
||||
Add(protocol.ConnectionID, packetHandler) bool
|
||||
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
|
||||
Retire(protocol.ConnectionID)
|
||||
Remove(protocol.ConnectionID)
|
||||
ReplaceWithClosed([]protocol.ConnectionID, []byte)
|
||||
|
@ -225,7 +224,7 @@ var newConnection = func(
|
|||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetToken protocol.StatelessResetToken,
|
||||
statelessResetter *statelessResetter,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
tokenGenerator *handshake.TokenGenerator,
|
||||
|
@ -263,7 +262,7 @@ var newConnection = func(
|
|||
srcConnID,
|
||||
&clientDestConnID,
|
||||
func(connID protocol.ConnectionID) { runner.Add(connID, s) },
|
||||
runner.GetStatelessResetToken,
|
||||
statelessResetter,
|
||||
runner.Remove,
|
||||
runner.Retire,
|
||||
runner.ReplaceWithClosed,
|
||||
|
@ -282,6 +281,7 @@ var newConnection = func(
|
|||
s.logger,
|
||||
)
|
||||
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
|
||||
statelessResetToken := statelessResetter.GetStatelessResetToken(srcConnID)
|
||||
params := &wire.TransportParameters{
|
||||
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
||||
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
||||
|
@ -340,6 +340,7 @@ var newClientConnection = func(
|
|||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetter *statelessResetter,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
|
@ -372,7 +373,7 @@ var newClientConnection = func(
|
|||
srcConnID,
|
||||
nil,
|
||||
func(connID protocol.ConnectionID) { runner.Add(connID, s) },
|
||||
runner.GetStatelessResetToken,
|
||||
statelessResetter,
|
||||
runner.Remove,
|
||||
runner.Retire,
|
||||
runner.ReplaceWithClosed,
|
||||
|
|
|
@ -125,7 +125,7 @@ func newServerTestConnection(
|
|||
protocol.ConnectionID{},
|
||||
srcConnID,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
protocol.StatelessResetToken{},
|
||||
newStatelessResetter(nil),
|
||||
populateConfig(config),
|
||||
&tls.Config{},
|
||||
handshake.NewTokenGenerator(handshake.TokenProtectorKey{}),
|
||||
|
@ -180,6 +180,7 @@ func newClientTestConnection(
|
|||
destConnID,
|
||||
srcConnID,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
newStatelessResetter(nil),
|
||||
populateConfig(config),
|
||||
&tls.Config{ServerName: "quic-go.net"},
|
||||
0,
|
||||
|
|
|
@ -114,44 +114,6 @@ func (c *MockConnRunnerAddResetTokenCall) DoAndReturn(f func(protocol.StatelessR
|
|||
return c
|
||||
}
|
||||
|
||||
// GetStatelessResetToken mocks base method.
|
||||
func (m *MockConnRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0)
|
||||
ret0, _ := ret[0].(protocol.StatelessResetToken)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetStatelessResetToken indicates an expected call of GetStatelessResetToken.
|
||||
func (mr *MockConnRunnerMockRecorder) GetStatelessResetToken(arg0 any) *MockConnRunnerGetStatelessResetTokenCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockConnRunner)(nil).GetStatelessResetToken), arg0)
|
||||
return &MockConnRunnerGetStatelessResetTokenCall{Call: call}
|
||||
}
|
||||
|
||||
// MockConnRunnerGetStatelessResetTokenCall wrap *gomock.Call
|
||||
type MockConnRunnerGetStatelessResetTokenCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockConnRunnerGetStatelessResetTokenCall) Return(arg0 protocol.StatelessResetToken) *MockConnRunnerGetStatelessResetTokenCall {
|
||||
c.Call = c.Call.Return(arg0)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockConnRunnerGetStatelessResetTokenCall) Do(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockConnRunnerGetStatelessResetTokenCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockConnRunnerGetStatelessResetTokenCall) DoAndReturn(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockConnRunnerGetStatelessResetTokenCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Remove mocks base method.
|
||||
func (m *MockConnRunner) Remove(arg0 protocol.ConnectionID) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -266,44 +266,6 @@ func (c *MockPacketHandlerManagerGetByResetTokenCall) DoAndReturn(f func(protoco
|
|||
return c
|
||||
}
|
||||
|
||||
// GetStatelessResetToken mocks base method.
|
||||
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0)
|
||||
ret0, _ := ret[0].(protocol.StatelessResetToken)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetStatelessResetToken indicates an expected call of GetStatelessResetToken.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 any) *MockPacketHandlerManagerGetStatelessResetTokenCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0)
|
||||
return &MockPacketHandlerManagerGetStatelessResetTokenCall{Call: call}
|
||||
}
|
||||
|
||||
// MockPacketHandlerManagerGetStatelessResetTokenCall wrap *gomock.Call
|
||||
type MockPacketHandlerManagerGetStatelessResetTokenCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPacketHandlerManagerGetStatelessResetTokenCall) Return(arg0 protocol.StatelessResetToken) *MockPacketHandlerManagerGetStatelessResetTokenCall {
|
||||
c.Call = c.Call.Return(arg0)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPacketHandlerManagerGetStatelessResetTokenCall) Do(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockPacketHandlerManagerGetStatelessResetTokenCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPacketHandlerManagerGetStatelessResetTokenCall) DoAndReturn(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockPacketHandlerManagerGetStatelessResetTokenCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Remove mocks base method.
|
||||
func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"hash"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
@ -56,15 +52,12 @@ type packetHandlerMap struct {
|
|||
|
||||
deleteRetiredConnsAfter time.Duration
|
||||
|
||||
statelessResetMutex sync.Mutex
|
||||
statelessResetHasher hash.Hash
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ packetHandlerManager = &packetHandlerMap{}
|
||||
|
||||
func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
|
||||
func newPacketHandlerMap(enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
|
||||
h := &packetHandlerMap{
|
||||
closeChan: make(chan struct{}),
|
||||
handlers: make(map[protocol.ConnectionID]packetHandler),
|
||||
|
@ -73,9 +66,6 @@ func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePa
|
|||
enqueueClosePacket: enqueueClosePacket,
|
||||
logger: logger,
|
||||
}
|
||||
if key != nil {
|
||||
h.statelessResetHasher = hmac.New(sha256.New, key[:])
|
||||
}
|
||||
if h.logger.Debug() {
|
||||
go h.logUsage()
|
||||
}
|
||||
|
@ -236,20 +226,3 @@ func (h *packetHandlerMap) Close(e error) {
|
|||
h.mutex.Unlock()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
var token protocol.StatelessResetToken
|
||||
if h.statelessResetHasher == nil {
|
||||
// Return a random stateless reset token.
|
||||
// This token will be sent in the server's transport parameters.
|
||||
// By using a random token, an off-path attacker won't be able to disrupt the connection.
|
||||
rand.Read(token[:])
|
||||
return token
|
||||
}
|
||||
h.statelessResetMutex.Lock()
|
||||
h.statelessResetHasher.Write(connID.Bytes())
|
||||
copy(token[:], h.statelessResetHasher.Sum(nil))
|
||||
h.statelessResetHasher.Reset()
|
||||
h.statelessResetMutex.Unlock()
|
||||
return token
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
func TestPacketHandlerMapAddAndRemove(t *testing.T) {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
m := newPacketHandlerMap(nil, utils.DefaultLogger)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
h := &mockPacketHandler{}
|
||||
require.True(t, m.Add(connID, h))
|
||||
|
@ -36,7 +36,7 @@ func TestPacketHandlerMapAddAndRemove(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPacketHandlerMapAddWithClientChosenConnID(t *testing.T) {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
m := newPacketHandlerMap(nil, utils.DefaultLogger)
|
||||
h := &mockPacketHandler{}
|
||||
|
||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
|
@ -54,7 +54,7 @@ func TestPacketHandlerMapAddWithClientChosenConnID(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPacketHandlerMapRetire(t *testing.T) {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
m := newPacketHandlerMap(nil, utils.DefaultLogger)
|
||||
dur := scaleDuration(10 * time.Millisecond)
|
||||
m.deleteRetiredConnsAfter = dur
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
|
@ -76,7 +76,7 @@ func TestPacketHandlerMapRetire(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPacketHandlerMapAddGetRemoveResetTokens(t *testing.T) {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
m := newPacketHandlerMap(nil, utils.DefaultLogger)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
|
||||
handler := &mockPacketHandler{}
|
||||
m.AddResetToken(token, handler)
|
||||
|
@ -88,43 +88,12 @@ func TestPacketHandlerMapAddGetRemoveResetTokens(t *testing.T) {
|
|||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPacketHandlerMapGenerateStatelessResetToken(t *testing.T) {
|
||||
t.Run("no key", func(t *testing.T) {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
connID := protocol.ParseConnectionID(b)
|
||||
tokens := make(map[protocol.StatelessResetToken]struct{})
|
||||
for i := 0; i < 100; i++ {
|
||||
token := m.GetStatelessResetToken(connID)
|
||||
require.NotZero(t, token)
|
||||
if _, ok := tokens[token]; ok {
|
||||
t.Fatalf("token %s already exists", token)
|
||||
}
|
||||
tokens[token] = struct{}{}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with key", func(t *testing.T) {
|
||||
var key StatelessResetKey
|
||||
rand.Read(key[:])
|
||||
m := newPacketHandlerMap(&key, nil, utils.DefaultLogger)
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
connID := protocol.ParseConnectionID(b)
|
||||
token := m.GetStatelessResetToken(connID)
|
||||
require.NotZero(t, token)
|
||||
require.Equal(t, token, m.GetStatelessResetToken(connID))
|
||||
// generate a new connection ID
|
||||
rand.Read(b)
|
||||
connID2 := protocol.ParseConnectionID(b)
|
||||
require.NotEqual(t, token, m.GetStatelessResetToken(connID2))
|
||||
})
|
||||
}
|
||||
|
||||
func TestPacketHandlerMapReplaceWithLocalClosed(t *testing.T) {
|
||||
var closePackets []closePacket
|
||||
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
|
||||
m := newPacketHandlerMap(
|
||||
func(p closePacket) { closePackets = append(closePackets, p) },
|
||||
utils.DefaultLogger,
|
||||
)
|
||||
dur := scaleDuration(10 * time.Millisecond)
|
||||
m.deleteRetiredConnsAfter = dur
|
||||
|
||||
|
@ -150,7 +119,10 @@ func TestPacketHandlerMapReplaceWithLocalClosed(t *testing.T) {
|
|||
|
||||
func TestPacketHandlerMapReplaceWithRemoteClosed(t *testing.T) {
|
||||
var closePackets []closePacket
|
||||
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
|
||||
m := newPacketHandlerMap(
|
||||
func(p closePacket) { closePackets = append(closePackets, p) },
|
||||
utils.DefaultLogger,
|
||||
)
|
||||
dur := scaleDuration(50 * time.Millisecond)
|
||||
m.deleteRetiredConnsAfter = dur
|
||||
|
||||
|
@ -173,7 +145,7 @@ func TestPacketHandlerMapReplaceWithRemoteClosed(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPacketHandlerMapClose(t *testing.T) {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
m := newPacketHandlerMap(nil, utils.DefaultLogger)
|
||||
testErr := errors.New("shutdown")
|
||||
const numConns = 10
|
||||
destroyChan := make(chan error, 2*numConns)
|
||||
|
|
|
@ -73,6 +73,7 @@ type baseServer struct {
|
|||
maxTokenAge time.Duration
|
||||
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
statelessResetter *statelessResetter
|
||||
connHandler packetHandlerManager
|
||||
onClose func()
|
||||
|
||||
|
@ -95,7 +96,7 @@ type baseServer struct {
|
|||
protocol.ConnectionID, /* destination connection ID */
|
||||
protocol.ConnectionID, /* source connection ID */
|
||||
ConnectionIDGenerator,
|
||||
protocol.StatelessResetToken,
|
||||
*statelessResetter,
|
||||
*Config,
|
||||
*tls.Config,
|
||||
*handshake.TokenGenerator,
|
||||
|
@ -248,6 +249,7 @@ func newServer(
|
|||
conn rawConn,
|
||||
connHandler packetHandlerManager,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetter *statelessResetter,
|
||||
connContext func(context.Context) context.Context,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
|
@ -268,6 +270,7 @@ func newServer(
|
|||
maxTokenAge: maxTokenAge,
|
||||
verifySourceAddress: verifySourceAddress,
|
||||
connIDGenerator: connIDGenerator,
|
||||
statelessResetter: statelessResetter,
|
||||
connHandler: connHandler,
|
||||
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
|
||||
errorChan: make(chan struct{}),
|
||||
|
@ -707,7 +710,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
|||
hdr.SrcConnectionID,
|
||||
connID,
|
||||
s.connIDGenerator,
|
||||
s.connHandler.GetStatelessResetToken(connID),
|
||||
s.statelessResetter,
|
||||
config,
|
||||
s.tlsConf,
|
||||
s.tokenGenerator,
|
||||
|
|
|
@ -297,7 +297,7 @@ var _ = Describe("Server", func() {
|
|||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
tokenP protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -314,7 +314,6 @@ var _ = Describe("Server", func() {
|
|||
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
|
||||
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
|
||||
newConnID = srcConnID
|
||||
Expect(tokenP).To(Equal(token))
|
||||
conn.EXPECT().handlePacket(p)
|
||||
conn.EXPECT().run().Do(func() error { close(run); return nil })
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
|
@ -322,7 +321,6 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token)
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool {
|
||||
Expect(cid).To(Equal(newConnID))
|
||||
return true
|
||||
|
@ -500,7 +498,7 @@ var _ = Describe("Server", func() {
|
|||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
tokenP protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -517,7 +515,6 @@ var _ = Describe("Server", func() {
|
|||
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
|
||||
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
|
||||
newConnID = srcConnID
|
||||
Expect(tokenP).To(Equal(token))
|
||||
conn.EXPECT().handlePacket(p)
|
||||
conn.EXPECT().run().Do(func() error { close(run); return nil })
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
|
@ -526,7 +523,6 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().Get(connID),
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token),
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool {
|
||||
Expect(c).To(Equal(newConnID))
|
||||
return true
|
||||
|
@ -553,7 +549,6 @@ var _ = Describe("Server", func() {
|
|||
serv.verifySourceAddress = func(net.Addr) bool { return false }
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
|
||||
|
||||
acceptConn := make(chan struct{})
|
||||
|
@ -569,7 +564,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -625,7 +620,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -645,7 +640,6 @@ var _ = Describe("Server", func() {
|
|||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
p := getInitial(connID)
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision
|
||||
Expect(serv.handlePacketImpl(p)).To(BeTrue())
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -657,7 +651,6 @@ var _ = Describe("Server", func() {
|
|||
serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() }
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
|
||||
|
||||
connChan := make(chan *MockQUICConn, 1)
|
||||
|
@ -675,7 +668,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -737,7 +730,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -769,7 +762,6 @@ var _ = Describe("Server", func() {
|
|||
|
||||
done := make(chan struct{})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool {
|
||||
close(done)
|
||||
return true
|
||||
|
@ -972,7 +964,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
conf *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -990,7 +982,6 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
||||
serv.handleInitialImpl(
|
||||
receivedPacket{buffer: getPacketBuffer()},
|
||||
|
@ -1040,7 +1031,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
conf *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -1057,7 +1048,6 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
||||
serv.handleInitialImpl(
|
||||
receivedPacket{buffer: getPacketBuffer()},
|
||||
|
@ -1111,7 +1101,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -1127,7 +1117,6 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
||||
serv.handleInitialImpl(
|
||||
receivedPacket{buffer: getPacketBuffer()},
|
||||
|
@ -1182,7 +1171,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -1198,7 +1187,6 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
||||
serv.baseServer.handleInitialImpl(
|
||||
receivedPacket{buffer: getPacketBuffer()},
|
||||
|
@ -1224,7 +1212,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -1244,7 +1232,6 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
|
||||
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
|
@ -1257,7 +1244,6 @@ var _ = Describe("Server", func() {
|
|||
wg.Add(1)
|
||||
|
||||
rejected := make(chan struct{})
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(qerr.TransportErrorCode) {
|
||||
|
@ -1284,7 +1270,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -1302,7 +1288,6 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
||||
serv.baseServer.handlePacket(p)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
|
@ -1407,7 +1392,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -1433,7 +1418,6 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
||||
serv.handlePacket(initial)
|
||||
Eventually(called).Should(BeClosed())
|
||||
|
|
42
stateless_reset.go
Normal file
42
stateless_reset.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"hash"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type statelessResetter struct {
|
||||
mx sync.Mutex
|
||||
h hash.Hash
|
||||
}
|
||||
|
||||
// newStatelessRetter creates a new stateless reset generator.
|
||||
// It is valid to use a nil key. In that case, a random key will be used.
|
||||
// This makes is impossible for on-path attackers to shut down established connections.
|
||||
func newStatelessResetter(key *StatelessResetKey) *statelessResetter {
|
||||
var h hash.Hash
|
||||
if key != nil {
|
||||
h = hmac.New(sha256.New, key[:])
|
||||
} else {
|
||||
b := make([]byte, 32)
|
||||
_, _ = rand.Read(b)
|
||||
h = hmac.New(sha256.New, b)
|
||||
}
|
||||
return &statelessResetter{h: h}
|
||||
}
|
||||
|
||||
func (r *statelessResetter) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
|
||||
var token protocol.StatelessResetToken
|
||||
r.h.Write(connID.Bytes())
|
||||
copy(token[:], r.h.Sum(nil))
|
||||
r.h.Reset()
|
||||
return token
|
||||
}
|
42
stateless_reset_test.go
Normal file
42
stateless_reset_test.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStatelessResetter(t *testing.T) {
|
||||
t.Run("no key", func(t *testing.T) {
|
||||
r1 := newStatelessResetter(nil)
|
||||
r2 := newStatelessResetter(nil)
|
||||
for i := 0; i < 100; i++ {
|
||||
b := make([]byte, 15)
|
||||
rand.Read(b)
|
||||
connID := protocol.ParseConnectionID(b)
|
||||
t1 := r1.GetStatelessResetToken(connID)
|
||||
t2 := r2.GetStatelessResetToken(connID)
|
||||
require.NotZero(t, t1)
|
||||
require.NotZero(t, t2)
|
||||
require.NotEqual(t, t1, t2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with key", func(t *testing.T) {
|
||||
var key StatelessResetKey
|
||||
rand.Read(key[:])
|
||||
m := newStatelessResetter(&key)
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
connID := protocol.ParseConnectionID(b)
|
||||
token := m.GetStatelessResetToken(connID)
|
||||
require.NotZero(t, token)
|
||||
require.Equal(t, token, m.GetStatelessResetToken(connID))
|
||||
// generate a new connection ID
|
||||
rand.Read(b)
|
||||
connID2 := protocol.ParseConnectionID(b)
|
||||
require.NotEqual(t, token, m.GetStatelessResetToken(connID2))
|
||||
})
|
||||
}
|
19
transport.go
19
transport.go
|
@ -116,6 +116,7 @@ type Transport struct {
|
|||
// Set in init.
|
||||
// If no ConnectionIDGenerator is set, this is set to a default.
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
statelessResetter *statelessResetter
|
||||
|
||||
server *baseServer
|
||||
|
||||
|
@ -183,6 +184,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
|
|||
t.conn,
|
||||
t.handlerMap,
|
||||
t.connIDGenerator,
|
||||
t.statelessResetter,
|
||||
t.ConnContext,
|
||||
tlsConf,
|
||||
conf,
|
||||
|
@ -222,7 +224,17 @@ func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsCon
|
|||
}
|
||||
tlsConf = tlsConf.Clone()
|
||||
setTLSConfigServerName(tlsConf, addr, host)
|
||||
return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT)
|
||||
return dial(
|
||||
ctx,
|
||||
newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger),
|
||||
t.connIDGenerator,
|
||||
t.statelessResetter,
|
||||
t.handlerMap,
|
||||
tlsConf,
|
||||
conf,
|
||||
onClose,
|
||||
use0RTT,
|
||||
)
|
||||
}
|
||||
|
||||
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
||||
|
@ -242,7 +254,7 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
|||
t.logger = utils.DefaultLogger // TODO: make this configurable
|
||||
t.conn = conn
|
||||
if t.handlerMap == nil { // allows mocking the handlerMap in tests
|
||||
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
|
||||
t.handlerMap = newPacketHandlerMap(t.enqueueClosePacket, t.logger)
|
||||
}
|
||||
t.listening = make(chan struct{})
|
||||
|
||||
|
@ -268,6 +280,7 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
|||
t.connIDLen = connIDLen
|
||||
t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
|
||||
}
|
||||
t.statelessResetter = newStatelessResetter(t.StatelessResetKey)
|
||||
|
||||
go t.listen(conn)
|
||||
go t.runSendQueue()
|
||||
|
@ -478,7 +491,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) {
|
|||
t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
|
||||
return
|
||||
}
|
||||
token := t.handlerMap.GetStatelessResetToken(connID)
|
||||
token := t.statelessResetter.GetStatelessResetToken(connID)
|
||||
t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
|
||||
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
|
||||
rand.Read(data)
|
||||
|
|
|
@ -254,8 +254,6 @@ func TestTransportStatelessResetSending(t *testing.T) {
|
|||
// but a stateless reset is sent for packets larger than MinStatelessResetSize
|
||||
phm.EXPECT().Get(connID) // no handler for this connection ID
|
||||
phm.EXPECT().GetByResetToken(gomock.Any())
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
phm.EXPECT().GetStatelessResetToken(connID).Return(token)
|
||||
_, err = conn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
|
@ -263,7 +261,8 @@ func TestTransportStatelessResetSending(t *testing.T) {
|
|||
n, addr, err := conn.ReadFrom(p)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, addr, tr.Conn.LocalAddr())
|
||||
require.Contains(t, string(p[:n]), string(token[:]))
|
||||
srt := newStatelessResetter(tr.StatelessResetKey).GetStatelessResetToken(connID)
|
||||
require.Contains(t, string(p[:n]), string(srt[:]))
|
||||
}
|
||||
|
||||
func TestTransportDropsUnparseableQUICPackets(t *testing.T) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue