mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
2767 lines
99 KiB
Go
2767 lines
99 KiB
Go
package quic
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"errors"
|
|
"net"
|
|
"net/netip"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go/internal/ackhandler"
|
|
"github.com/quic-go/quic-go/internal/handshake"
|
|
"github.com/quic-go/quic-go/internal/mocks"
|
|
mockackhandler "github.com/quic-go/quic-go/internal/mocks/ackhandler"
|
|
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
|
"github.com/quic-go/quic-go/internal/protocol"
|
|
"github.com/quic-go/quic-go/internal/qerr"
|
|
"github.com/quic-go/quic-go/internal/utils"
|
|
"github.com/quic-go/quic-go/internal/wire"
|
|
"github.com/quic-go/quic-go/logging"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
type testConnectionOpt func(*connection)
|
|
|
|
func connectionOptCryptoSetup(cs *mocks.MockCryptoSetup) testConnectionOpt {
|
|
return func(conn *connection) { conn.cryptoStreamHandler = cs }
|
|
}
|
|
|
|
func connectionOptStreamManager(sm *MockStreamManager) testConnectionOpt {
|
|
return func(conn *connection) { conn.streamsMap = sm }
|
|
}
|
|
|
|
func connectionOptConnFlowController(cfc *mocks.MockConnectionFlowController) testConnectionOpt {
|
|
return func(conn *connection) { conn.connFlowController = cfc }
|
|
}
|
|
|
|
func connectionOptTracer(tr *logging.ConnectionTracer) testConnectionOpt {
|
|
return func(conn *connection) { conn.tracer = tr }
|
|
}
|
|
|
|
func connectionOptSentPacketHandler(sph ackhandler.SentPacketHandler) testConnectionOpt {
|
|
return func(conn *connection) { conn.sentPacketHandler = sph }
|
|
}
|
|
|
|
func connectionOptReceivedPacketHandler(rph ackhandler.ReceivedPacketHandler) testConnectionOpt {
|
|
return func(conn *connection) { conn.receivedPacketHandler = rph }
|
|
}
|
|
|
|
func connectionOptUnpacker(u unpacker) testConnectionOpt {
|
|
return func(conn *connection) { conn.unpacker = u }
|
|
}
|
|
|
|
func connectionOptSender(s sender) testConnectionOpt {
|
|
return func(conn *connection) { conn.sendQueue = s }
|
|
}
|
|
|
|
func connectionOptHandshakeConfirmed() testConnectionOpt {
|
|
return func(conn *connection) {
|
|
conn.handshakeComplete = true
|
|
conn.handshakeConfirmed = true
|
|
}
|
|
}
|
|
|
|
func connectionOptRTT(rtt time.Duration) testConnectionOpt {
|
|
var rttStats utils.RTTStats
|
|
rttStats.UpdateRTT(rtt, 0)
|
|
return func(conn *connection) { conn.rttStats = &rttStats }
|
|
}
|
|
|
|
func connectionOptRetrySrcConnID(rcid protocol.ConnectionID) testConnectionOpt {
|
|
return func(conn *connection) { conn.retrySrcConnID = &rcid }
|
|
}
|
|
|
|
type testConnection struct {
|
|
conn *connection
|
|
connRunner *MockConnRunner
|
|
sendConn *MockSendConn
|
|
packer *MockPacker
|
|
destConnID protocol.ConnectionID
|
|
srcConnID protocol.ConnectionID
|
|
}
|
|
|
|
func newServerTestConnection(
|
|
t *testing.T,
|
|
mockCtrl *gomock.Controller,
|
|
config *Config,
|
|
gso bool,
|
|
opts ...testConnectionOpt,
|
|
) *testConnection {
|
|
if mockCtrl == nil {
|
|
mockCtrl = gomock.NewController(t)
|
|
}
|
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
|
|
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
|
|
connRunner := NewMockConnRunner(mockCtrl)
|
|
sendConn := NewMockSendConn(mockCtrl)
|
|
sendConn.EXPECT().capabilities().Return(connCapabilities{GSO: gso}).AnyTimes()
|
|
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
|
|
sendConn.EXPECT().LocalAddr().Return(localAddr).AnyTimes()
|
|
packer := NewMockPacker(mockCtrl)
|
|
b := make([]byte, 12)
|
|
rand.Read(b)
|
|
origDestConnID := protocol.ParseConnectionID(b[:6])
|
|
srcConnID := protocol.ParseConnectionID(b[6:12])
|
|
ctx, cancel := context.WithCancelCause(context.Background())
|
|
if config == nil {
|
|
config = &Config{DisablePathMTUDiscovery: true}
|
|
}
|
|
conn := newConnection(
|
|
ctx,
|
|
cancel,
|
|
sendConn,
|
|
connRunner,
|
|
origDestConnID,
|
|
nil,
|
|
protocol.ConnectionID{},
|
|
protocol.ConnectionID{},
|
|
srcConnID,
|
|
&protocol.DefaultConnectionIDGenerator{},
|
|
newStatelessResetter(nil),
|
|
populateConfig(config),
|
|
&tls.Config{},
|
|
handshake.NewTokenGenerator(handshake.TokenProtectorKey{}),
|
|
false,
|
|
nil,
|
|
utils.DefaultLogger,
|
|
protocol.Version1,
|
|
).(*connection)
|
|
conn.packer = packer
|
|
for _, opt := range opts {
|
|
opt(conn)
|
|
}
|
|
return &testConnection{
|
|
conn: conn,
|
|
connRunner: connRunner,
|
|
sendConn: sendConn,
|
|
packer: packer,
|
|
destConnID: origDestConnID,
|
|
srcConnID: srcConnID,
|
|
}
|
|
}
|
|
|
|
func newClientTestConnection(
|
|
t *testing.T,
|
|
mockCtrl *gomock.Controller,
|
|
config *Config,
|
|
enable0RTT bool,
|
|
opts ...testConnectionOpt,
|
|
) *testConnection {
|
|
if mockCtrl == nil {
|
|
mockCtrl = gomock.NewController(t)
|
|
}
|
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
|
|
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
|
|
connRunner := NewMockConnRunner(mockCtrl)
|
|
sendConn := NewMockSendConn(mockCtrl)
|
|
sendConn.EXPECT().capabilities().Return(connCapabilities{}).AnyTimes()
|
|
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
|
|
sendConn.EXPECT().LocalAddr().Return(localAddr).AnyTimes()
|
|
packer := NewMockPacker(mockCtrl)
|
|
b := make([]byte, 12)
|
|
rand.Read(b)
|
|
destConnID := protocol.ParseConnectionID(b[:6])
|
|
srcConnID := protocol.ParseConnectionID(b[6:12])
|
|
if config == nil {
|
|
config = &Config{DisablePathMTUDiscovery: true}
|
|
}
|
|
conn := newClientConnection(
|
|
context.Background(),
|
|
sendConn,
|
|
connRunner,
|
|
destConnID,
|
|
srcConnID,
|
|
&protocol.DefaultConnectionIDGenerator{},
|
|
newStatelessResetter(nil),
|
|
populateConfig(config),
|
|
&tls.Config{ServerName: "quic-go.net"},
|
|
0,
|
|
enable0RTT,
|
|
false,
|
|
nil,
|
|
utils.DefaultLogger,
|
|
protocol.Version1,
|
|
).(*connection)
|
|
conn.packer = packer
|
|
for _, opt := range opts {
|
|
opt(conn)
|
|
}
|
|
return &testConnection{
|
|
conn: conn,
|
|
connRunner: connRunner,
|
|
sendConn: sendConn,
|
|
packer: packer,
|
|
destConnID: destConnID,
|
|
srcConnID: srcConnID,
|
|
}
|
|
}
|
|
|
|
func TestConnectionHandleReceiveStreamFrames(t *testing.T) {
|
|
const streamID protocol.StreamID = 5
|
|
now := time.Now()
|
|
connID := protocol.ConnectionID{}
|
|
f := &wire.StreamFrame{StreamID: streamID, Data: []byte("foobar")}
|
|
rsf := &wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 42, FinalSize: 1337}
|
|
sdbf := &wire.StreamDataBlockedFrame{StreamID: streamID, MaximumStreamData: 1337}
|
|
|
|
t.Run("for existing and new streams", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
str := NewMockReceiveStreamI(mockCtrl)
|
|
// STREAM frame
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
|
|
str.EXPECT().handleStreamFrame(f, now)
|
|
require.NoError(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now))
|
|
// RESET_STREAM frame
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
|
|
str.EXPECT().handleResetStreamFrame(rsf, now)
|
|
require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now))
|
|
// STREAM_DATA_BLOCKED frames are not passed to the stream
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
|
|
require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now))
|
|
})
|
|
|
|
t.Run("for closed streams", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
// STREAM frame
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
|
|
require.NoError(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now))
|
|
// RESET_STREAM frame
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
|
|
require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now))
|
|
// STREAM_DATA_BLOCKED frames are not passed to the stream
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
|
|
require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now))
|
|
})
|
|
|
|
t.Run("for invalid streams", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
testErr := errors.New("test err")
|
|
// STREAM frame
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
|
|
require.ErrorIs(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now), testErr)
|
|
// RESET_STREAM frame
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
|
|
require.ErrorIs(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now), testErr)
|
|
// STREAM_DATA_BLOCKED frames are not passed to the stream
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
|
|
require.ErrorIs(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now), testErr)
|
|
})
|
|
}
|
|
|
|
func TestConnectionHandleSendStreamFrames(t *testing.T) {
|
|
const streamID protocol.StreamID = 3
|
|
now := time.Now()
|
|
connID := protocol.ConnectionID{}
|
|
ss := &wire.StopSendingFrame{StreamID: streamID, ErrorCode: 42}
|
|
msd := &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337}
|
|
|
|
t.Run("for existing and new streams", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
str := NewMockSendStreamI(mockCtrl)
|
|
// STOP_SENDING frame
|
|
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
|
|
str.EXPECT().handleStopSendingFrame(ss)
|
|
require.NoError(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now))
|
|
// MAX_STREAM_DATA frame
|
|
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
|
|
str.EXPECT().updateSendWindow(msd.MaximumStreamData)
|
|
require.NoError(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now))
|
|
})
|
|
|
|
t.Run("for closed streams", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
// STOP_SENDING frame
|
|
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
|
|
require.NoError(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now))
|
|
// MAX_STREAM_DATA frame
|
|
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
|
|
require.NoError(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now))
|
|
})
|
|
|
|
t.Run("for invalid streams", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
testErr := errors.New("test err")
|
|
// STOP_SENDING frame
|
|
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, testErr)
|
|
require.ErrorIs(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now), testErr)
|
|
// MAX_STREAM_DATA frame
|
|
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, testErr)
|
|
require.ErrorIs(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now), testErr)
|
|
})
|
|
}
|
|
|
|
func TestConnectionHandleStreamNumFrames(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
now := time.Now()
|
|
connID := protocol.ConnectionID{}
|
|
// MAX_STREAMS frame
|
|
msf := &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10}
|
|
streamsMap.EXPECT().HandleMaxStreamsFrame(msf)
|
|
require.NoError(t, tc.conn.handleFrame(msf, protocol.Encryption1RTT, connID, now))
|
|
// STREAMS_BLOCKED frame
|
|
tc.conn.handleFrame(&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 1}, protocol.Encryption1RTT, connID, now)
|
|
}
|
|
|
|
func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
connFC := mocks.NewMockConnectionFlowController(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC))
|
|
now := time.Now()
|
|
connID := protocol.ConnectionID{}
|
|
// MAX_DATA frame
|
|
connFC.EXPECT().UpdateSendWindow(protocol.ByteCount(1337))
|
|
require.NoError(t, tc.conn.handleFrame(&wire.MaxDataFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now))
|
|
// DATA_BLOCKED frame
|
|
require.NoError(t, tc.conn.handleFrame(&wire.DataBlockedFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now))
|
|
}
|
|
|
|
func TestConnectionOpenStreams(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
|
|
// using OpenStream
|
|
mstr := NewMockStreamI(mockCtrl)
|
|
streamsMap.EXPECT().OpenStream().Return(mstr, nil)
|
|
str, err := tc.conn.OpenStream()
|
|
require.NoError(t, err)
|
|
require.Equal(t, mstr, str)
|
|
|
|
// using OpenStreamSync
|
|
streamsMap.EXPECT().OpenStreamSync(context.Background()).Return(mstr, nil)
|
|
str, err = tc.conn.OpenStreamSync(context.Background())
|
|
require.NoError(t, err)
|
|
require.Equal(t, mstr, str)
|
|
|
|
// using OpenUniStream
|
|
streamsMap.EXPECT().OpenUniStream().Return(mstr, nil)
|
|
ustr, err := tc.conn.OpenUniStream()
|
|
require.NoError(t, err)
|
|
require.Equal(t, mstr, ustr)
|
|
|
|
// using OpenUniStreamSync
|
|
streamsMap.EXPECT().OpenUniStreamSync(context.Background()).Return(mstr, nil)
|
|
ustr, err = tc.conn.OpenUniStreamSync(context.Background())
|
|
require.NoError(t, err)
|
|
require.Equal(t, mstr, ustr)
|
|
}
|
|
|
|
func TestConnectionAcceptStreams(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
|
|
// bidirectional streams
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
|
defer cancel()
|
|
mstr := NewMockStreamI(mockCtrl)
|
|
streamsMap.EXPECT().AcceptStream(ctx).Return(mstr, nil)
|
|
str, err := tc.conn.AcceptStream(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, mstr, str)
|
|
|
|
// unidirectional streams
|
|
streamsMap.EXPECT().AcceptUniStream(ctx).Return(mstr, nil)
|
|
ustr, err := tc.conn.AcceptUniStream(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, mstr, ustr)
|
|
}
|
|
|
|
func TestConnectionServerInvalidFrames(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false)
|
|
|
|
for _, test := range []struct {
|
|
Name string
|
|
Frame wire.Frame
|
|
}{
|
|
{Name: "NEW_TOKEN", Frame: &wire.NewTokenFrame{Token: []byte("foobar")}},
|
|
{Name: "HANDSHAKE_DONE", Frame: &wire.HandshakeDoneFrame{}},
|
|
{Name: "PATH_RESPONSE", Frame: &wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}},
|
|
} {
|
|
t.Run(test.Name, func(t *testing.T) {
|
|
require.ErrorIs(t,
|
|
tc.conn.handleFrame(test.Frame, protocol.Encryption1RTT, protocol.ConnectionID{}, time.Now()),
|
|
&qerr.TransportError{ErrorCode: qerr.ProtocolViolation},
|
|
)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConnectionTransportError(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
errChan := make(chan error, 1)
|
|
expectedErr := &qerr.TransportError{
|
|
ErrorCode: 1337,
|
|
FrameType: 42,
|
|
ErrorMessage: "test error",
|
|
}
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
b := getPacketBuffer()
|
|
b.Data = append(b.Data, []byte("connection close")...)
|
|
tc.packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil)
|
|
tc.sendConn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any())
|
|
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ClosedConnection(expectedErr),
|
|
tracer.EXPECT().Close(),
|
|
)
|
|
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.closeLocal(expectedErr)
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, expectedErr)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// further calls to CloseWithError don't do anything
|
|
tc.conn.CloseWithError(42, "another error")
|
|
}
|
|
|
|
func TestConnectionApplicationClose(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
errChan := make(chan error, 1)
|
|
expectedErr := &qerr.ApplicationError{
|
|
ErrorCode: 1337,
|
|
ErrorMessage: "test error",
|
|
}
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
b := getPacketBuffer()
|
|
b.Data = append(b.Data, []byte("connection close")...)
|
|
tc.packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil)
|
|
tc.sendConn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any())
|
|
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ClosedConnection(expectedErr),
|
|
tracer.EXPECT().Close(),
|
|
)
|
|
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.CloseWithError(1337, "test error")
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, expectedErr)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// further calls to CloseWithError don't do anything
|
|
tc.conn.CloseWithError(42, "another error")
|
|
}
|
|
|
|
func TestConnectionStatelessReset(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
errChan := make(chan error, 1)
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ClosedConnection(&StatelessResetError{}),
|
|
tracer.EXPECT().Close(),
|
|
)
|
|
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.destroy(&StatelessResetError{})
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, &StatelessResetError{})
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func getLongHeaderPacket(t *testing.T, extHdr *wire.ExtendedHeader, data []byte) receivedPacket {
|
|
t.Helper()
|
|
b, err := extHdr.Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
return receivedPacket{
|
|
data: append(b, data...),
|
|
buffer: getPacketBuffer(),
|
|
rcvTime: time.Now(),
|
|
}
|
|
}
|
|
|
|
func getShortHeaderPacket(t *testing.T, connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) receivedPacket {
|
|
t.Helper()
|
|
b, err := wire.AppendShortHeader(nil, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
|
|
require.NoError(t, err)
|
|
return receivedPacket{
|
|
data: append(b, data...),
|
|
buffer: getPacketBuffer(),
|
|
rcvTime: time.Now(),
|
|
}
|
|
}
|
|
|
|
func TestConnectionServerInvalidPackets(t *testing.T) {
|
|
t.Run("Retry", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
|
|
p := getLongHeaderPacket(t, &wire.ExtendedHeader{Header: wire.Header{
|
|
Type: protocol.PacketTypeRetry,
|
|
DestConnectionID: tc.conn.origDestConnID,
|
|
SrcConnectionID: tc.srcConnID,
|
|
Version: tc.conn.version,
|
|
Token: []byte("foobar"),
|
|
}}, make([]byte, 16) /* Retry integrity tag */)
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
|
|
wasProcessed, err := tc.conn.handlePacketImpl(p)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
})
|
|
|
|
t.Run("version negotiation", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
|
|
b := wire.ComposeVersionNegotiation(
|
|
protocol.ArbitraryLenConnectionID(tc.srcConnID.Bytes()),
|
|
protocol.ArbitraryLenConnectionID(tc.conn.origDestConnID.Bytes()),
|
|
[]Version{Version1},
|
|
)
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket)
|
|
wasProcessed, err := tc.conn.handlePacketImpl(receivedPacket{data: b, buffer: getPacketBuffer()})
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
})
|
|
|
|
t.Run("unsupported version", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
|
|
p := getLongHeaderPacket(t, &wire.ExtendedHeader{
|
|
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: 1234},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}, nil)
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnsupportedVersion)
|
|
wasProcessed, err := tc.conn.handlePacketImpl(p)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
})
|
|
|
|
t.Run("invalid header", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
|
|
p := getLongHeaderPacket(t, &wire.ExtendedHeader{
|
|
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: Version1},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}, nil)
|
|
p.data[0] ^= 0x40 // unset the QUIC bit
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError)
|
|
wasProcessed, err := tc.conn.handlePacketImpl(p)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
})
|
|
}
|
|
|
|
func TestConnectionClientDrop0RTT(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
|
|
p := getLongHeaderPacket(t, &wire.ExtendedHeader{
|
|
Header: wire.Header{Type: protocol.PacketType0RTT, Length: 2, Version: protocol.Version1},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}, nil)
|
|
tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
|
|
wasProcessed, err := tc.conn.handlePacketImpl(p)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
}
|
|
|
|
func TestConnectionUnpacking(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptReceivedPacketHandler(rph),
|
|
connectionOptUnpacker(unpacker),
|
|
connectionOptTracer(tr),
|
|
)
|
|
|
|
// receive a long header packet
|
|
hdr := &wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
DestConnectionID: tc.srcConnID,
|
|
Version: protocol.Version1,
|
|
Length: 1,
|
|
},
|
|
PacketNumber: 0x37,
|
|
PacketNumberLen: protocol.PacketNumberLen1,
|
|
}
|
|
unpackedHdr := *hdr
|
|
unpackedHdr.PacketNumber = 0x1337
|
|
packet := getLongHeaderPacket(t, hdr, nil)
|
|
packet.ecn = protocol.ECNCE
|
|
rcvTime := time.Now().Add(-10 * time.Second)
|
|
packet.rcvTime = rcvTime
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
|
encryptionLevel: protocol.EncryptionInitial,
|
|
hdr: &unpackedHdr,
|
|
data: []byte{0}, // one PADDING frame
|
|
}, nil)
|
|
gomock.InOrder(
|
|
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial),
|
|
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNCE, protocol.EncryptionInitial, rcvTime, false),
|
|
)
|
|
|
|
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECNCE, []logging.Frame{})
|
|
wasProcessed, err := tc.conn.handlePacketImpl(packet)
|
|
require.NoError(t, err)
|
|
require.True(t, wasProcessed)
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// receive a duplicate of this packet
|
|
packet = getLongHeaderPacket(t, hdr, nil)
|
|
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial).Return(true)
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
|
encryptionLevel: protocol.EncryptionInitial,
|
|
hdr: &unpackedHdr,
|
|
data: []byte{0}, // one PADDING frame
|
|
}, nil)
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.PacketNumber(0x1337), protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate)
|
|
wasProcessed, err = tc.conn.handlePacketImpl(packet)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// receive a short header packet
|
|
packet = getShortHeaderPacket(t, tc.srcConnID, 0x37, nil)
|
|
packet.ecn = protocol.ECT1
|
|
packet.rcvTime = rcvTime
|
|
gomock.InOrder(
|
|
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT),
|
|
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECT1, protocol.Encryption1RTT, rcvTime, false),
|
|
)
|
|
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
|
|
protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil,
|
|
)
|
|
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{})
|
|
wasProcessed, err = tc.conn.handlePacketImpl(packet)
|
|
require.NoError(t, err)
|
|
require.True(t, wasProcessed)
|
|
}
|
|
|
|
func TestConnectionUnpackCoalescedPacket(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptReceivedPacketHandler(rph),
|
|
connectionOptUnpacker(unpacker),
|
|
connectionOptTracer(tr),
|
|
)
|
|
hdr1 := &wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
DestConnectionID: tc.srcConnID,
|
|
Version: protocol.Version1,
|
|
Length: 1,
|
|
},
|
|
PacketNumber: 37,
|
|
PacketNumberLen: protocol.PacketNumberLen1,
|
|
}
|
|
hdr2 := &wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
DestConnectionID: tc.srcConnID,
|
|
Version: protocol.Version1,
|
|
Length: 1,
|
|
},
|
|
PacketNumber: 38,
|
|
PacketNumberLen: protocol.PacketNumberLen1,
|
|
}
|
|
// add a packet with a different source connection ID
|
|
incorrectSrcConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc})
|
|
hdr3 := &wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
DestConnectionID: incorrectSrcConnID,
|
|
Version: protocol.Version1,
|
|
Length: 1,
|
|
},
|
|
PacketNumber: 0x42,
|
|
PacketNumberLen: protocol.PacketNumberLen1,
|
|
}
|
|
unpackedHdr1 := *hdr1
|
|
unpackedHdr1.PacketNumber = 1337
|
|
unpackedHdr2 := *hdr2
|
|
unpackedHdr2.PacketNumber = 1338
|
|
|
|
packet := getLongHeaderPacket(t, hdr1, nil)
|
|
packet2 := getLongHeaderPacket(t, hdr2, nil)
|
|
packet3 := getLongHeaderPacket(t, hdr3, nil)
|
|
packet.data = append(packet.data, packet2.data...)
|
|
packet.data = append(packet.data, packet3.data...)
|
|
packet.ecn = protocol.ECT1
|
|
rcvTime := time.Now()
|
|
packet.rcvTime = rcvTime
|
|
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
|
encryptionLevel: protocol.EncryptionInitial,
|
|
hdr: &unpackedHdr1,
|
|
data: []byte{0}, // one PADDING frame
|
|
}, nil)
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
|
encryptionLevel: protocol.EncryptionHandshake,
|
|
hdr: &unpackedHdr2,
|
|
data: []byte{1}, // one PING frame
|
|
}, nil)
|
|
gomock.InOrder(
|
|
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(1337), protocol.EncryptionInitial),
|
|
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(1337), protocol.ECT1, protocol.EncryptionInitial, rcvTime, false),
|
|
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(1338), protocol.EncryptionHandshake),
|
|
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(1338), protocol.ECT1, protocol.EncryptionHandshake, rcvTime, true),
|
|
)
|
|
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial)
|
|
rph.EXPECT().DropPackets(protocol.EncryptionInitial)
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{}),
|
|
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{&wire.PingFrame{}}),
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(packet3.data)), logging.PacketDropUnknownConnectionID),
|
|
)
|
|
wasProcessed, err := tc.conn.handlePacketImpl(packet)
|
|
require.NoError(t, err)
|
|
require.True(t, wasProcessed)
|
|
}
|
|
|
|
func TestConnectionUnpackFailuresFatal(t *testing.T) {
|
|
t.Run("other errors", func(t *testing.T) {
|
|
require.ErrorIs(t,
|
|
testConnectionUnpackFailureFatal(t, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}),
|
|
&qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError},
|
|
)
|
|
})
|
|
|
|
t.Run("invalid reserved bits", func(t *testing.T) {
|
|
require.ErrorIs(t,
|
|
testConnectionUnpackFailureFatal(t, wire.ErrInvalidReservedBits),
|
|
&qerr.TransportError{ErrorCode: qerr.ProtocolViolation},
|
|
)
|
|
})
|
|
}
|
|
|
|
func testConnectionUnpackFailureFatal(t *testing.T, unpackErr error) error {
|
|
mockCtrl := gomock.NewController(t)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptUnpacker(unpacker),
|
|
)
|
|
|
|
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
|
|
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, unpackErr)
|
|
tc.packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tc.conn.handlePacket(getShortHeaderPacket(t, tc.srcConnID, 0x42, nil))
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.Error(t, err)
|
|
return err
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestConnectionUnpackFailureDropped(t *testing.T) {
|
|
t.Run("keys dropped", func(t *testing.T) {
|
|
testConnectionUnpackFailureDropped(t, handshake.ErrKeysDropped, logging.PacketDropKeyUnavailable)
|
|
})
|
|
|
|
t.Run("decryption failed", func(t *testing.T) {
|
|
testConnectionUnpackFailureDropped(t, handshake.ErrDecryptionFailed, logging.PacketDropPayloadDecryptError)
|
|
})
|
|
|
|
t.Run("header parse error", func(t *testing.T) {
|
|
testErr := errors.New("foo")
|
|
testConnectionUnpackFailureDropped(t, &headerParseError{err: testErr}, logging.PacketDropHeaderParseError)
|
|
})
|
|
}
|
|
|
|
func testConnectionUnpackFailureDropped(t *testing.T, unpackErr error, packetDropReason logging.PacketDropReason) {
|
|
mockCtrl := gomock.NewController(t)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptUnpacker(unpacker),
|
|
connectionOptTracer(tr),
|
|
)
|
|
|
|
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, unpackErr)
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
done := make(chan struct{})
|
|
tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.InvalidPacketNumber, gomock.Any(), packetDropReason).Do(
|
|
func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason) {
|
|
close(done)
|
|
},
|
|
)
|
|
tc.conn.handlePacket(getShortHeaderPacket(t, tc.srcConnID, 0x42, nil))
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tracer.EXPECT().ClosedConnection(gomock.Any())
|
|
tracer.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case <-errChan:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionMaxUnprocessedPackets(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
done := make(chan struct{})
|
|
|
|
for i := protocol.PacketNumber(0); i < protocol.MaxConnUnprocessedPackets; i++ {
|
|
// nothing here should block
|
|
tc.conn.handlePacket(receivedPacket{data: []byte("foobar")})
|
|
}
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, logging.ByteCount(6), logging.PacketDropDOSPrevention).Do(func(logging.PacketType, logging.PacketNumber, logging.ByteCount, logging.PacketDropReason) {
|
|
close(done)
|
|
})
|
|
tc.conn.handlePacket(receivedPacket{data: []byte("foobar")})
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionRemoteClose(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
mockStreamManager := NewMockStreamManager(mockCtrl)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptStreamManager(mockStreamManager),
|
|
connectionOptTracer(tr),
|
|
connectionOptUnpacker(unpacker),
|
|
)
|
|
ccf, err := (&wire.ConnectionCloseFrame{
|
|
ErrorCode: uint64(qerr.StreamLimitError),
|
|
ReasonPhrase: "foobar",
|
|
}).Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2, protocol.KeyPhaseBit(0), ccf, nil)
|
|
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
|
|
expectedErr := &qerr.TransportError{ErrorCode: qerr.StreamLimitError, Remote: true}
|
|
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
|
|
streamErrChan := make(chan error, 1)
|
|
mockStreamManager.EXPECT().CloseWithError(gomock.Any()).Do(func(e error) { streamErrChan <- e })
|
|
tracerErrChan := make(chan error, 1)
|
|
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { tracerErrChan <- e })
|
|
tracer.EXPECT().Close()
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
p := getShortHeaderPacket(t, tc.srcConnID, 1, []byte("encrypted"))
|
|
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, expectedErr)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
select {
|
|
case err := <-tracerErrChan:
|
|
require.ErrorIs(t, err, expectedErr)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
select {
|
|
case err := <-streamErrChan:
|
|
require.ErrorIs(t, err, expectedErr)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionIdleTimeoutDuringHandshake(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
&Config{HandshakeIdleTimeout: scaleDuration(25 * time.Millisecond)},
|
|
false,
|
|
connectionOptTracer(tr),
|
|
)
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ClosedConnection(&IdleTimeoutError{}),
|
|
tracer.EXPECT().Close(),
|
|
)
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, &IdleTimeoutError{})
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionHandshakeIdleTimeout(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
&Config{HandshakeIdleTimeout: scaleDuration(25 * time.Millisecond)},
|
|
false,
|
|
connectionOptTracer(tr),
|
|
func(c *connection) { c.creationTime = time.Now().Add(-10 * time.Second) },
|
|
)
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ClosedConnection(&HandshakeTimeoutError{}),
|
|
tracer.EXPECT().Close(),
|
|
)
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, &HandshakeTimeoutError{})
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionTransportParameters(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
streamManager := NewMockStreamManager(mockCtrl)
|
|
connFC := mocks.NewMockConnectionFlowController(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
connectionOptStreamManager(streamManager),
|
|
connectionOptConnFlowController(connFC),
|
|
)
|
|
tracer.EXPECT().ReceivedTransportParameters(gomock.Any())
|
|
params := &wire.TransportParameters{
|
|
MaxIdleTimeout: 90 * time.Second,
|
|
InitialMaxStreamDataBidiLocal: 0x5000,
|
|
InitialMaxData: 0x5000,
|
|
ActiveConnectionIDLimit: 3,
|
|
// marshaling always sets it to this value
|
|
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
|
|
OriginalDestinationConnectionID: tc.destConnID,
|
|
}
|
|
streamManager.EXPECT().UpdateLimits(params)
|
|
connFC.EXPECT().UpdateSendWindow(params.InitialMaxData)
|
|
require.NoError(t, tc.conn.handleTransportParameters(params))
|
|
}
|
|
|
|
func TestConnectionTransportParameterValidationFailureServer(t *testing.T) {
|
|
tc := newServerTestConnection(t, nil, nil, false)
|
|
err := tc.conn.handleTransportParameters(&wire.TransportParameters{
|
|
InitialSourceConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
|
})
|
|
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
|
|
assert.ErrorContains(t, err, "expected initial_source_connection_id to equal")
|
|
}
|
|
|
|
func TestConnectionTransportParameterValidationFailureClient(t *testing.T) {
|
|
t.Run("initial_source_connection_id", func(t *testing.T) {
|
|
tc := newClientTestConnection(t, nil, nil, false)
|
|
err := tc.conn.handleTransportParameters(&wire.TransportParameters{
|
|
InitialSourceConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
|
})
|
|
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
|
|
assert.ErrorContains(t, err, "expected initial_source_connection_id to equal")
|
|
})
|
|
|
|
t.Run("original_destination_connection_id", func(t *testing.T) {
|
|
tc := newClientTestConnection(t, nil, nil, false)
|
|
err := tc.conn.handleTransportParameters(&wire.TransportParameters{
|
|
InitialSourceConnectionID: tc.destConnID,
|
|
OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
|
})
|
|
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
|
|
assert.ErrorContains(t, err, "expected original_destination_connection_id to equal")
|
|
})
|
|
|
|
t.Run("retry_source_connection_id if no retry", func(t *testing.T) {
|
|
tc := newClientTestConnection(t, nil, nil, false)
|
|
rcid := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
|
params := &wire.TransportParameters{
|
|
InitialSourceConnectionID: tc.destConnID,
|
|
OriginalDestinationConnectionID: tc.destConnID,
|
|
RetrySourceConnectionID: &rcid,
|
|
}
|
|
err := tc.conn.handleTransportParameters(params)
|
|
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
|
|
assert.ErrorContains(t, err, "received retry_source_connection_id, although no Retry was performed")
|
|
})
|
|
|
|
t.Run("retry_source_connection_id missing", func(t *testing.T) {
|
|
tc := newClientTestConnection(t,
|
|
nil,
|
|
nil,
|
|
false,
|
|
connectionOptRetrySrcConnID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})),
|
|
)
|
|
params := &wire.TransportParameters{
|
|
InitialSourceConnectionID: tc.destConnID,
|
|
OriginalDestinationConnectionID: tc.destConnID,
|
|
}
|
|
err := tc.conn.handleTransportParameters(params)
|
|
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
|
|
assert.ErrorContains(t, err, "missing retry_source_connection_id")
|
|
})
|
|
|
|
t.Run("retry_source_connection_id incorrect", func(t *testing.T) {
|
|
tc := newClientTestConnection(t,
|
|
nil,
|
|
nil,
|
|
false,
|
|
connectionOptRetrySrcConnID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})),
|
|
)
|
|
wrongCID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
|
params := &wire.TransportParameters{
|
|
InitialSourceConnectionID: tc.destConnID,
|
|
OriginalDestinationConnectionID: tc.destConnID,
|
|
RetrySourceConnectionID: &wrongCID,
|
|
}
|
|
err := tc.conn.handleTransportParameters(params)
|
|
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
|
|
assert.ErrorContains(t, err, "expected retry_source_connection_id to equal")
|
|
})
|
|
}
|
|
|
|
func TestConnectionHandshakeServer(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
cs := mocks.NewMockCryptoSetup(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newServerTestConnection(
|
|
t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptCryptoSetup(cs),
|
|
connectionOptUnpacker(unpacker),
|
|
)
|
|
|
|
// the state transition is driven by processing of a CRYPTO frame
|
|
hdr := &wire.ExtendedHeader{
|
|
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}
|
|
data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
|
|
cs.EXPECT().DiscardInitialKeys()
|
|
tc.connRunner.EXPECT().Retire(gomock.Any())
|
|
gomock.InOrder(
|
|
cs.EXPECT().StartHandshake(gomock.Any()),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
|
|
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil,
|
|
),
|
|
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
cs.EXPECT().SetHandshakeConfirmed(),
|
|
cs.EXPECT().GetSessionTicket().Return([]byte("session ticket"), nil),
|
|
)
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
p := getLongHeaderPacket(t, hdr, nil)
|
|
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
|
|
|
|
select {
|
|
case <-tc.conn.HandshakeComplete():
|
|
case <-tc.conn.Context().Done():
|
|
t.Fatal("connection context done")
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
var foundSessionTicket, foundHandshakeDone, foundNewToken bool
|
|
frames, _, _ := tc.conn.framer.Append(nil, nil, protocol.MaxByteCount, time.Now(), protocol.Version1)
|
|
for _, frame := range frames {
|
|
switch f := frame.Frame.(type) {
|
|
case *wire.CryptoFrame:
|
|
assert.Equal(t, []byte("session ticket"), f.Data)
|
|
foundSessionTicket = true
|
|
case *wire.HandshakeDoneFrame:
|
|
foundHandshakeDone = true
|
|
case *wire.NewTokenFrame:
|
|
assert.NotEmpty(t, f.Token)
|
|
foundNewToken = true
|
|
}
|
|
}
|
|
assert.True(t, foundSessionTicket)
|
|
assert.True(t, foundHandshakeDone)
|
|
assert.True(t, foundNewToken)
|
|
|
|
// test teardown
|
|
cs.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionHandshakeClient(t *testing.T) {
|
|
t.Run("without preferred address", func(t *testing.T) {
|
|
testConnectionHandshakeClient(t, false)
|
|
})
|
|
t.Run("with preferred address", func(t *testing.T) {
|
|
testConnectionHandshakeClient(t, true)
|
|
})
|
|
}
|
|
|
|
func testConnectionHandshakeClient(t *testing.T, usePreferredAddress bool) {
|
|
mockCtrl := gomock.NewController(t)
|
|
cs := mocks.NewMockCryptoSetup(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptCryptoSetup(cs), connectionOptUnpacker(unpacker))
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
|
|
// the state transition is driven by processing of a CRYPTO frame
|
|
hdr := &wire.ExtendedHeader{
|
|
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}
|
|
data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
|
|
tp := &wire.TransportParameters{
|
|
OriginalDestinationConnectionID: tc.destConnID,
|
|
MaxIdleTimeout: time.Hour,
|
|
}
|
|
preferredAddressConnID := protocol.ParseConnectionID([]byte{10, 8, 6, 4})
|
|
preferredAddressResetToken := protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}
|
|
if usePreferredAddress {
|
|
tp.PreferredAddress = &wire.PreferredAddress{
|
|
IPv4: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 42),
|
|
IPv6: netip.AddrPortFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), 13),
|
|
ConnectionID: preferredAddressConnID,
|
|
StatelessResetToken: preferredAddressResetToken,
|
|
}
|
|
}
|
|
|
|
packedFirstPacket := make(chan struct{})
|
|
gomock.InOrder(
|
|
cs.EXPECT().StartHandshake(gomock.Any()),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(
|
|
func(b bool, bc protocol.ByteCount, t time.Time, v protocol.Version) (*coalescedPacket, error) {
|
|
close(packedFirstPacket)
|
|
return &coalescedPacket{buffer: getPacketBuffer(), longHdrPackets: []*longHeaderPacket{{header: hdr}}}, nil
|
|
},
|
|
),
|
|
// initial keys are dropped when the first handshake packet is sent
|
|
cs.EXPECT().DiscardInitialKeys(),
|
|
// no more data to send
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
|
|
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil,
|
|
),
|
|
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedTransportParameters, TransportParameters: tp}),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
)
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes()
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
select {
|
|
case <-packedFirstPacket:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
p := getLongHeaderPacket(t, hdr, nil)
|
|
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
|
|
|
|
select {
|
|
case <-tc.conn.HandshakeComplete():
|
|
case <-tc.conn.Context().Done():
|
|
t.Fatal("connection context done")
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
require.True(t, mockCtrl.Satisfied())
|
|
// the handshake isn't confirmed until we receive a HANDSHAKE_DONE frame from the server
|
|
|
|
data, err = (&wire.HandshakeDoneFrame{}).Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
done := make(chan struct{})
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes()
|
|
gomock.InOrder(
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
|
|
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.Encryption1RTT, data: data}, nil,
|
|
),
|
|
cs.EXPECT().SetHandshakeConfirmed(),
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
|
|
close(done)
|
|
return shortHeaderPacket{}, errNothingToPack
|
|
},
|
|
),
|
|
)
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
|
|
p = getLongHeaderPacket(t, hdr, nil)
|
|
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
if usePreferredAddress {
|
|
tc.connRunner.EXPECT().AddResetToken(preferredAddressResetToken, gomock.Any())
|
|
}
|
|
nextConnID := tc.conn.connIDManager.Get()
|
|
if usePreferredAddress {
|
|
require.Equal(t, preferredAddressConnID, nextConnID)
|
|
}
|
|
|
|
// test teardown
|
|
cs.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
if usePreferredAddress {
|
|
tc.connRunner.EXPECT().RemoveResetToken(preferredAddressResetToken)
|
|
}
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnection0RTTTransportParameters(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
cs := mocks.NewMockCryptoSetup(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptCryptoSetup(cs), connectionOptUnpacker(unpacker))
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
|
|
// the state transition is driven by processing of a CRYPTO frame
|
|
hdr := &wire.ExtendedHeader{
|
|
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}
|
|
data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
|
|
restored := &wire.TransportParameters{
|
|
ActiveConnectionIDLimit: 3,
|
|
InitialMaxData: 0x5000,
|
|
InitialMaxStreamDataBidiLocal: 0x5000,
|
|
InitialMaxStreamDataBidiRemote: 1000,
|
|
InitialMaxStreamDataUni: 1000,
|
|
MaxBidiStreamNum: 500,
|
|
MaxUniStreamNum: 500,
|
|
}
|
|
new := *restored
|
|
new.MaxBidiStreamNum-- // the server is not allowed to reduce the limit
|
|
new.OriginalDestinationConnectionID = tc.destConnID
|
|
|
|
packedFirstPacket := make(chan struct{})
|
|
gomock.InOrder(
|
|
cs.EXPECT().StartHandshake(gomock.Any()),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventRestoredTransportParameters, TransportParameters: restored}),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(
|
|
func(b bool, bc protocol.ByteCount, t time.Time, v protocol.Version) (*coalescedPacket, error) {
|
|
close(packedFirstPacket)
|
|
return &coalescedPacket{buffer: getPacketBuffer(), longHdrPackets: []*longHeaderPacket{{header: hdr}}}, nil
|
|
},
|
|
),
|
|
// initial keys are dropped when the first handshake packet is sent
|
|
cs.EXPECT().DiscardInitialKeys(),
|
|
// no more data to send
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
|
|
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil,
|
|
),
|
|
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedTransportParameters, TransportParameters: &new}),
|
|
cs.EXPECT().ConnectionState().Return(handshake.ConnectionState{Used0RTT: true}),
|
|
// cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
cs.EXPECT().Close(),
|
|
)
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes()
|
|
tc.packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
|
|
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
select {
|
|
case <-packedFirstPacket:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
p := getLongHeaderPacket(t, hdr, nil)
|
|
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
|
|
require.ErrorContains(t, err, "server sent reduced limits after accepting 0-RTT data")
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionReceivePrioritization(t *testing.T) {
|
|
t.Run("handshake complete", func(t *testing.T) {
|
|
counter := testConnectionReceivePrioritization(t, true)
|
|
require.Equal(t, 10, counter)
|
|
})
|
|
|
|
// before handshake completion, we trigger packing of a new packet every time we receive a packet
|
|
t.Run("handshake not complete", func(t *testing.T) {
|
|
counter := testConnectionReceivePrioritization(t, false)
|
|
require.Equal(t, 1, counter)
|
|
})
|
|
}
|
|
|
|
func testConnectionReceivePrioritization(t *testing.T, handshakeComplete bool) int {
|
|
mockCtrl := gomock.NewController(t)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
opts := []testConnectionOpt{connectionOptUnpacker(unpacker)}
|
|
if handshakeComplete {
|
|
opts = append(opts, connectionOptHandshakeConfirmed())
|
|
}
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
opts...,
|
|
)
|
|
|
|
var counter int
|
|
var packedFirst bool
|
|
done := make(chan struct{})
|
|
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
|
|
if !packedFirst {
|
|
counter++
|
|
}
|
|
return protocol.PacketNumber(counter), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0, 1} /* PADDING, PING */, nil
|
|
},
|
|
).AnyTimes()
|
|
switch handshakeComplete {
|
|
case false:
|
|
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(b bool, bc protocol.ByteCount, t time.Time, v protocol.Version) (*coalescedPacket, error) {
|
|
if !packedFirst {
|
|
packedFirst = true
|
|
close(done)
|
|
}
|
|
return nil, nil
|
|
},
|
|
).AnyTimes()
|
|
case true:
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(b *packetBuffer, bc protocol.ByteCount, t time.Time, v protocol.Version) (shortHeaderPacket, error) {
|
|
if !packedFirst {
|
|
packedFirst = true
|
|
close(done)
|
|
}
|
|
return shortHeaderPacket{}, errNothingToPack
|
|
},
|
|
).AnyTimes()
|
|
}
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
for i := 0; i < 10; i++ {
|
|
tc.conn.handlePacket(getShortHeaderPacket(t, tc.srcConnID, protocol.PacketNumber(i), []byte("foobar")))
|
|
}
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
return counter
|
|
}
|
|
|
|
func TestConnectionPacketBuffering(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
cs := mocks.NewMockCryptoSetup(mockCtrl)
|
|
tracer, tr := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptUnpacker(unpacker),
|
|
connectionOptCryptoSetup(cs),
|
|
connectionOptTracer(tracer),
|
|
)
|
|
|
|
tr.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tr.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
tr.EXPECT().DroppedEncryptionLevel(gomock.Any())
|
|
cs.EXPECT().DiscardInitialKeys()
|
|
|
|
hdr1 := wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
DestConnectionID: tc.srcConnID,
|
|
SrcConnectionID: tc.destConnID,
|
|
Length: 8,
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: protocol.PacketNumberLen1,
|
|
PacketNumber: 1,
|
|
}
|
|
hdr2 := hdr1
|
|
hdr2.PacketNumber = 2
|
|
cs.EXPECT().StartHandshake(gomock.Any())
|
|
buffered := make(chan struct{})
|
|
gomock.InOrder(
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable),
|
|
tr.EXPECT().BufferedPacket(logging.PacketTypeHandshake, gomock.Any()),
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable),
|
|
tr.EXPECT().BufferedPacket(logging.PacketTypeHandshake, gomock.Any()).Do(
|
|
func(logging.PacketType, logging.ByteCount) { close(buffered) },
|
|
),
|
|
)
|
|
|
|
tc.conn.handlePacket(getLongHeaderPacket(t, &hdr1, []byte("packet1")))
|
|
tc.conn.handlePacket(getLongHeaderPacket(t, &hdr2, []byte("packet2")))
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
select {
|
|
case <-buffered:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// Now send another packet.
|
|
// In reality, this packet would contain a CRYPTO frame that advances the TLS handshake
|
|
// such that new keys become available.
|
|
var packets []string
|
|
hdr3 := hdr1
|
|
hdr3.PacketNumber = 3
|
|
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
|
unpacked := make(chan struct{})
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedReadKeys})
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
|
|
|
|
gomock.InOrder(
|
|
// packet 3 contains a CRYPTO frame and triggers the keys to become available
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
|
|
packets = append(packets, string(data[len(data)-7:]))
|
|
cf := &wire.CryptoFrame{Data: []byte("foobar")}
|
|
b, _ := cf.Append(nil, protocol.Version1)
|
|
return &unpackedPacket{hdr: &hdr3, encryptionLevel: protocol.EncryptionHandshake, data: b}, nil
|
|
},
|
|
),
|
|
cs.EXPECT().HandleMessage(gomock.Any(), gomock.Any()),
|
|
tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
// packet 1 dequeued from the buffer
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
|
|
packets = append(packets, string(data[len(data)-7:]))
|
|
return &unpackedPacket{hdr: &hdr1, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil
|
|
},
|
|
),
|
|
tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
// packet 2 dequeued from the buffer
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
|
|
packets = append(packets, string(data[len(data)-7:]))
|
|
close(unpacked)
|
|
return &unpackedPacket{hdr: &hdr2, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil
|
|
},
|
|
),
|
|
tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
)
|
|
|
|
tc.conn.handlePacket(getLongHeaderPacket(t, &hdr3, []byte("packet3")))
|
|
|
|
select {
|
|
case <-unpacked:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// packet3 triggered the keys to become available
|
|
// packet1 and packet2 are processed from the buffer in order
|
|
require.Equal(t, []string{"packet3", "packet1", "packet2"}, packets)
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
cs.EXPECT().Close()
|
|
tr.EXPECT().ClosedConnection(gomock.Any())
|
|
tr.EXPECT().Close()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionPacketPacing(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
sender := NewMockSender(mockCtrl)
|
|
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptSentPacketHandler(sph),
|
|
connectionOptSender(sender),
|
|
connectionOptHandshakeConfirmed(),
|
|
// set a fixed RTT, so that the idle timeout doesn't interfere with this test
|
|
connectionOptRTT(10*time.Second),
|
|
)
|
|
sender.EXPECT().Run()
|
|
|
|
step := scaleDuration(50 * time.Millisecond)
|
|
|
|
sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes()
|
|
gomock.InOrder(
|
|
// 1. allow 2 packets to be sent
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
|
|
// 2. become pacing limited for 25ms
|
|
sph.EXPECT().TimeUntilSend().DoAndReturn(func() time.Time { return time.Now().Add(step) }),
|
|
// 3. send another packet
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
|
|
// 4. become pacing limited for 25ms...
|
|
sph.EXPECT().TimeUntilSend().DoAndReturn(func() time.Time { return time.Now().Add(step) }),
|
|
// ... but this time we're still pacing limited when waking up.
|
|
// In this case, we can only send an ACK.
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
|
|
// 5. stop the test by becoming pacing limited forever
|
|
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)),
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
)
|
|
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
|
|
for i := 0; i < 3; i++ {
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), Version1).DoAndReturn(
|
|
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
|
|
buf.Data = append(buf.Data, []byte("packet"+strconv.Itoa(i+1))...)
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i + 1)}, nil
|
|
},
|
|
)
|
|
}
|
|
tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(_ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
|
buf := getPacketBuffer()
|
|
buf.Data = []byte("ack")
|
|
return shortHeaderPacket{PacketNumber: 1}, buf, nil
|
|
},
|
|
)
|
|
sender.EXPECT().WouldBlock().AnyTimes()
|
|
|
|
type sentPacket struct {
|
|
time time.Time
|
|
data []byte
|
|
}
|
|
sendChan := make(chan sentPacket, 10)
|
|
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
|
|
sendChan <- sentPacket{time: time.Now(), data: b.Data}
|
|
}).Times(4)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
var times []time.Time
|
|
for i := 0; i < 3; i++ {
|
|
select {
|
|
case b := <-sendChan:
|
|
require.Equal(t, []byte("packet"+strconv.Itoa(i+1)), b.data)
|
|
times = append(times, b.time)
|
|
case <-time.After(scaleDuration(time.Second)):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
select {
|
|
case b := <-sendChan:
|
|
require.Equal(t, []byte("ack"), b.data)
|
|
times = append(times, b.time)
|
|
case <-time.After(scaleDuration(time.Second)):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
require.InDelta(t, times[0].Sub(times[1]).Seconds(), 0, scaleDuration(10*time.Millisecond).Seconds())
|
|
require.InDelta(t, times[2].Sub(times[1]).Seconds(), step.Seconds(), scaleDuration(20*time.Millisecond).Seconds())
|
|
require.InDelta(t, times[3].Sub(times[2]).Seconds(), step.Seconds(), scaleDuration(20*time.Millisecond).Seconds())
|
|
|
|
time.Sleep(scaleDuration(step)) // make sure that no more packets are sent
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// test teardown
|
|
sender.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case <-sendChan:
|
|
t.Fatal("should not have sent any more packets")
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionIdleTimeout(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
&Config{MaxIdleTimeout: time.Second},
|
|
false,
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptSentPacketHandler(sph),
|
|
connectionOptRTT(time.Millisecond),
|
|
)
|
|
// the idle timeout is set when the transport parameters are received
|
|
idleTimeout := scaleDuration(50 * time.Millisecond)
|
|
require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
|
|
MaxIdleTimeout: idleTimeout,
|
|
}))
|
|
|
|
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
|
|
var lastSendTime time.Time
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
|
|
buf.Data = append(buf.Data, []byte("foobar")...)
|
|
lastSendTime = time.Now()
|
|
return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil
|
|
},
|
|
)
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, &IdleTimeoutError{})
|
|
require.NotZero(t, lastSendTime)
|
|
require.InDelta(t,
|
|
time.Since(lastSendTime).Seconds(),
|
|
idleTimeout.Seconds(),
|
|
scaleDuration(10*time.Millisecond).Seconds(),
|
|
)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionKeepAlive(t *testing.T) {
|
|
t.Run("enabled", func(t *testing.T) {
|
|
testConnectionKeepAlive(t, true, true)
|
|
})
|
|
|
|
t.Run("disabled", func(t *testing.T) {
|
|
testConnectionKeepAlive(t, false, false)
|
|
})
|
|
}
|
|
|
|
func testConnectionKeepAlive(t *testing.T, enable, expectKeepAlive bool) {
|
|
var keepAlivePeriod time.Duration
|
|
if enable {
|
|
keepAlivePeriod = time.Second
|
|
}
|
|
|
|
mockCtrl := gomock.NewController(t)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
&Config{MaxIdleTimeout: time.Second, KeepAlivePeriod: keepAlivePeriod},
|
|
false,
|
|
connectionOptUnpacker(unpacker),
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptRTT(time.Millisecond),
|
|
)
|
|
// the idle timeout is set when the transport parameters are received
|
|
idleTimeout := scaleDuration(50 * time.Millisecond)
|
|
require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
|
|
MaxIdleTimeout: idleTimeout,
|
|
}))
|
|
|
|
// Receive a packet. This starts the keep-alive timer.
|
|
buf := getPacketBuffer()
|
|
var err error
|
|
buf.Data, err = wire.AppendShortHeader(buf.Data, tc.srcConnID, 1, protocol.PacketNumberLen1, protocol.KeyPhaseZero)
|
|
require.NoError(t, err)
|
|
buf.Data = append(buf.Data, []byte("packet")...)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
var unpackTime, packTime time.Time
|
|
done := make(chan struct{})
|
|
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(t time.Time, bytes []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
|
|
unpackTime = time.Now()
|
|
return protocol.PacketNumber(1), protocol.PacketNumberLen1, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil
|
|
},
|
|
)
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
|
|
|
|
switch expectKeepAlive {
|
|
case true:
|
|
// record the time of the keep-alive is sent
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
packTime = time.Now()
|
|
close(done)
|
|
return shortHeaderPacket{}, errNothingToPack
|
|
},
|
|
)
|
|
tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: time.Now()})
|
|
select {
|
|
case <-done:
|
|
// the keep-alive packet should be sent after half the idle timeout
|
|
diff := packTime.Sub(unpackTime)
|
|
require.InDelta(t, diff.Seconds(), idleTimeout.Seconds()/2, scaleDuration(10*time.Millisecond).Seconds())
|
|
case <-time.After(idleTimeout):
|
|
t.Fatal("timeout")
|
|
}
|
|
case false: // if keep-alives are disabled, the connection will run into an idle timeout
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: time.Now()})
|
|
select {
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
case <-time.After(idleTimeout):
|
|
}
|
|
}
|
|
|
|
// test teardown
|
|
if expectKeepAlive {
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
}
|
|
select {
|
|
case err := <-errChan:
|
|
if expectKeepAlive {
|
|
require.NoError(t, err)
|
|
} else {
|
|
require.ErrorIs(t, err, &IdleTimeoutError{})
|
|
}
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionACKTimer(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
&Config{MaxIdleTimeout: time.Second},
|
|
false,
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptReceivedPacketHandler(rph),
|
|
connectionOptSentPacketHandler(sph),
|
|
connectionOptRTT(10*time.Second),
|
|
)
|
|
alarmTimeout := scaleDuration(50 * time.Millisecond)
|
|
|
|
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
|
|
rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour))
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
|
|
var times []time.Time
|
|
done := make(chan struct{}, 5)
|
|
var calls []any
|
|
for i := 0; i < 2; i++ {
|
|
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
|
|
buf.Data = append(buf.Data, []byte("foobar")...)
|
|
times = append(times, time.Now())
|
|
return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil
|
|
},
|
|
))
|
|
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
|
|
done <- struct{}{}
|
|
return shortHeaderPacket{}, errNothingToPack
|
|
},
|
|
))
|
|
if i == 0 {
|
|
calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(alarmTimeout)))
|
|
} else {
|
|
calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).MaxTimes(1))
|
|
}
|
|
}
|
|
gomock.InOrder(calls...)
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
for i := 0; i < 2; i++ {
|
|
select {
|
|
case <-done:
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
assert.Len(t, times, 2)
|
|
require.InDelta(t, times[1].Sub(times[0]).Seconds(), alarmTimeout.Seconds(), scaleDuration(10*time.Millisecond).Seconds())
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
// Send a GSO batch, until we have no more data to send.
|
|
func TestConnectionGSOBatch(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
true,
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptSentPacketHandler(sph),
|
|
)
|
|
|
|
// allow packets to be sent
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
|
|
sph.EXPECT().TimeUntilSend().Return(time.Time{}).AnyTimes()
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
sph.EXPECT().GetLossDetectionTimeout().Return(time.Time{}).AnyTimes()
|
|
sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
|
|
|
|
maxPacketSize := tc.conn.maxPacketSize()
|
|
var expectedData []byte
|
|
for i := 0; i < 4; i++ {
|
|
data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
|
|
expectedData = append(expectedData, data...)
|
|
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
buffer.Data = append(buffer.Data, data...)
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil
|
|
},
|
|
)
|
|
}
|
|
done := make(chan struct{})
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
|
|
tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1).DoAndReturn(
|
|
func([]byte, uint16, protocol.ECN) error { close(done); return nil },
|
|
)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
// Send a GSO batch, until a packet smaller than the maximum size is packed
|
|
func TestConnectionGSOBatchPacketSize(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
true,
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptSentPacketHandler(sph),
|
|
)
|
|
|
|
// allow packets to be sent
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
|
|
sph.EXPECT().TimeUntilSend().Return(time.Time{}).AnyTimes()
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
sph.EXPECT().GetLossDetectionTimeout().Return(time.Time{}).AnyTimes()
|
|
sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
|
|
|
|
maxPacketSize := tc.conn.maxPacketSize()
|
|
var expectedData []byte
|
|
var calls []any
|
|
for i := 0; i < 4; i++ {
|
|
var data []byte
|
|
if i == 3 {
|
|
data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize-1))
|
|
} else {
|
|
data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
|
|
}
|
|
expectedData = append(expectedData, data...)
|
|
|
|
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
buffer.Data = append(buffer.Data, data...)
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(10 + i)}, nil
|
|
},
|
|
))
|
|
}
|
|
// The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch.
|
|
// We therefore send a "foobar", so we can check that we're actually generating two GSO batches.
|
|
calls = append(calls,
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
buffer.Data = append(buffer.Data, []byte("foobar")...)
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(14)}, nil
|
|
},
|
|
),
|
|
)
|
|
calls = append(calls,
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack),
|
|
)
|
|
gomock.InOrder(calls...)
|
|
|
|
done := make(chan struct{})
|
|
gomock.InOrder(
|
|
tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1),
|
|
tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECT1).DoAndReturn(
|
|
func([]byte, uint16, protocol.ECN) error { close(done); return nil },
|
|
),
|
|
)
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionGSOBatchECN(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
true,
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptSentPacketHandler(sph),
|
|
)
|
|
|
|
// allow packets to be sent
|
|
ecnMode := protocol.ECT1
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
|
|
sph.EXPECT().TimeUntilSend().Return(time.Time{}).AnyTimes()
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
sph.EXPECT().GetLossDetectionTimeout().Return(time.Time{}).AnyTimes()
|
|
sph.EXPECT().ECNMode(gomock.Any()).DoAndReturn(func(bool) protocol.ECN { return ecnMode }).AnyTimes()
|
|
|
|
// 3. Send a GSO batch, until the ECN marking changes.
|
|
var expectedData []byte
|
|
var calls []any
|
|
maxPacketSize := tc.conn.maxPacketSize()
|
|
for i := 0; i < 3; i++ {
|
|
data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
|
|
expectedData = append(expectedData, data...)
|
|
|
|
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
buffer.Data = append(buffer.Data, data...)
|
|
if i == 2 {
|
|
ecnMode = protocol.ECNCE
|
|
}
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(20 + i)}, nil
|
|
},
|
|
))
|
|
}
|
|
// The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch.
|
|
// We therefore send a "foobar", so we can check that we're actually generating two GSO batches.
|
|
calls = append(calls,
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
buffer.Data = append(buffer.Data, []byte("foobar")...)
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(24)}, nil
|
|
},
|
|
),
|
|
)
|
|
calls = append(calls,
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack),
|
|
)
|
|
gomock.InOrder(calls...)
|
|
|
|
done3 := make(chan struct{})
|
|
tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1)
|
|
tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECNCE).DoAndReturn(
|
|
func([]byte, uint16, protocol.ECN) error { close(done3); return nil },
|
|
)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
select {
|
|
case <-done3:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionPTOProbePackets(t *testing.T) {
|
|
t.Run("Initial", func(t *testing.T) {
|
|
testConnectionPTOProbePackets(t, protocol.EncryptionInitial)
|
|
})
|
|
t.Run("Handshake", func(t *testing.T) {
|
|
testConnectionPTOProbePackets(t, protocol.EncryptionHandshake)
|
|
})
|
|
t.Run("1-RTT", func(t *testing.T) {
|
|
testConnectionPTOProbePackets(t, protocol.Encryption1RTT)
|
|
})
|
|
}
|
|
|
|
func testConnectionPTOProbePackets(t *testing.T, encLevel protocol.EncryptionLevel) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptSentPacketHandler(sph),
|
|
)
|
|
|
|
var sendMode ackhandler.SendMode
|
|
switch encLevel {
|
|
case protocol.EncryptionInitial:
|
|
sendMode = ackhandler.SendPTOInitial
|
|
case protocol.EncryptionHandshake:
|
|
sendMode = ackhandler.SendPTOHandshake
|
|
case protocol.Encryption1RTT:
|
|
sendMode = ackhandler.SendPTOAppData
|
|
}
|
|
|
|
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
|
sph.EXPECT().TimeUntilSend().AnyTimes()
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(sendMode)
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
|
|
sph.EXPECT().ECNMode(gomock.Any())
|
|
sph.EXPECT().QueueProbePacket(encLevel).Return(false)
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
|
|
tc.packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(
|
|
func(encLevel protocol.EncryptionLevel, maxSize protocol.ByteCount, t time.Time, version protocol.Version) (*coalescedPacket, error) {
|
|
return &coalescedPacket{
|
|
buffer: getPacketBuffer(),
|
|
shortHdrPacket: &shortHeaderPacket{PacketNumber: 1},
|
|
}, nil
|
|
},
|
|
)
|
|
done := make(chan struct{})
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(
|
|
func([]byte, uint16, protocol.ECN) error { close(done); return nil },
|
|
)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionCongestionControl(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptSentPacketHandler(sph),
|
|
connectionOptRTT(10*time.Second),
|
|
)
|
|
|
|
sph.EXPECT().TimeUntilSend().AnyTimes()
|
|
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
|
sph.EXPECT().ECNMode(true).AnyTimes()
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2)
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck).MaxTimes(1)
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
|
|
// Since we're already sending out packets, we don't expect any calls to PackAckOnlyPacket
|
|
for i := 0; i < 2; i++ {
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
buffer.Data = append(buffer.Data, []byte("foobar")...)
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil
|
|
},
|
|
)
|
|
}
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
|
|
done1 := make(chan struct{})
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(
|
|
func([]byte, uint16, protocol.ECN) error { close(done1); return nil },
|
|
)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
select {
|
|
case <-done1:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// Now that we're congestion limited, we can only send an ack-only packet
|
|
done2 := make(chan struct{})
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck)
|
|
tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(protocol.ByteCount, time.Time, protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
|
close(done2)
|
|
return shortHeaderPacket{}, nil, errNothingToPack
|
|
},
|
|
)
|
|
tc.conn.scheduleSending()
|
|
select {
|
|
case <-done2:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// If the send mode is "none", we can't even send an ack-only packet
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
|
|
tc.conn.scheduleSending()
|
|
time.Sleep(scaleDuration(10 * time.Millisecond)) // make sure there are no calls to the packer
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionSendQueue(t *testing.T) {
|
|
t.Run("with GSO", func(t *testing.T) {
|
|
testConnectionSendQueue(t, true)
|
|
})
|
|
t.Run("without GSO", func(t *testing.T) {
|
|
testConnectionSendQueue(t, false)
|
|
})
|
|
}
|
|
|
|
func testConnectionSendQueue(t *testing.T, enableGSO bool) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
sender := NewMockSender(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
enableGSO,
|
|
connectionOptSender(sender),
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptSentPacketHandler(sph),
|
|
)
|
|
|
|
sender.EXPECT().Run().MaxTimes(1)
|
|
sender.EXPECT().WouldBlock()
|
|
sender.EXPECT().WouldBlock().Return(true).Times(2)
|
|
available := make(chan struct{})
|
|
blocked := make(chan struct{})
|
|
sender.EXPECT().Available().DoAndReturn(
|
|
func() <-chan struct{} {
|
|
close(blocked)
|
|
return available
|
|
},
|
|
)
|
|
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
|
|
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
|
|
shortHeaderPacket{PacketNumber: protocol.PacketNumber(1)}, nil,
|
|
)
|
|
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any())
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
select {
|
|
case <-blocked:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// now make room in the send queue
|
|
sender.EXPECT().WouldBlock().AnyTimes()
|
|
unblocked := make(chan struct{})
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(*packetBuffer, protocol.ByteCount, time.Time, protocol.Version) (shortHeaderPacket, error) {
|
|
close(unblocked)
|
|
return shortHeaderPacket{}, errNothingToPack
|
|
},
|
|
)
|
|
available <- struct{}{}
|
|
select {
|
|
case <-unblocked:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
sender.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func getVersionNegotiationPacket(src, dest protocol.ConnectionID, versions []protocol.Version) receivedPacket {
|
|
b := wire.ComposeVersionNegotiation(
|
|
protocol.ArbitraryLenConnectionID(src.Bytes()),
|
|
protocol.ArbitraryLenConnectionID(dest.Bytes()),
|
|
versions,
|
|
)
|
|
return receivedPacket{
|
|
rcvTime: time.Now(),
|
|
data: b,
|
|
buffer: getPacketBuffer(),
|
|
}
|
|
}
|
|
|
|
func TestConnectionVersionNegotiation(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
)
|
|
|
|
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
|
var tracerVersions []logging.Version
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) {
|
|
tracerVersions = versions
|
|
}),
|
|
tracer.EXPECT().NegotiatedVersion(protocol.Version2, gomock.Any(), gomock.Any()),
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()),
|
|
)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.handlePacket(getVersionNegotiationPacket(
|
|
tc.destConnID,
|
|
tc.srcConnID,
|
|
[]protocol.Version{1234, protocol.Version2},
|
|
))
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
var rerr *errCloseForRecreating
|
|
require.ErrorAs(t, err, &rerr)
|
|
require.Equal(t, rerr.nextVersion, protocol.Version2)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
require.Contains(t, tracerVersions, protocol.Version(1234))
|
|
require.Contains(t, tracerVersions, protocol.Version2)
|
|
}
|
|
|
|
func TestConnectionVersionNegotiationNoMatch(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
&Config{Versions: []protocol.Version{protocol.Version1}},
|
|
false,
|
|
connectionOptTracer(tr),
|
|
)
|
|
|
|
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
|
var tracerVersions []logging.Version
|
|
tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(
|
|
func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) { tracerVersions = versions },
|
|
)
|
|
tracer.EXPECT().ClosedConnection(gomock.Any())
|
|
tracer.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any())
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.handlePacket(getVersionNegotiationPacket(
|
|
tc.destConnID,
|
|
tc.srcConnID,
|
|
[]protocol.Version{protocol.Version2},
|
|
))
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
var verr *VersionNegotiationError
|
|
require.ErrorAs(t, err, &verr)
|
|
require.Contains(t, verr.Theirs, protocol.Version2)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
require.Contains(t, tracerVersions, protocol.Version2)
|
|
}
|
|
|
|
func TestConnectionVersionNegotiationInvalidPackets(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
)
|
|
|
|
// offers the current version
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedVersion)
|
|
vnp := getVersionNegotiationPacket(
|
|
tc.destConnID,
|
|
tc.srcConnID,
|
|
[]protocol.Version{1234, protocol.Version1},
|
|
)
|
|
wasProcessed, err := tc.conn.handlePacketImpl(vnp)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// unparseable, since it's missing 2 bytes
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, gomock.Any(), gomock.Any(), logging.PacketDropHeaderParseError)
|
|
vnp.data = vnp.data[:len(vnp.data)-2]
|
|
wasProcessed, err = tc.conn.handlePacketImpl(vnp)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
}
|
|
|
|
func getRetryPacket(t *testing.T, src, dest, origDest protocol.ConnectionID, token []byte) receivedPacket {
|
|
hdr := wire.Header{
|
|
Type: protocol.PacketTypeRetry,
|
|
SrcConnectionID: src,
|
|
DestConnectionID: dest,
|
|
Token: token,
|
|
Version: protocol.Version1,
|
|
}
|
|
b, err := (&wire.ExtendedHeader{Header: hdr}).Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
tag := handshake.GetRetryIntegrityTag(b, origDest, protocol.Version1)
|
|
b = append(b, tag[:]...)
|
|
return receivedPacket{
|
|
rcvTime: time.Now(),
|
|
data: b,
|
|
buffer: getPacketBuffer(),
|
|
}
|
|
}
|
|
|
|
func TestConnectionRetryDrops(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
connectionOptUnpacker(unpacker),
|
|
)
|
|
|
|
newConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
|
|
|
|
// invalid integrity tag
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropPayloadDecryptError)
|
|
retry := getRetryPacket(t, newConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
|
|
retry.data[len(retry.data)-1]++
|
|
wasProcessed, err := tc.conn.handlePacketImpl(retry)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// receive a retry that doesn't change the connection ID
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedPacket)
|
|
retry = getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
|
|
wasProcessed, err = tc.conn.handlePacketImpl(retry)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
}
|
|
|
|
func TestConnectionRetryAfterReceivedPacket(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
connectionOptUnpacker(unpacker),
|
|
)
|
|
|
|
// receive a regular packet
|
|
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
regular := getPacketWithPacketType(t, tc.srcConnID, protocol.PacketTypeInitial, 200)
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
|
|
&unpackedPacket{
|
|
hdr: &wire.ExtendedHeader{Header: wire.Header{Type: protocol.PacketTypeInitial}},
|
|
encryptionLevel: protocol.EncryptionInitial,
|
|
}, nil,
|
|
)
|
|
wasProcessed, err := tc.conn.handlePacketImpl(receivedPacket{
|
|
data: regular,
|
|
buffer: getPacketBuffer(),
|
|
rcvTime: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, wasProcessed)
|
|
|
|
// receive a retry
|
|
retry := getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedPacket)
|
|
wasProcessed, err = tc.conn.handlePacketImpl(retry)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
}
|
|
|
|
func TestConnectionConnectionIDChanges(t *testing.T) {
|
|
t.Run("with retry", func(t *testing.T) {
|
|
testConnectionConnectionIDChanges(t, true)
|
|
})
|
|
t.Run("without retry", func(t *testing.T) {
|
|
testConnectionConnectionIDChanges(t, false)
|
|
})
|
|
}
|
|
|
|
func testConnectionConnectionIDChanges(t *testing.T, sendRetry bool) {
|
|
makeInitialPacket := func(t *testing.T, hdr *wire.ExtendedHeader) []byte {
|
|
t.Helper()
|
|
data, err := hdr.Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
data = append(data, make([]byte, hdr.Length-protocol.ByteCount(hdr.PacketNumberLen))...)
|
|
return data
|
|
}
|
|
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
connectionOptUnpacker(unpacker),
|
|
)
|
|
|
|
dstConnID := tc.destConnID
|
|
b := make([]byte, 3*10)
|
|
rand.Read(b)
|
|
newConnID := protocol.ParseConnectionID(b[:11])
|
|
newConnID2 := protocol.ParseConnectionID(b[11:20])
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
|
|
|
require.Equal(t, dstConnID, tc.conn.connIDManager.Get())
|
|
|
|
var retryConnID protocol.ConnectionID
|
|
if sendRetry {
|
|
retryConnID = protocol.ParseConnectionID(b[20:30])
|
|
hdrChan := make(chan *wire.Header)
|
|
tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { hdrChan <- hdr })
|
|
tc.packer.EXPECT().SetToken([]byte("foobar"))
|
|
|
|
tc.conn.handlePacket(getRetryPacket(t, retryConnID, tc.srcConnID, tc.destConnID, []byte("foobar")))
|
|
select {
|
|
case hdr := <-hdrChan:
|
|
assert.Equal(t, retryConnID, hdr.SrcConnectionID)
|
|
assert.Equal(t, []byte("foobar"), hdr.Token)
|
|
require.Equal(t, retryConnID, tc.conn.connIDManager.Get())
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
// Send the first packet. The server changes the connection ID to newConnID.
|
|
hdr1 := wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
SrcConnectionID: newConnID,
|
|
DestConnectionID: tc.srcConnID,
|
|
Type: protocol.PacketTypeInitial,
|
|
Length: 200,
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumber: 1,
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}
|
|
hdr2 := hdr1
|
|
hdr2.SrcConnectionID = newConnID2
|
|
|
|
receivedFirst := make(chan struct{})
|
|
gomock.InOrder(
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
|
|
&unpackedPacket{
|
|
hdr: &hdr1,
|
|
encryptionLevel: protocol.EncryptionInitial,
|
|
}, nil,
|
|
),
|
|
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(
|
|
func(*wire.ExtendedHeader, protocol.ByteCount, protocol.ECN, []logging.Frame) { close(receivedFirst) },
|
|
),
|
|
)
|
|
|
|
tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr1), buffer: getPacketBuffer(), rcvTime: time.Now()})
|
|
|
|
select {
|
|
case <-receivedFirst:
|
|
require.Equal(t, newConnID, tc.conn.connIDManager.Get())
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// Send the second packet. We refuse to accept it, because the connection ID is changed again.
|
|
dropped := make(chan struct{})
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, gomock.Any(), gomock.Any(), logging.PacketDropUnknownConnectionID).Do(
|
|
func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason) {
|
|
close(dropped)
|
|
},
|
|
)
|
|
|
|
tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr2), buffer: getPacketBuffer(), rcvTime: time.Now()})
|
|
select {
|
|
case <-dropped:
|
|
// the connection ID should not have changed
|
|
require.Equal(t, newConnID, tc.conn.connIDManager.Get())
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tracer.EXPECT().ClosedConnection(gomock.Any())
|
|
tracer.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any())
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
// When the connection is closed before sending the first packet,
|
|
// we don't send a CONNECTION_CLOSE.
|
|
// This can happen if there's something wrong the tls.Config, and
|
|
// crypto/tls refuses to start the handshake.
|
|
func TestConnectionEarlyClose(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
cryptoSetup := mocks.NewMockCryptoSetup(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
connectionOptCryptoSetup(cryptoSetup),
|
|
)
|
|
|
|
tc.conn.sentFirstPacket = false
|
|
tracer.EXPECT().ClosedConnection(gomock.Any())
|
|
tracer.EXPECT().Close()
|
|
cryptoSetup.EXPECT().StartHandshake(gomock.Any()).Do(func(context.Context) error {
|
|
tc.conn.closeLocal(errors.New("early error"))
|
|
return nil
|
|
})
|
|
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
|
|
cryptoSetup.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any())
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "early error")
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|