mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
* remove unused bool return value from sentPacketHandler.getPTOTimeAndSpace * ackhandler: implement timer logic for path probing packets Path probe packets are treated differently from regular packets: The new path might have a vastly different RTT than the original path. Path probe packets are declared lost 1s after they are sent. This value can be reduced, once implement proper retransmission logic for lost path probes. * ackhandler: declare path probes lost on OnLossDetectionTimeout
2771 lines
99 KiB
Go
2771 lines
99 KiB
Go
package quic
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"errors"
|
|
"net"
|
|
"net/netip"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go/internal/ackhandler"
|
|
"github.com/quic-go/quic-go/internal/handshake"
|
|
"github.com/quic-go/quic-go/internal/mocks"
|
|
mockackhandler "github.com/quic-go/quic-go/internal/mocks/ackhandler"
|
|
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
|
"github.com/quic-go/quic-go/internal/protocol"
|
|
"github.com/quic-go/quic-go/internal/qerr"
|
|
"github.com/quic-go/quic-go/internal/utils"
|
|
"github.com/quic-go/quic-go/internal/wire"
|
|
"github.com/quic-go/quic-go/logging"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
type testConnectionOpt func(*connection)
|
|
|
|
func connectionOptCryptoSetup(cs *mocks.MockCryptoSetup) testConnectionOpt {
|
|
return func(conn *connection) { conn.cryptoStreamHandler = cs }
|
|
}
|
|
|
|
func connectionOptStreamManager(sm *MockStreamManager) testConnectionOpt {
|
|
return func(conn *connection) { conn.streamsMap = sm }
|
|
}
|
|
|
|
func connectionOptConnFlowController(cfc *mocks.MockConnectionFlowController) testConnectionOpt {
|
|
return func(conn *connection) { conn.connFlowController = cfc }
|
|
}
|
|
|
|
func connectionOptTracer(tr *logging.ConnectionTracer) testConnectionOpt {
|
|
return func(conn *connection) { conn.tracer = tr }
|
|
}
|
|
|
|
func connectionOptSentPacketHandler(sph ackhandler.SentPacketHandler) testConnectionOpt {
|
|
return func(conn *connection) { conn.sentPacketHandler = sph }
|
|
}
|
|
|
|
func connectionOptReceivedPacketHandler(rph ackhandler.ReceivedPacketHandler) testConnectionOpt {
|
|
return func(conn *connection) { conn.receivedPacketHandler = rph }
|
|
}
|
|
|
|
func connectionOptUnpacker(u unpacker) testConnectionOpt {
|
|
return func(conn *connection) { conn.unpacker = u }
|
|
}
|
|
|
|
func connectionOptSender(s sender) testConnectionOpt {
|
|
return func(conn *connection) { conn.sendQueue = s }
|
|
}
|
|
|
|
func connectionOptHandshakeConfirmed() testConnectionOpt {
|
|
return func(conn *connection) {
|
|
conn.handshakeComplete = true
|
|
conn.handshakeConfirmed = true
|
|
}
|
|
}
|
|
|
|
func connectionOptRTT(rtt time.Duration) testConnectionOpt {
|
|
var rttStats utils.RTTStats
|
|
rttStats.UpdateRTT(rtt, 0)
|
|
return func(conn *connection) { conn.rttStats = &rttStats }
|
|
}
|
|
|
|
func connectionOptRetrySrcConnID(rcid protocol.ConnectionID) testConnectionOpt {
|
|
return func(conn *connection) { conn.retrySrcConnID = &rcid }
|
|
}
|
|
|
|
type testConnection struct {
|
|
conn *connection
|
|
connRunner *MockConnRunner
|
|
sendConn *MockSendConn
|
|
packer *MockPacker
|
|
destConnID protocol.ConnectionID
|
|
srcConnID protocol.ConnectionID
|
|
}
|
|
|
|
func newServerTestConnection(
|
|
t *testing.T,
|
|
mockCtrl *gomock.Controller,
|
|
config *Config,
|
|
gso bool,
|
|
opts ...testConnectionOpt,
|
|
) *testConnection {
|
|
if mockCtrl == nil {
|
|
mockCtrl = gomock.NewController(t)
|
|
}
|
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
|
|
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
|
|
connRunner := NewMockConnRunner(mockCtrl)
|
|
sendConn := NewMockSendConn(mockCtrl)
|
|
sendConn.EXPECT().capabilities().Return(connCapabilities{GSO: gso}).AnyTimes()
|
|
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
|
|
sendConn.EXPECT().LocalAddr().Return(localAddr).AnyTimes()
|
|
packer := NewMockPacker(mockCtrl)
|
|
b := make([]byte, 12)
|
|
rand.Read(b)
|
|
origDestConnID := protocol.ParseConnectionID(b[:6])
|
|
srcConnID := protocol.ParseConnectionID(b[6:12])
|
|
ctx, cancel := context.WithCancelCause(context.Background())
|
|
if config == nil {
|
|
config = &Config{DisablePathMTUDiscovery: true}
|
|
}
|
|
conn := newConnection(
|
|
ctx,
|
|
cancel,
|
|
sendConn,
|
|
connRunner,
|
|
origDestConnID,
|
|
nil,
|
|
protocol.ConnectionID{},
|
|
protocol.ConnectionID{},
|
|
srcConnID,
|
|
&protocol.DefaultConnectionIDGenerator{},
|
|
newStatelessResetter(nil),
|
|
populateConfig(config),
|
|
&tls.Config{},
|
|
handshake.NewTokenGenerator(handshake.TokenProtectorKey{}),
|
|
false,
|
|
nil,
|
|
utils.DefaultLogger,
|
|
protocol.Version1,
|
|
).(*connection)
|
|
conn.packer = packer
|
|
for _, opt := range opts {
|
|
opt(conn)
|
|
}
|
|
return &testConnection{
|
|
conn: conn,
|
|
connRunner: connRunner,
|
|
sendConn: sendConn,
|
|
packer: packer,
|
|
destConnID: origDestConnID,
|
|
srcConnID: srcConnID,
|
|
}
|
|
}
|
|
|
|
func newClientTestConnection(
|
|
t *testing.T,
|
|
mockCtrl *gomock.Controller,
|
|
config *Config,
|
|
enable0RTT bool,
|
|
opts ...testConnectionOpt,
|
|
) *testConnection {
|
|
if mockCtrl == nil {
|
|
mockCtrl = gomock.NewController(t)
|
|
}
|
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
|
|
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
|
|
connRunner := NewMockConnRunner(mockCtrl)
|
|
sendConn := NewMockSendConn(mockCtrl)
|
|
sendConn.EXPECT().capabilities().Return(connCapabilities{}).AnyTimes()
|
|
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
|
|
sendConn.EXPECT().LocalAddr().Return(localAddr).AnyTimes()
|
|
packer := NewMockPacker(mockCtrl)
|
|
b := make([]byte, 12)
|
|
rand.Read(b)
|
|
destConnID := protocol.ParseConnectionID(b[:6])
|
|
srcConnID := protocol.ParseConnectionID(b[6:12])
|
|
if config == nil {
|
|
config = &Config{DisablePathMTUDiscovery: true}
|
|
}
|
|
conn := newClientConnection(
|
|
context.Background(),
|
|
sendConn,
|
|
connRunner,
|
|
destConnID,
|
|
srcConnID,
|
|
&protocol.DefaultConnectionIDGenerator{},
|
|
newStatelessResetter(nil),
|
|
populateConfig(config),
|
|
&tls.Config{ServerName: "quic-go.net"},
|
|
0,
|
|
enable0RTT,
|
|
false,
|
|
nil,
|
|
utils.DefaultLogger,
|
|
protocol.Version1,
|
|
).(*connection)
|
|
conn.packer = packer
|
|
for _, opt := range opts {
|
|
opt(conn)
|
|
}
|
|
return &testConnection{
|
|
conn: conn,
|
|
connRunner: connRunner,
|
|
sendConn: sendConn,
|
|
packer: packer,
|
|
destConnID: destConnID,
|
|
srcConnID: srcConnID,
|
|
}
|
|
}
|
|
|
|
func TestConnectionHandleReceiveStreamFrames(t *testing.T) {
|
|
const streamID protocol.StreamID = 5
|
|
now := time.Now()
|
|
connID := protocol.ConnectionID{}
|
|
f := &wire.StreamFrame{StreamID: streamID, Data: []byte("foobar")}
|
|
rsf := &wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 42, FinalSize: 1337}
|
|
sdbf := &wire.StreamDataBlockedFrame{StreamID: streamID, MaximumStreamData: 1337}
|
|
|
|
t.Run("for existing and new streams", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
str := NewMockReceiveStreamI(mockCtrl)
|
|
// STREAM frame
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
|
|
str.EXPECT().handleStreamFrame(f, now)
|
|
require.NoError(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now))
|
|
// RESET_STREAM frame
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
|
|
str.EXPECT().handleResetStreamFrame(rsf, now)
|
|
require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now))
|
|
// STREAM_DATA_BLOCKED frames are not passed to the stream
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
|
|
require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now))
|
|
})
|
|
|
|
t.Run("for closed streams", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
// STREAM frame
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
|
|
require.NoError(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now))
|
|
// RESET_STREAM frame
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
|
|
require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now))
|
|
// STREAM_DATA_BLOCKED frames are not passed to the stream
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
|
|
require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now))
|
|
})
|
|
|
|
t.Run("for invalid streams", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
testErr := errors.New("test err")
|
|
// STREAM frame
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
|
|
require.ErrorIs(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now), testErr)
|
|
// RESET_STREAM frame
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
|
|
require.ErrorIs(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now), testErr)
|
|
// STREAM_DATA_BLOCKED frames are not passed to the stream
|
|
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
|
|
require.ErrorIs(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now), testErr)
|
|
})
|
|
}
|
|
|
|
func TestConnectionHandleSendStreamFrames(t *testing.T) {
|
|
const streamID protocol.StreamID = 3
|
|
now := time.Now()
|
|
connID := protocol.ConnectionID{}
|
|
ss := &wire.StopSendingFrame{StreamID: streamID, ErrorCode: 42}
|
|
msd := &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337}
|
|
|
|
t.Run("for existing and new streams", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
str := NewMockSendStreamI(mockCtrl)
|
|
// STOP_SENDING frame
|
|
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
|
|
str.EXPECT().handleStopSendingFrame(ss)
|
|
require.NoError(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now))
|
|
// MAX_STREAM_DATA frame
|
|
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
|
|
str.EXPECT().updateSendWindow(msd.MaximumStreamData)
|
|
require.NoError(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now))
|
|
})
|
|
|
|
t.Run("for closed streams", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
// STOP_SENDING frame
|
|
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
|
|
require.NoError(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now))
|
|
// MAX_STREAM_DATA frame
|
|
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
|
|
require.NoError(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now))
|
|
})
|
|
|
|
t.Run("for invalid streams", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
testErr := errors.New("test err")
|
|
// STOP_SENDING frame
|
|
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, testErr)
|
|
require.ErrorIs(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now), testErr)
|
|
// MAX_STREAM_DATA frame
|
|
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, testErr)
|
|
require.ErrorIs(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now), testErr)
|
|
})
|
|
}
|
|
|
|
func TestConnectionHandleStreamNumFrames(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
now := time.Now()
|
|
connID := protocol.ConnectionID{}
|
|
// MAX_STREAMS frame
|
|
msf := &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10}
|
|
streamsMap.EXPECT().HandleMaxStreamsFrame(msf)
|
|
require.NoError(t, tc.conn.handleFrame(msf, protocol.Encryption1RTT, connID, now))
|
|
// STREAMS_BLOCKED frame
|
|
tc.conn.handleFrame(&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 1}, protocol.Encryption1RTT, connID, now)
|
|
}
|
|
|
|
func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
connFC := mocks.NewMockConnectionFlowController(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC))
|
|
now := time.Now()
|
|
connID := protocol.ConnectionID{}
|
|
// MAX_DATA frame
|
|
connFC.EXPECT().UpdateSendWindow(protocol.ByteCount(1337))
|
|
require.NoError(t, tc.conn.handleFrame(&wire.MaxDataFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now))
|
|
// DATA_BLOCKED frame
|
|
require.NoError(t, tc.conn.handleFrame(&wire.DataBlockedFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now))
|
|
}
|
|
|
|
func TestConnectionOpenStreams(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
|
|
// using OpenStream
|
|
mstr := NewMockStreamI(mockCtrl)
|
|
streamsMap.EXPECT().OpenStream().Return(mstr, nil)
|
|
str, err := tc.conn.OpenStream()
|
|
require.NoError(t, err)
|
|
require.Equal(t, mstr, str)
|
|
|
|
// using OpenStreamSync
|
|
streamsMap.EXPECT().OpenStreamSync(context.Background()).Return(mstr, nil)
|
|
str, err = tc.conn.OpenStreamSync(context.Background())
|
|
require.NoError(t, err)
|
|
require.Equal(t, mstr, str)
|
|
|
|
// using OpenUniStream
|
|
streamsMap.EXPECT().OpenUniStream().Return(mstr, nil)
|
|
ustr, err := tc.conn.OpenUniStream()
|
|
require.NoError(t, err)
|
|
require.Equal(t, mstr, ustr)
|
|
|
|
// using OpenUniStreamSync
|
|
streamsMap.EXPECT().OpenUniStreamSync(context.Background()).Return(mstr, nil)
|
|
ustr, err = tc.conn.OpenUniStreamSync(context.Background())
|
|
require.NoError(t, err)
|
|
require.Equal(t, mstr, ustr)
|
|
}
|
|
|
|
func TestConnectionAcceptStreams(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
streamsMap := NewMockStreamManager(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
|
|
|
|
// bidirectional streams
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
|
defer cancel()
|
|
mstr := NewMockStreamI(mockCtrl)
|
|
streamsMap.EXPECT().AcceptStream(ctx).Return(mstr, nil)
|
|
str, err := tc.conn.AcceptStream(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, mstr, str)
|
|
|
|
// unidirectional streams
|
|
streamsMap.EXPECT().AcceptUniStream(ctx).Return(mstr, nil)
|
|
ustr, err := tc.conn.AcceptUniStream(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, mstr, ustr)
|
|
}
|
|
|
|
func TestConnectionServerInvalidFrames(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false)
|
|
|
|
for _, test := range []struct {
|
|
Name string
|
|
Frame wire.Frame
|
|
}{
|
|
{Name: "NEW_TOKEN", Frame: &wire.NewTokenFrame{Token: []byte("foobar")}},
|
|
{Name: "HANDSHAKE_DONE", Frame: &wire.HandshakeDoneFrame{}},
|
|
{Name: "PATH_RESPONSE", Frame: &wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}},
|
|
} {
|
|
t.Run(test.Name, func(t *testing.T) {
|
|
require.ErrorIs(t,
|
|
tc.conn.handleFrame(test.Frame, protocol.Encryption1RTT, protocol.ConnectionID{}, time.Now()),
|
|
&qerr.TransportError{ErrorCode: qerr.ProtocolViolation},
|
|
)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConnectionTransportError(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
errChan := make(chan error, 1)
|
|
expectedErr := &qerr.TransportError{
|
|
ErrorCode: 1337,
|
|
FrameType: 42,
|
|
ErrorMessage: "test error",
|
|
}
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
b := getPacketBuffer()
|
|
b.Data = append(b.Data, []byte("connection close")...)
|
|
tc.packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil)
|
|
tc.sendConn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any())
|
|
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ClosedConnection(expectedErr),
|
|
tracer.EXPECT().Close(),
|
|
)
|
|
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.closeLocal(expectedErr)
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, expectedErr)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// further calls to CloseWithError don't do anything
|
|
tc.conn.CloseWithError(42, "another error")
|
|
}
|
|
|
|
func TestConnectionApplicationClose(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
errChan := make(chan error, 1)
|
|
expectedErr := &qerr.ApplicationError{
|
|
ErrorCode: 1337,
|
|
ErrorMessage: "test error",
|
|
}
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
b := getPacketBuffer()
|
|
b.Data = append(b.Data, []byte("connection close")...)
|
|
tc.packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil)
|
|
tc.sendConn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any())
|
|
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ClosedConnection(expectedErr),
|
|
tracer.EXPECT().Close(),
|
|
)
|
|
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.CloseWithError(1337, "test error")
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, expectedErr)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// further calls to CloseWithError don't do anything
|
|
tc.conn.CloseWithError(42, "another error")
|
|
}
|
|
|
|
func TestConnectionStatelessReset(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
errChan := make(chan error, 1)
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ClosedConnection(&StatelessResetError{}),
|
|
tracer.EXPECT().Close(),
|
|
)
|
|
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.destroy(&StatelessResetError{})
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, &StatelessResetError{})
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func getLongHeaderPacket(t *testing.T, extHdr *wire.ExtendedHeader, data []byte) receivedPacket {
|
|
t.Helper()
|
|
b, err := extHdr.Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
return receivedPacket{
|
|
data: append(b, data...),
|
|
buffer: getPacketBuffer(),
|
|
rcvTime: time.Now(),
|
|
}
|
|
}
|
|
|
|
func getShortHeaderPacket(t *testing.T, connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) receivedPacket {
|
|
t.Helper()
|
|
b, err := wire.AppendShortHeader(nil, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
|
|
require.NoError(t, err)
|
|
return receivedPacket{
|
|
data: append(b, data...),
|
|
buffer: getPacketBuffer(),
|
|
rcvTime: time.Now(),
|
|
}
|
|
}
|
|
|
|
func TestConnectionServerInvalidPackets(t *testing.T) {
|
|
t.Run("Retry", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
|
|
p := getLongHeaderPacket(t, &wire.ExtendedHeader{Header: wire.Header{
|
|
Type: protocol.PacketTypeRetry,
|
|
DestConnectionID: tc.conn.origDestConnID,
|
|
SrcConnectionID: tc.srcConnID,
|
|
Version: tc.conn.version,
|
|
Token: []byte("foobar"),
|
|
}}, make([]byte, 16) /* Retry integrity tag */)
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
|
|
wasProcessed, err := tc.conn.handleOnePacket(p)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
})
|
|
|
|
t.Run("version negotiation", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
|
|
b := wire.ComposeVersionNegotiation(
|
|
protocol.ArbitraryLenConnectionID(tc.srcConnID.Bytes()),
|
|
protocol.ArbitraryLenConnectionID(tc.conn.origDestConnID.Bytes()),
|
|
[]Version{Version1},
|
|
)
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket)
|
|
wasProcessed, err := tc.conn.handleOnePacket(receivedPacket{data: b, buffer: getPacketBuffer()})
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
})
|
|
|
|
t.Run("unsupported version", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
|
|
p := getLongHeaderPacket(t, &wire.ExtendedHeader{
|
|
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: 1234},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}, nil)
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnsupportedVersion)
|
|
wasProcessed, err := tc.conn.handleOnePacket(p)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
})
|
|
|
|
t.Run("invalid header", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
|
|
p := getLongHeaderPacket(t, &wire.ExtendedHeader{
|
|
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: Version1},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}, nil)
|
|
p.data[0] ^= 0x40 // unset the QUIC bit
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError)
|
|
wasProcessed, err := tc.conn.handleOnePacket(p)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
})
|
|
}
|
|
|
|
func TestConnectionClientDrop0RTT(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
|
|
p := getLongHeaderPacket(t, &wire.ExtendedHeader{
|
|
Header: wire.Header{Type: protocol.PacketType0RTT, Length: 2, Version: protocol.Version1},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}, nil)
|
|
tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
|
|
wasProcessed, err := tc.conn.handleOnePacket(p)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
}
|
|
|
|
func TestConnectionUnpacking(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptReceivedPacketHandler(rph),
|
|
connectionOptUnpacker(unpacker),
|
|
connectionOptTracer(tr),
|
|
)
|
|
|
|
// receive a long header packet
|
|
hdr := &wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
DestConnectionID: tc.srcConnID,
|
|
Version: protocol.Version1,
|
|
Length: 1,
|
|
},
|
|
PacketNumber: 0x37,
|
|
PacketNumberLen: protocol.PacketNumberLen1,
|
|
}
|
|
unpackedHdr := *hdr
|
|
unpackedHdr.PacketNumber = 0x1337
|
|
packet := getLongHeaderPacket(t, hdr, nil)
|
|
packet.ecn = protocol.ECNCE
|
|
rcvTime := time.Now().Add(-10 * time.Second)
|
|
packet.rcvTime = rcvTime
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
|
encryptionLevel: protocol.EncryptionInitial,
|
|
hdr: &unpackedHdr,
|
|
data: []byte{0}, // one PADDING frame
|
|
}, nil)
|
|
gomock.InOrder(
|
|
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial),
|
|
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNCE, protocol.EncryptionInitial, rcvTime, false),
|
|
)
|
|
|
|
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECNCE, []logging.Frame{})
|
|
wasProcessed, err := tc.conn.handleOnePacket(packet)
|
|
require.NoError(t, err)
|
|
require.True(t, wasProcessed)
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// receive a duplicate of this packet
|
|
packet = getLongHeaderPacket(t, hdr, nil)
|
|
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial).Return(true)
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
|
encryptionLevel: protocol.EncryptionInitial,
|
|
hdr: &unpackedHdr,
|
|
data: []byte{0}, // one PADDING frame
|
|
}, nil)
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.PacketNumber(0x1337), protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate)
|
|
wasProcessed, err = tc.conn.handleOnePacket(packet)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// receive a short header packet
|
|
packet = getShortHeaderPacket(t, tc.srcConnID, 0x37, nil)
|
|
packet.ecn = protocol.ECT1
|
|
packet.rcvTime = rcvTime
|
|
gomock.InOrder(
|
|
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT),
|
|
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECT1, protocol.Encryption1RTT, rcvTime, false),
|
|
)
|
|
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
|
|
protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil,
|
|
)
|
|
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{})
|
|
wasProcessed, err = tc.conn.handleOnePacket(packet)
|
|
require.NoError(t, err)
|
|
require.True(t, wasProcessed)
|
|
}
|
|
|
|
func TestConnectionUnpackCoalescedPacket(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptReceivedPacketHandler(rph),
|
|
connectionOptUnpacker(unpacker),
|
|
connectionOptTracer(tr),
|
|
)
|
|
hdr1 := &wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
DestConnectionID: tc.srcConnID,
|
|
Version: protocol.Version1,
|
|
Length: 1,
|
|
},
|
|
PacketNumber: 37,
|
|
PacketNumberLen: protocol.PacketNumberLen1,
|
|
}
|
|
hdr2 := &wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
DestConnectionID: tc.srcConnID,
|
|
Version: protocol.Version1,
|
|
Length: 1,
|
|
},
|
|
PacketNumber: 38,
|
|
PacketNumberLen: protocol.PacketNumberLen1,
|
|
}
|
|
// add a packet with a different source connection ID
|
|
incorrectSrcConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc})
|
|
hdr3 := &wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
DestConnectionID: incorrectSrcConnID,
|
|
Version: protocol.Version1,
|
|
Length: 1,
|
|
},
|
|
PacketNumber: 0x42,
|
|
PacketNumberLen: protocol.PacketNumberLen1,
|
|
}
|
|
unpackedHdr1 := *hdr1
|
|
unpackedHdr1.PacketNumber = 1337
|
|
unpackedHdr2 := *hdr2
|
|
unpackedHdr2.PacketNumber = 1338
|
|
|
|
packet := getLongHeaderPacket(t, hdr1, nil)
|
|
packet2 := getLongHeaderPacket(t, hdr2, nil)
|
|
packet3 := getLongHeaderPacket(t, hdr3, nil)
|
|
packet.data = append(packet.data, packet2.data...)
|
|
packet.data = append(packet.data, packet3.data...)
|
|
packet.ecn = protocol.ECT1
|
|
rcvTime := time.Now()
|
|
packet.rcvTime = rcvTime
|
|
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
|
encryptionLevel: protocol.EncryptionInitial,
|
|
hdr: &unpackedHdr1,
|
|
data: []byte{0}, // one PADDING frame
|
|
}, nil)
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
|
encryptionLevel: protocol.EncryptionHandshake,
|
|
hdr: &unpackedHdr2,
|
|
data: []byte{1}, // one PING frame
|
|
}, nil)
|
|
gomock.InOrder(
|
|
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(1337), protocol.EncryptionInitial),
|
|
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(1337), protocol.ECT1, protocol.EncryptionInitial, rcvTime, false),
|
|
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(1338), protocol.EncryptionHandshake),
|
|
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(1338), protocol.ECT1, protocol.EncryptionHandshake, rcvTime, true),
|
|
)
|
|
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial)
|
|
rph.EXPECT().DropPackets(protocol.EncryptionInitial)
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{}),
|
|
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{&wire.PingFrame{}}),
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(packet3.data)), logging.PacketDropUnknownConnectionID),
|
|
)
|
|
wasProcessed, err := tc.conn.handleOnePacket(packet)
|
|
require.NoError(t, err)
|
|
require.True(t, wasProcessed)
|
|
}
|
|
|
|
func TestConnectionUnpackFailuresFatal(t *testing.T) {
|
|
t.Run("other errors", func(t *testing.T) {
|
|
require.ErrorIs(t,
|
|
testConnectionUnpackFailureFatal(t, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}),
|
|
&qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError},
|
|
)
|
|
})
|
|
|
|
t.Run("invalid reserved bits", func(t *testing.T) {
|
|
require.ErrorIs(t,
|
|
testConnectionUnpackFailureFatal(t, wire.ErrInvalidReservedBits),
|
|
&qerr.TransportError{ErrorCode: qerr.ProtocolViolation},
|
|
)
|
|
})
|
|
}
|
|
|
|
func testConnectionUnpackFailureFatal(t *testing.T, unpackErr error) error {
|
|
mockCtrl := gomock.NewController(t)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptUnpacker(unpacker),
|
|
)
|
|
|
|
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
|
|
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, unpackErr)
|
|
tc.packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tc.conn.handlePacket(getShortHeaderPacket(t, tc.srcConnID, 0x42, nil))
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.Error(t, err)
|
|
return err
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestConnectionUnpackFailureDropped(t *testing.T) {
|
|
t.Run("keys dropped", func(t *testing.T) {
|
|
testConnectionUnpackFailureDropped(t, handshake.ErrKeysDropped, logging.PacketDropKeyUnavailable)
|
|
})
|
|
|
|
t.Run("decryption failed", func(t *testing.T) {
|
|
testConnectionUnpackFailureDropped(t, handshake.ErrDecryptionFailed, logging.PacketDropPayloadDecryptError)
|
|
})
|
|
|
|
t.Run("header parse error", func(t *testing.T) {
|
|
testErr := errors.New("foo")
|
|
testConnectionUnpackFailureDropped(t, &headerParseError{err: testErr}, logging.PacketDropHeaderParseError)
|
|
})
|
|
}
|
|
|
|
func testConnectionUnpackFailureDropped(t *testing.T, unpackErr error, packetDropReason logging.PacketDropReason) {
|
|
mockCtrl := gomock.NewController(t)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptUnpacker(unpacker),
|
|
connectionOptTracer(tr),
|
|
)
|
|
|
|
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, unpackErr)
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
done := make(chan struct{})
|
|
tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.InvalidPacketNumber, gomock.Any(), packetDropReason).Do(
|
|
func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason) {
|
|
close(done)
|
|
},
|
|
)
|
|
tc.conn.handlePacket(getShortHeaderPacket(t, tc.srcConnID, 0x42, nil))
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tracer.EXPECT().ClosedConnection(gomock.Any())
|
|
tracer.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case <-errChan:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionMaxUnprocessedPackets(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
|
|
done := make(chan struct{})
|
|
|
|
for i := protocol.PacketNumber(0); i < protocol.MaxConnUnprocessedPackets; i++ {
|
|
// nothing here should block
|
|
tc.conn.handlePacket(receivedPacket{data: []byte("foobar")})
|
|
}
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, logging.ByteCount(6), logging.PacketDropDOSPrevention).Do(func(logging.PacketType, logging.PacketNumber, logging.ByteCount, logging.PacketDropReason) {
|
|
close(done)
|
|
})
|
|
tc.conn.handlePacket(receivedPacket{data: []byte("foobar")})
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionRemoteClose(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
mockStreamManager := NewMockStreamManager(mockCtrl)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptStreamManager(mockStreamManager),
|
|
connectionOptTracer(tr),
|
|
connectionOptUnpacker(unpacker),
|
|
)
|
|
ccf, err := (&wire.ConnectionCloseFrame{
|
|
ErrorCode: uint64(qerr.StreamLimitError),
|
|
ReasonPhrase: "foobar",
|
|
}).Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2, protocol.KeyPhaseBit(0), ccf, nil)
|
|
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
|
|
expectedErr := &qerr.TransportError{ErrorCode: qerr.StreamLimitError, Remote: true}
|
|
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
|
|
streamErrChan := make(chan error, 1)
|
|
mockStreamManager.EXPECT().CloseWithError(gomock.Any()).Do(func(e error) { streamErrChan <- e })
|
|
tracerErrChan := make(chan error, 1)
|
|
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { tracerErrChan <- e })
|
|
tracer.EXPECT().Close()
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
p := getShortHeaderPacket(t, tc.srcConnID, 1, []byte("encrypted"))
|
|
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, expectedErr)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
select {
|
|
case err := <-tracerErrChan:
|
|
require.ErrorIs(t, err, expectedErr)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
select {
|
|
case err := <-streamErrChan:
|
|
require.ErrorIs(t, err, expectedErr)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionIdleTimeoutDuringHandshake(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
&Config{HandshakeIdleTimeout: scaleDuration(25 * time.Millisecond)},
|
|
false,
|
|
connectionOptTracer(tr),
|
|
)
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ClosedConnection(&IdleTimeoutError{}),
|
|
tracer.EXPECT().Close(),
|
|
)
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, &IdleTimeoutError{})
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionHandshakeIdleTimeout(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
&Config{HandshakeIdleTimeout: scaleDuration(25 * time.Millisecond)},
|
|
false,
|
|
connectionOptTracer(tr),
|
|
func(c *connection) { c.creationTime = time.Now().Add(-10 * time.Second) },
|
|
)
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ClosedConnection(&HandshakeTimeoutError{}),
|
|
tracer.EXPECT().Close(),
|
|
)
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, &HandshakeTimeoutError{})
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionTransportParameters(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
streamManager := NewMockStreamManager(mockCtrl)
|
|
connFC := mocks.NewMockConnectionFlowController(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
connectionOptStreamManager(streamManager),
|
|
connectionOptConnFlowController(connFC),
|
|
)
|
|
tracer.EXPECT().ReceivedTransportParameters(gomock.Any())
|
|
params := &wire.TransportParameters{
|
|
MaxIdleTimeout: 90 * time.Second,
|
|
InitialMaxStreamDataBidiLocal: 0x5000,
|
|
InitialMaxData: 0x5000,
|
|
ActiveConnectionIDLimit: 3,
|
|
// marshaling always sets it to this value
|
|
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
|
|
OriginalDestinationConnectionID: tc.destConnID,
|
|
}
|
|
streamManager.EXPECT().UpdateLimits(params)
|
|
connFC.EXPECT().UpdateSendWindow(params.InitialMaxData)
|
|
require.NoError(t, tc.conn.handleTransportParameters(params))
|
|
}
|
|
|
|
func TestConnectionTransportParameterValidationFailureServer(t *testing.T) {
|
|
tc := newServerTestConnection(t, nil, nil, false)
|
|
err := tc.conn.handleTransportParameters(&wire.TransportParameters{
|
|
InitialSourceConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
|
})
|
|
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
|
|
assert.ErrorContains(t, err, "expected initial_source_connection_id to equal")
|
|
}
|
|
|
|
func TestConnectionTransportParameterValidationFailureClient(t *testing.T) {
|
|
t.Run("initial_source_connection_id", func(t *testing.T) {
|
|
tc := newClientTestConnection(t, nil, nil, false)
|
|
err := tc.conn.handleTransportParameters(&wire.TransportParameters{
|
|
InitialSourceConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
|
})
|
|
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
|
|
assert.ErrorContains(t, err, "expected initial_source_connection_id to equal")
|
|
})
|
|
|
|
t.Run("original_destination_connection_id", func(t *testing.T) {
|
|
tc := newClientTestConnection(t, nil, nil, false)
|
|
err := tc.conn.handleTransportParameters(&wire.TransportParameters{
|
|
InitialSourceConnectionID: tc.destConnID,
|
|
OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
|
})
|
|
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
|
|
assert.ErrorContains(t, err, "expected original_destination_connection_id to equal")
|
|
})
|
|
|
|
t.Run("retry_source_connection_id if no retry", func(t *testing.T) {
|
|
tc := newClientTestConnection(t, nil, nil, false)
|
|
rcid := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
|
params := &wire.TransportParameters{
|
|
InitialSourceConnectionID: tc.destConnID,
|
|
OriginalDestinationConnectionID: tc.destConnID,
|
|
RetrySourceConnectionID: &rcid,
|
|
}
|
|
err := tc.conn.handleTransportParameters(params)
|
|
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
|
|
assert.ErrorContains(t, err, "received retry_source_connection_id, although no Retry was performed")
|
|
})
|
|
|
|
t.Run("retry_source_connection_id missing", func(t *testing.T) {
|
|
tc := newClientTestConnection(t,
|
|
nil,
|
|
nil,
|
|
false,
|
|
connectionOptRetrySrcConnID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})),
|
|
)
|
|
params := &wire.TransportParameters{
|
|
InitialSourceConnectionID: tc.destConnID,
|
|
OriginalDestinationConnectionID: tc.destConnID,
|
|
}
|
|
err := tc.conn.handleTransportParameters(params)
|
|
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
|
|
assert.ErrorContains(t, err, "missing retry_source_connection_id")
|
|
})
|
|
|
|
t.Run("retry_source_connection_id incorrect", func(t *testing.T) {
|
|
tc := newClientTestConnection(t,
|
|
nil,
|
|
nil,
|
|
false,
|
|
connectionOptRetrySrcConnID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})),
|
|
)
|
|
wrongCID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
|
params := &wire.TransportParameters{
|
|
InitialSourceConnectionID: tc.destConnID,
|
|
OriginalDestinationConnectionID: tc.destConnID,
|
|
RetrySourceConnectionID: &wrongCID,
|
|
}
|
|
err := tc.conn.handleTransportParameters(params)
|
|
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
|
|
assert.ErrorContains(t, err, "expected retry_source_connection_id to equal")
|
|
})
|
|
}
|
|
|
|
func TestConnectionHandshakeServer(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
cs := mocks.NewMockCryptoSetup(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newServerTestConnection(
|
|
t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptCryptoSetup(cs),
|
|
connectionOptUnpacker(unpacker),
|
|
)
|
|
|
|
// the state transition is driven by processing of a CRYPTO frame
|
|
hdr := &wire.ExtendedHeader{
|
|
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}
|
|
data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
|
|
cs.EXPECT().DiscardInitialKeys()
|
|
tc.connRunner.EXPECT().Retire(gomock.Any())
|
|
gomock.InOrder(
|
|
cs.EXPECT().StartHandshake(gomock.Any()),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
|
|
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil,
|
|
),
|
|
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
cs.EXPECT().SetHandshakeConfirmed(),
|
|
cs.EXPECT().GetSessionTicket().Return([]byte("session ticket"), nil),
|
|
)
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
p := getLongHeaderPacket(t, hdr, nil)
|
|
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
|
|
|
|
select {
|
|
case <-tc.conn.HandshakeComplete():
|
|
case <-tc.conn.Context().Done():
|
|
t.Fatal("connection context done")
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
var foundSessionTicket, foundHandshakeDone, foundNewToken bool
|
|
frames, _, _ := tc.conn.framer.Append(nil, nil, protocol.MaxByteCount, time.Now(), protocol.Version1)
|
|
for _, frame := range frames {
|
|
switch f := frame.Frame.(type) {
|
|
case *wire.CryptoFrame:
|
|
assert.Equal(t, []byte("session ticket"), f.Data)
|
|
foundSessionTicket = true
|
|
case *wire.HandshakeDoneFrame:
|
|
foundHandshakeDone = true
|
|
case *wire.NewTokenFrame:
|
|
assert.NotEmpty(t, f.Token)
|
|
foundNewToken = true
|
|
}
|
|
}
|
|
assert.True(t, foundSessionTicket)
|
|
assert.True(t, foundHandshakeDone)
|
|
assert.True(t, foundNewToken)
|
|
|
|
// test teardown
|
|
cs.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionHandshakeClient(t *testing.T) {
|
|
t.Run("without preferred address", func(t *testing.T) {
|
|
testConnectionHandshakeClient(t, false)
|
|
})
|
|
t.Run("with preferred address", func(t *testing.T) {
|
|
testConnectionHandshakeClient(t, true)
|
|
})
|
|
}
|
|
|
|
func testConnectionHandshakeClient(t *testing.T, usePreferredAddress bool) {
|
|
mockCtrl := gomock.NewController(t)
|
|
cs := mocks.NewMockCryptoSetup(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptCryptoSetup(cs), connectionOptUnpacker(unpacker))
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
|
|
// the state transition is driven by processing of a CRYPTO frame
|
|
hdr := &wire.ExtendedHeader{
|
|
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}
|
|
data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
|
|
tp := &wire.TransportParameters{
|
|
OriginalDestinationConnectionID: tc.destConnID,
|
|
MaxIdleTimeout: time.Hour,
|
|
}
|
|
preferredAddressConnID := protocol.ParseConnectionID([]byte{10, 8, 6, 4})
|
|
preferredAddressResetToken := protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}
|
|
if usePreferredAddress {
|
|
tp.PreferredAddress = &wire.PreferredAddress{
|
|
IPv4: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 42),
|
|
IPv6: netip.AddrPortFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), 13),
|
|
ConnectionID: preferredAddressConnID,
|
|
StatelessResetToken: preferredAddressResetToken,
|
|
}
|
|
}
|
|
|
|
packedFirstPacket := make(chan struct{})
|
|
gomock.InOrder(
|
|
cs.EXPECT().StartHandshake(gomock.Any()),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(
|
|
func(b bool, bc protocol.ByteCount, t time.Time, v protocol.Version) (*coalescedPacket, error) {
|
|
close(packedFirstPacket)
|
|
return &coalescedPacket{buffer: getPacketBuffer(), longHdrPackets: []*longHeaderPacket{{header: hdr}}}, nil
|
|
},
|
|
),
|
|
// initial keys are dropped when the first handshake packet is sent
|
|
cs.EXPECT().DiscardInitialKeys(),
|
|
// no more data to send
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
|
|
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil,
|
|
),
|
|
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedTransportParameters, TransportParameters: tp}),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
)
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes()
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
select {
|
|
case <-packedFirstPacket:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
p := getLongHeaderPacket(t, hdr, nil)
|
|
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
|
|
|
|
select {
|
|
case <-tc.conn.HandshakeComplete():
|
|
case <-tc.conn.Context().Done():
|
|
t.Fatal("connection context done")
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
require.True(t, mockCtrl.Satisfied())
|
|
// the handshake isn't confirmed until we receive a HANDSHAKE_DONE frame from the server
|
|
|
|
data, err = (&wire.HandshakeDoneFrame{}).Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
done := make(chan struct{})
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes()
|
|
gomock.InOrder(
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
|
|
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.Encryption1RTT, data: data}, nil,
|
|
),
|
|
cs.EXPECT().SetHandshakeConfirmed(),
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
|
|
close(done)
|
|
return shortHeaderPacket{}, errNothingToPack
|
|
},
|
|
),
|
|
)
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
|
|
p = getLongHeaderPacket(t, hdr, nil)
|
|
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
if usePreferredAddress {
|
|
tc.connRunner.EXPECT().AddResetToken(preferredAddressResetToken, gomock.Any())
|
|
}
|
|
nextConnID := tc.conn.connIDManager.Get()
|
|
if usePreferredAddress {
|
|
require.Equal(t, preferredAddressConnID, nextConnID)
|
|
}
|
|
|
|
// test teardown
|
|
cs.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
if usePreferredAddress {
|
|
tc.connRunner.EXPECT().RemoveResetToken(preferredAddressResetToken)
|
|
}
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnection0RTTTransportParameters(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
cs := mocks.NewMockCryptoSetup(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptCryptoSetup(cs), connectionOptUnpacker(unpacker))
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
|
|
// the state transition is driven by processing of a CRYPTO frame
|
|
hdr := &wire.ExtendedHeader{
|
|
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1},
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}
|
|
data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
|
|
restored := &wire.TransportParameters{
|
|
ActiveConnectionIDLimit: 3,
|
|
InitialMaxData: 0x5000,
|
|
InitialMaxStreamDataBidiLocal: 0x5000,
|
|
InitialMaxStreamDataBidiRemote: 1000,
|
|
InitialMaxStreamDataUni: 1000,
|
|
MaxBidiStreamNum: 500,
|
|
MaxUniStreamNum: 500,
|
|
}
|
|
new := *restored
|
|
new.MaxBidiStreamNum-- // the server is not allowed to reduce the limit
|
|
new.OriginalDestinationConnectionID = tc.destConnID
|
|
|
|
packedFirstPacket := make(chan struct{})
|
|
gomock.InOrder(
|
|
cs.EXPECT().StartHandshake(gomock.Any()),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventRestoredTransportParameters, TransportParameters: restored}),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(
|
|
func(b bool, bc protocol.ByteCount, t time.Time, v protocol.Version) (*coalescedPacket, error) {
|
|
close(packedFirstPacket)
|
|
return &coalescedPacket{buffer: getPacketBuffer(), longHdrPackets: []*longHeaderPacket{{header: hdr}}}, nil
|
|
},
|
|
),
|
|
// initial keys are dropped when the first handshake packet is sent
|
|
cs.EXPECT().DiscardInitialKeys(),
|
|
// no more data to send
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
|
|
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil,
|
|
),
|
|
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake),
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedTransportParameters, TransportParameters: &new}),
|
|
cs.EXPECT().ConnectionState().Return(handshake.ConnectionState{Used0RTT: true}),
|
|
// cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
cs.EXPECT().Close(),
|
|
)
|
|
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes()
|
|
tc.packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
|
|
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
select {
|
|
case <-packedFirstPacket:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
p := getLongHeaderPacket(t, hdr, nil)
|
|
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
|
|
require.ErrorContains(t, err, "server sent reduced limits after accepting 0-RTT data")
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionReceivePrioritization(t *testing.T) {
|
|
t.Run("handshake complete", func(t *testing.T) {
|
|
events := testConnectionReceivePrioritization(t, true, 5)
|
|
require.Equal(t, []string{"unpack", "unpack", "unpack", "unpack", "unpack", "pack"}, events)
|
|
})
|
|
|
|
// before handshake completion, we trigger packing of a new packet every time we receive a packet
|
|
t.Run("handshake not complete", func(t *testing.T) {
|
|
events := testConnectionReceivePrioritization(t, false, 5)
|
|
require.Equal(t, []string{
|
|
"unpack", "pack",
|
|
"unpack", "pack",
|
|
"unpack", "pack",
|
|
"unpack", "pack",
|
|
"unpack", "pack",
|
|
}, events)
|
|
})
|
|
}
|
|
|
|
func testConnectionReceivePrioritization(t *testing.T, handshakeComplete bool, numPackets int) []string {
|
|
mockCtrl := gomock.NewController(t)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
opts := []testConnectionOpt{connectionOptUnpacker(unpacker)}
|
|
if handshakeComplete {
|
|
opts = append(opts, connectionOptHandshakeConfirmed())
|
|
}
|
|
tc := newServerTestConnection(t, mockCtrl, nil, false, opts...)
|
|
|
|
var events []string
|
|
var counter int
|
|
var testDone bool
|
|
done := make(chan struct{})
|
|
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
|
|
counter++
|
|
if counter == numPackets {
|
|
testDone = true
|
|
}
|
|
events = append(events, "unpack")
|
|
return protocol.PacketNumber(counter), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0, 1} /* PADDING, PING */, nil
|
|
},
|
|
).Times(numPackets)
|
|
switch handshakeComplete {
|
|
case false:
|
|
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(b bool, bc protocol.ByteCount, t time.Time, v protocol.Version) (*coalescedPacket, error) {
|
|
events = append(events, "pack")
|
|
if testDone {
|
|
close(done)
|
|
}
|
|
return nil, nil
|
|
},
|
|
).AnyTimes()
|
|
case true:
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(b *packetBuffer, bc protocol.ByteCount, t time.Time, v protocol.Version) (shortHeaderPacket, error) {
|
|
events = append(events, "pack")
|
|
if testDone {
|
|
close(done)
|
|
}
|
|
return shortHeaderPacket{}, errNothingToPack
|
|
},
|
|
).AnyTimes()
|
|
}
|
|
|
|
for i := range numPackets {
|
|
tc.conn.handlePacket(getShortHeaderPacket(t, tc.srcConnID, protocol.PacketNumber(i), []byte("foobar")))
|
|
}
|
|
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
return events
|
|
}
|
|
|
|
func TestConnectionPacketBuffering(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
cs := mocks.NewMockCryptoSetup(mockCtrl)
|
|
tracer, tr := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptUnpacker(unpacker),
|
|
connectionOptCryptoSetup(cs),
|
|
connectionOptTracer(tracer),
|
|
)
|
|
|
|
tr.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tr.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
tr.EXPECT().DroppedEncryptionLevel(gomock.Any())
|
|
cs.EXPECT().DiscardInitialKeys()
|
|
|
|
hdr1 := wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
DestConnectionID: tc.srcConnID,
|
|
SrcConnectionID: tc.destConnID,
|
|
Length: 8,
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: protocol.PacketNumberLen1,
|
|
PacketNumber: 1,
|
|
}
|
|
hdr2 := hdr1
|
|
hdr2.PacketNumber = 2
|
|
cs.EXPECT().StartHandshake(gomock.Any())
|
|
buffered := make(chan struct{})
|
|
gomock.InOrder(
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable),
|
|
tr.EXPECT().BufferedPacket(logging.PacketTypeHandshake, gomock.Any()),
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable),
|
|
tr.EXPECT().BufferedPacket(logging.PacketTypeHandshake, gomock.Any()).Do(
|
|
func(logging.PacketType, logging.ByteCount) { close(buffered) },
|
|
),
|
|
)
|
|
|
|
tc.conn.handlePacket(getLongHeaderPacket(t, &hdr1, []byte("packet1")))
|
|
tc.conn.handlePacket(getLongHeaderPacket(t, &hdr2, []byte("packet2")))
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
select {
|
|
case <-buffered:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// Now send another packet.
|
|
// In reality, this packet would contain a CRYPTO frame that advances the TLS handshake
|
|
// such that new keys become available.
|
|
var packets []string
|
|
hdr3 := hdr1
|
|
hdr3.PacketNumber = 3
|
|
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
|
unpacked := make(chan struct{})
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedReadKeys})
|
|
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
|
|
|
|
gomock.InOrder(
|
|
// packet 3 contains a CRYPTO frame and triggers the keys to become available
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
|
|
packets = append(packets, string(data[len(data)-7:]))
|
|
cf := &wire.CryptoFrame{Data: []byte("foobar")}
|
|
b, _ := cf.Append(nil, protocol.Version1)
|
|
return &unpackedPacket{hdr: &hdr3, encryptionLevel: protocol.EncryptionHandshake, data: b}, nil
|
|
},
|
|
),
|
|
cs.EXPECT().HandleMessage(gomock.Any(), gomock.Any()),
|
|
tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
// packet 1 dequeued from the buffer
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
|
|
packets = append(packets, string(data[len(data)-7:]))
|
|
return &unpackedPacket{hdr: &hdr1, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil
|
|
},
|
|
),
|
|
tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
// packet 2 dequeued from the buffer
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
|
|
packets = append(packets, string(data[len(data)-7:]))
|
|
close(unpacked)
|
|
return &unpackedPacket{hdr: &hdr2, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil
|
|
},
|
|
),
|
|
tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
)
|
|
|
|
tc.conn.handlePacket(getLongHeaderPacket(t, &hdr3, []byte("packet3")))
|
|
|
|
select {
|
|
case <-unpacked:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// packet3 triggered the keys to become available
|
|
// packet1 and packet2 are processed from the buffer in order
|
|
require.Equal(t, []string{"packet3", "packet1", "packet2"}, packets)
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
cs.EXPECT().Close()
|
|
tr.EXPECT().ClosedConnection(gomock.Any())
|
|
tr.EXPECT().Close()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionPacketPacing(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
sender := NewMockSender(mockCtrl)
|
|
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptSentPacketHandler(sph),
|
|
connectionOptSender(sender),
|
|
connectionOptHandshakeConfirmed(),
|
|
// set a fixed RTT, so that the idle timeout doesn't interfere with this test
|
|
connectionOptRTT(10*time.Second),
|
|
)
|
|
sender.EXPECT().Run()
|
|
|
|
step := scaleDuration(50 * time.Millisecond)
|
|
|
|
sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes()
|
|
gomock.InOrder(
|
|
// 1. allow 2 packets to be sent
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
|
|
// 2. become pacing limited for 25ms
|
|
sph.EXPECT().TimeUntilSend().DoAndReturn(func() time.Time { return time.Now().Add(step) }),
|
|
// 3. send another packet
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
|
|
// 4. become pacing limited for 25ms...
|
|
sph.EXPECT().TimeUntilSend().DoAndReturn(func() time.Time { return time.Now().Add(step) }),
|
|
// ... but this time we're still pacing limited when waking up.
|
|
// In this case, we can only send an ACK.
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
|
|
// 5. stop the test by becoming pacing limited forever
|
|
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)),
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
|
)
|
|
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
|
|
for i := 0; i < 3; i++ {
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), Version1).DoAndReturn(
|
|
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
|
|
buf.Data = append(buf.Data, []byte("packet"+strconv.Itoa(i+1))...)
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i + 1)}, nil
|
|
},
|
|
)
|
|
}
|
|
tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(_ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
|
buf := getPacketBuffer()
|
|
buf.Data = []byte("ack")
|
|
return shortHeaderPacket{PacketNumber: 1}, buf, nil
|
|
},
|
|
)
|
|
sender.EXPECT().WouldBlock().AnyTimes()
|
|
|
|
type sentPacket struct {
|
|
time time.Time
|
|
data []byte
|
|
}
|
|
sendChan := make(chan sentPacket, 10)
|
|
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
|
|
sendChan <- sentPacket{time: time.Now(), data: b.Data}
|
|
}).Times(4)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
var times []time.Time
|
|
for i := 0; i < 3; i++ {
|
|
select {
|
|
case b := <-sendChan:
|
|
require.Equal(t, []byte("packet"+strconv.Itoa(i+1)), b.data)
|
|
times = append(times, b.time)
|
|
case <-time.After(scaleDuration(time.Second)):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
select {
|
|
case b := <-sendChan:
|
|
require.Equal(t, []byte("ack"), b.data)
|
|
times = append(times, b.time)
|
|
case <-time.After(scaleDuration(time.Second)):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
require.InDelta(t, times[0].Sub(times[1]).Seconds(), 0, scaleDuration(10*time.Millisecond).Seconds())
|
|
require.InDelta(t, times[2].Sub(times[1]).Seconds(), step.Seconds(), scaleDuration(20*time.Millisecond).Seconds())
|
|
require.InDelta(t, times[3].Sub(times[2]).Seconds(), step.Seconds(), scaleDuration(20*time.Millisecond).Seconds())
|
|
|
|
time.Sleep(scaleDuration(step)) // make sure that no more packets are sent
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// test teardown
|
|
sender.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case <-sendChan:
|
|
t.Fatal("should not have sent any more packets")
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionIdleTimeout(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
&Config{MaxIdleTimeout: time.Second},
|
|
false,
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptSentPacketHandler(sph),
|
|
connectionOptRTT(time.Millisecond),
|
|
)
|
|
// the idle timeout is set when the transport parameters are received
|
|
idleTimeout := scaleDuration(50 * time.Millisecond)
|
|
require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
|
|
MaxIdleTimeout: idleTimeout,
|
|
}))
|
|
|
|
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
|
|
var lastSendTime time.Time
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
|
|
buf.Data = append(buf.Data, []byte("foobar")...)
|
|
lastSendTime = time.Now()
|
|
return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil
|
|
},
|
|
)
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, &IdleTimeoutError{})
|
|
require.NotZero(t, lastSendTime)
|
|
require.InDelta(t,
|
|
time.Since(lastSendTime).Seconds(),
|
|
idleTimeout.Seconds(),
|
|
scaleDuration(10*time.Millisecond).Seconds(),
|
|
)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionKeepAlive(t *testing.T) {
|
|
t.Run("enabled", func(t *testing.T) {
|
|
testConnectionKeepAlive(t, true, true)
|
|
})
|
|
|
|
t.Run("disabled", func(t *testing.T) {
|
|
testConnectionKeepAlive(t, false, false)
|
|
})
|
|
}
|
|
|
|
func testConnectionKeepAlive(t *testing.T, enable, expectKeepAlive bool) {
|
|
var keepAlivePeriod time.Duration
|
|
if enable {
|
|
keepAlivePeriod = time.Second
|
|
}
|
|
|
|
mockCtrl := gomock.NewController(t)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
&Config{MaxIdleTimeout: time.Second, KeepAlivePeriod: keepAlivePeriod},
|
|
false,
|
|
connectionOptUnpacker(unpacker),
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptRTT(time.Millisecond),
|
|
)
|
|
// the idle timeout is set when the transport parameters are received
|
|
idleTimeout := scaleDuration(50 * time.Millisecond)
|
|
require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
|
|
MaxIdleTimeout: idleTimeout,
|
|
}))
|
|
|
|
// Receive a packet. This starts the keep-alive timer.
|
|
buf := getPacketBuffer()
|
|
var err error
|
|
buf.Data, err = wire.AppendShortHeader(buf.Data, tc.srcConnID, 1, protocol.PacketNumberLen1, protocol.KeyPhaseZero)
|
|
require.NoError(t, err)
|
|
buf.Data = append(buf.Data, []byte("packet")...)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
var unpackTime, packTime time.Time
|
|
done := make(chan struct{})
|
|
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(t time.Time, bytes []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
|
|
unpackTime = time.Now()
|
|
return protocol.PacketNumber(1), protocol.PacketNumberLen1, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil
|
|
},
|
|
)
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
|
|
|
|
switch expectKeepAlive {
|
|
case true:
|
|
// record the time of the keep-alive is sent
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
packTime = time.Now()
|
|
close(done)
|
|
return shortHeaderPacket{}, errNothingToPack
|
|
},
|
|
)
|
|
tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: time.Now()})
|
|
select {
|
|
case <-done:
|
|
// the keep-alive packet should be sent after half the idle timeout
|
|
diff := packTime.Sub(unpackTime)
|
|
require.InDelta(t, diff.Seconds(), idleTimeout.Seconds()/2, scaleDuration(10*time.Millisecond).Seconds())
|
|
case <-time.After(idleTimeout):
|
|
t.Fatal("timeout")
|
|
}
|
|
case false: // if keep-alives are disabled, the connection will run into an idle timeout
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: time.Now()})
|
|
select {
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
case <-time.After(idleTimeout):
|
|
}
|
|
}
|
|
|
|
// test teardown
|
|
if expectKeepAlive {
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
}
|
|
select {
|
|
case err := <-errChan:
|
|
if expectKeepAlive {
|
|
require.NoError(t, err)
|
|
} else {
|
|
require.ErrorIs(t, err, &IdleTimeoutError{})
|
|
}
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionACKTimer(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
&Config{MaxIdleTimeout: time.Second},
|
|
false,
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptReceivedPacketHandler(rph),
|
|
connectionOptSentPacketHandler(sph),
|
|
connectionOptRTT(10*time.Second),
|
|
)
|
|
alarmTimeout := scaleDuration(50 * time.Millisecond)
|
|
|
|
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
|
|
rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour))
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
|
|
var times []time.Time
|
|
done := make(chan struct{}, 5)
|
|
var calls []any
|
|
for i := 0; i < 2; i++ {
|
|
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
|
|
buf.Data = append(buf.Data, []byte("foobar")...)
|
|
times = append(times, time.Now())
|
|
return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil
|
|
},
|
|
))
|
|
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
|
|
done <- struct{}{}
|
|
return shortHeaderPacket{}, errNothingToPack
|
|
},
|
|
))
|
|
if i == 0 {
|
|
calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(alarmTimeout)))
|
|
} else {
|
|
calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).MaxTimes(1))
|
|
}
|
|
}
|
|
gomock.InOrder(calls...)
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
for i := 0; i < 2; i++ {
|
|
select {
|
|
case <-done:
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
assert.Len(t, times, 2)
|
|
require.InDelta(t, times[1].Sub(times[0]).Seconds(), alarmTimeout.Seconds(), scaleDuration(10*time.Millisecond).Seconds())
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
// Send a GSO batch, until we have no more data to send.
|
|
func TestConnectionGSOBatch(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
true,
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptSentPacketHandler(sph),
|
|
)
|
|
|
|
// allow packets to be sent
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
|
|
sph.EXPECT().TimeUntilSend().Return(time.Time{}).AnyTimes()
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
sph.EXPECT().GetLossDetectionTimeout().Return(time.Time{}).AnyTimes()
|
|
sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
|
|
|
|
maxPacketSize := tc.conn.maxPacketSize()
|
|
var expectedData []byte
|
|
for i := 0; i < 4; i++ {
|
|
data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
|
|
expectedData = append(expectedData, data...)
|
|
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
buffer.Data = append(buffer.Data, data...)
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil
|
|
},
|
|
)
|
|
}
|
|
done := make(chan struct{})
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
|
|
tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1).DoAndReturn(
|
|
func([]byte, uint16, protocol.ECN) error { close(done); return nil },
|
|
)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
// Send a GSO batch, until a packet smaller than the maximum size is packed
|
|
func TestConnectionGSOBatchPacketSize(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
true,
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptSentPacketHandler(sph),
|
|
)
|
|
|
|
// allow packets to be sent
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
|
|
sph.EXPECT().TimeUntilSend().Return(time.Time{}).AnyTimes()
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
sph.EXPECT().GetLossDetectionTimeout().Return(time.Time{}).AnyTimes()
|
|
sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
|
|
|
|
maxPacketSize := tc.conn.maxPacketSize()
|
|
var expectedData []byte
|
|
var calls []any
|
|
for i := 0; i < 4; i++ {
|
|
var data []byte
|
|
if i == 3 {
|
|
data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize-1))
|
|
} else {
|
|
data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
|
|
}
|
|
expectedData = append(expectedData, data...)
|
|
|
|
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
buffer.Data = append(buffer.Data, data...)
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(10 + i)}, nil
|
|
},
|
|
))
|
|
}
|
|
// The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch.
|
|
// We therefore send a "foobar", so we can check that we're actually generating two GSO batches.
|
|
calls = append(calls,
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
buffer.Data = append(buffer.Data, []byte("foobar")...)
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(14)}, nil
|
|
},
|
|
),
|
|
)
|
|
calls = append(calls,
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack),
|
|
)
|
|
gomock.InOrder(calls...)
|
|
|
|
done := make(chan struct{})
|
|
gomock.InOrder(
|
|
tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1),
|
|
tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECT1).DoAndReturn(
|
|
func([]byte, uint16, protocol.ECN) error { close(done); return nil },
|
|
),
|
|
)
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionGSOBatchECN(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
true,
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptSentPacketHandler(sph),
|
|
)
|
|
|
|
// allow packets to be sent
|
|
ecnMode := protocol.ECT1
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
|
|
sph.EXPECT().TimeUntilSend().Return(time.Time{}).AnyTimes()
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
sph.EXPECT().GetLossDetectionTimeout().Return(time.Time{}).AnyTimes()
|
|
sph.EXPECT().ECNMode(gomock.Any()).DoAndReturn(func(bool) protocol.ECN { return ecnMode }).AnyTimes()
|
|
|
|
// 3. Send a GSO batch, until the ECN marking changes.
|
|
var expectedData []byte
|
|
var calls []any
|
|
maxPacketSize := tc.conn.maxPacketSize()
|
|
for i := 0; i < 3; i++ {
|
|
data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
|
|
expectedData = append(expectedData, data...)
|
|
|
|
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
buffer.Data = append(buffer.Data, data...)
|
|
if i == 2 {
|
|
ecnMode = protocol.ECNCE
|
|
}
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(20 + i)}, nil
|
|
},
|
|
))
|
|
}
|
|
// The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch.
|
|
// We therefore send a "foobar", so we can check that we're actually generating two GSO batches.
|
|
calls = append(calls,
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
buffer.Data = append(buffer.Data, []byte("foobar")...)
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(24)}, nil
|
|
},
|
|
),
|
|
)
|
|
calls = append(calls,
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack),
|
|
)
|
|
gomock.InOrder(calls...)
|
|
|
|
done3 := make(chan struct{})
|
|
tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1)
|
|
tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECNCE).DoAndReturn(
|
|
func([]byte, uint16, protocol.ECN) error { close(done3); return nil },
|
|
)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
select {
|
|
case <-done3:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionPTOProbePackets(t *testing.T) {
|
|
t.Run("Initial", func(t *testing.T) {
|
|
testConnectionPTOProbePackets(t, protocol.EncryptionInitial)
|
|
})
|
|
t.Run("Handshake", func(t *testing.T) {
|
|
testConnectionPTOProbePackets(t, protocol.EncryptionHandshake)
|
|
})
|
|
t.Run("1-RTT", func(t *testing.T) {
|
|
testConnectionPTOProbePackets(t, protocol.Encryption1RTT)
|
|
})
|
|
}
|
|
|
|
func testConnectionPTOProbePackets(t *testing.T, encLevel protocol.EncryptionLevel) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptSentPacketHandler(sph),
|
|
)
|
|
|
|
var sendMode ackhandler.SendMode
|
|
switch encLevel {
|
|
case protocol.EncryptionInitial:
|
|
sendMode = ackhandler.SendPTOInitial
|
|
case protocol.EncryptionHandshake:
|
|
sendMode = ackhandler.SendPTOHandshake
|
|
case protocol.Encryption1RTT:
|
|
sendMode = ackhandler.SendPTOAppData
|
|
}
|
|
|
|
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
|
sph.EXPECT().TimeUntilSend().AnyTimes()
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(sendMode)
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
|
|
sph.EXPECT().ECNMode(gomock.Any())
|
|
sph.EXPECT().QueueProbePacket(encLevel).Return(false)
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
|
|
tc.packer.EXPECT().MaybePackPTOProbePacket(encLevel, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(
|
|
func(encLevel protocol.EncryptionLevel, maxSize protocol.ByteCount, t time.Time, version protocol.Version) (*coalescedPacket, error) {
|
|
return &coalescedPacket{
|
|
buffer: getPacketBuffer(),
|
|
shortHdrPacket: &shortHeaderPacket{PacketNumber: 1},
|
|
}, nil
|
|
},
|
|
)
|
|
done := make(chan struct{})
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(
|
|
func([]byte, uint16, protocol.ECN) error { close(done); return nil },
|
|
)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionCongestionControl(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptSentPacketHandler(sph),
|
|
connectionOptRTT(10*time.Second),
|
|
)
|
|
|
|
sph.EXPECT().TimeUntilSend().AnyTimes()
|
|
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
|
sph.EXPECT().ECNMode(true).AnyTimes()
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2)
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck).MaxTimes(1)
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
|
|
// Since we're already sending out packets, we don't expect any calls to PackAckOnlyPacket
|
|
for i := 0; i < 2; i++ {
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
|
|
buffer.Data = append(buffer.Data, []byte("foobar")...)
|
|
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil
|
|
},
|
|
)
|
|
}
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
|
|
done1 := make(chan struct{})
|
|
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(
|
|
func([]byte, uint16, protocol.ECN) error { close(done1); return nil },
|
|
)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
select {
|
|
case <-done1:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// Now that we're congestion limited, we can only send an ack-only packet
|
|
done2 := make(chan struct{})
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck)
|
|
tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(protocol.ByteCount, time.Time, protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
|
|
close(done2)
|
|
return shortHeaderPacket{}, nil, errNothingToPack
|
|
},
|
|
)
|
|
tc.conn.scheduleSending()
|
|
select {
|
|
case <-done2:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// If the send mode is "none", we can't even send an ack-only packet
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
|
|
tc.conn.scheduleSending()
|
|
time.Sleep(scaleDuration(10 * time.Millisecond)) // make sure there are no calls to the packer
|
|
|
|
// test teardown
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionSendQueue(t *testing.T) {
|
|
t.Run("with GSO", func(t *testing.T) {
|
|
testConnectionSendQueue(t, true)
|
|
})
|
|
t.Run("without GSO", func(t *testing.T) {
|
|
testConnectionSendQueue(t, false)
|
|
})
|
|
}
|
|
|
|
func testConnectionSendQueue(t *testing.T, enableGSO bool) {
|
|
mockCtrl := gomock.NewController(t)
|
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
|
sender := NewMockSender(mockCtrl)
|
|
tc := newServerTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
enableGSO,
|
|
connectionOptSender(sender),
|
|
connectionOptHandshakeConfirmed(),
|
|
connectionOptSentPacketHandler(sph),
|
|
)
|
|
|
|
sender.EXPECT().Run().MaxTimes(1)
|
|
sender.EXPECT().WouldBlock()
|
|
sender.EXPECT().WouldBlock().Return(true).Times(2)
|
|
available := make(chan struct{})
|
|
blocked := make(chan struct{})
|
|
sender.EXPECT().Available().DoAndReturn(
|
|
func() <-chan struct{} {
|
|
close(blocked)
|
|
return available
|
|
},
|
|
)
|
|
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
|
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
|
|
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
|
|
shortHeaderPacket{PacketNumber: protocol.PacketNumber(1)}, nil,
|
|
)
|
|
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any())
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.scheduleSending()
|
|
|
|
select {
|
|
case <-blocked:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// now make room in the send queue
|
|
sender.EXPECT().WouldBlock().AnyTimes()
|
|
unblocked := make(chan struct{})
|
|
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(*packetBuffer, protocol.ByteCount, time.Time, protocol.Version) (shortHeaderPacket, error) {
|
|
close(unblocked)
|
|
return shortHeaderPacket{}, errNothingToPack
|
|
},
|
|
)
|
|
available <- struct{}{}
|
|
select {
|
|
case <-unblocked:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
sender.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func getVersionNegotiationPacket(src, dest protocol.ConnectionID, versions []protocol.Version) receivedPacket {
|
|
b := wire.ComposeVersionNegotiation(
|
|
protocol.ArbitraryLenConnectionID(src.Bytes()),
|
|
protocol.ArbitraryLenConnectionID(dest.Bytes()),
|
|
versions,
|
|
)
|
|
return receivedPacket{
|
|
rcvTime: time.Now(),
|
|
data: b,
|
|
buffer: getPacketBuffer(),
|
|
}
|
|
}
|
|
|
|
func TestConnectionVersionNegotiation(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
)
|
|
|
|
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
|
var tracerVersions []logging.Version
|
|
gomock.InOrder(
|
|
tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) {
|
|
tracerVersions = versions
|
|
}),
|
|
tracer.EXPECT().NegotiatedVersion(protocol.Version2, gomock.Any(), gomock.Any()),
|
|
tc.connRunner.EXPECT().Remove(gomock.Any()),
|
|
)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.handlePacket(getVersionNegotiationPacket(
|
|
tc.destConnID,
|
|
tc.srcConnID,
|
|
[]protocol.Version{1234, protocol.Version2},
|
|
))
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
var rerr *errCloseForRecreating
|
|
require.ErrorAs(t, err, &rerr)
|
|
require.Equal(t, rerr.nextVersion, protocol.Version2)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
require.Contains(t, tracerVersions, protocol.Version(1234))
|
|
require.Contains(t, tracerVersions, protocol.Version2)
|
|
}
|
|
|
|
func TestConnectionVersionNegotiationNoMatch(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
&Config{Versions: []protocol.Version{protocol.Version1}},
|
|
false,
|
|
connectionOptTracer(tr),
|
|
)
|
|
|
|
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
|
var tracerVersions []logging.Version
|
|
tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(
|
|
func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) { tracerVersions = versions },
|
|
)
|
|
tracer.EXPECT().ClosedConnection(gomock.Any())
|
|
tracer.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any())
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
tc.conn.handlePacket(getVersionNegotiationPacket(
|
|
tc.destConnID,
|
|
tc.srcConnID,
|
|
[]protocol.Version{protocol.Version2},
|
|
))
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
var verr *VersionNegotiationError
|
|
require.ErrorAs(t, err, &verr)
|
|
require.Contains(t, verr.Theirs, protocol.Version2)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
require.Contains(t, tracerVersions, protocol.Version2)
|
|
}
|
|
|
|
func TestConnectionVersionNegotiationInvalidPackets(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
)
|
|
|
|
// offers the current version
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedVersion)
|
|
vnp := getVersionNegotiationPacket(
|
|
tc.destConnID,
|
|
tc.srcConnID,
|
|
[]protocol.Version{1234, protocol.Version1},
|
|
)
|
|
wasProcessed, err := tc.conn.handleOnePacket(vnp)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// unparseable, since it's missing 2 bytes
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, gomock.Any(), gomock.Any(), logging.PacketDropHeaderParseError)
|
|
vnp.data = vnp.data[:len(vnp.data)-2]
|
|
wasProcessed, err = tc.conn.handleOnePacket(vnp)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
}
|
|
|
|
func getRetryPacket(t *testing.T, src, dest, origDest protocol.ConnectionID, token []byte) receivedPacket {
|
|
hdr := wire.Header{
|
|
Type: protocol.PacketTypeRetry,
|
|
SrcConnectionID: src,
|
|
DestConnectionID: dest,
|
|
Token: token,
|
|
Version: protocol.Version1,
|
|
}
|
|
b, err := (&wire.ExtendedHeader{Header: hdr}).Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
tag := handshake.GetRetryIntegrityTag(b, origDest, protocol.Version1)
|
|
b = append(b, tag[:]...)
|
|
return receivedPacket{
|
|
rcvTime: time.Now(),
|
|
data: b,
|
|
buffer: getPacketBuffer(),
|
|
}
|
|
}
|
|
|
|
func TestConnectionRetryDrops(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
connectionOptUnpacker(unpacker),
|
|
)
|
|
|
|
newConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
|
|
|
|
// invalid integrity tag
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropPayloadDecryptError)
|
|
retry := getRetryPacket(t, newConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
|
|
retry.data[len(retry.data)-1]++
|
|
wasProcessed, err := tc.conn.handleOnePacket(retry)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// receive a retry that doesn't change the connection ID
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedPacket)
|
|
retry = getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
|
|
wasProcessed, err = tc.conn.handleOnePacket(retry)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
}
|
|
|
|
func TestConnectionRetryAfterReceivedPacket(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
connectionOptUnpacker(unpacker),
|
|
)
|
|
|
|
// receive a regular packet
|
|
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
regular := getPacketWithPacketType(t, tc.srcConnID, protocol.PacketTypeInitial, 200)
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
|
|
&unpackedPacket{
|
|
hdr: &wire.ExtendedHeader{Header: wire.Header{Type: protocol.PacketTypeInitial}},
|
|
encryptionLevel: protocol.EncryptionInitial,
|
|
}, nil,
|
|
)
|
|
wasProcessed, err := tc.conn.handleOnePacket(receivedPacket{
|
|
data: regular,
|
|
buffer: getPacketBuffer(),
|
|
rcvTime: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, wasProcessed)
|
|
|
|
// receive a retry
|
|
retry := getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedPacket)
|
|
wasProcessed, err = tc.conn.handleOnePacket(retry)
|
|
require.NoError(t, err)
|
|
require.False(t, wasProcessed)
|
|
}
|
|
|
|
func TestConnectionConnectionIDChanges(t *testing.T) {
|
|
t.Run("with retry", func(t *testing.T) {
|
|
testConnectionConnectionIDChanges(t, true)
|
|
})
|
|
t.Run("without retry", func(t *testing.T) {
|
|
testConnectionConnectionIDChanges(t, false)
|
|
})
|
|
}
|
|
|
|
func testConnectionConnectionIDChanges(t *testing.T, sendRetry bool) {
|
|
makeInitialPacket := func(t *testing.T, hdr *wire.ExtendedHeader) []byte {
|
|
t.Helper()
|
|
data, err := hdr.Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
data = append(data, make([]byte, hdr.Length-protocol.ByteCount(hdr.PacketNumberLen))...)
|
|
return data
|
|
}
|
|
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
unpacker := NewMockUnpacker(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
connectionOptUnpacker(unpacker),
|
|
)
|
|
|
|
dstConnID := tc.destConnID
|
|
b := make([]byte, 3*10)
|
|
rand.Read(b)
|
|
newConnID := protocol.ParseConnectionID(b[:11])
|
|
newConnID2 := protocol.ParseConnectionID(b[11:20])
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
|
|
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
|
|
|
require.Equal(t, dstConnID, tc.conn.connIDManager.Get())
|
|
|
|
var retryConnID protocol.ConnectionID
|
|
if sendRetry {
|
|
retryConnID = protocol.ParseConnectionID(b[20:30])
|
|
hdrChan := make(chan *wire.Header)
|
|
tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { hdrChan <- hdr })
|
|
tc.packer.EXPECT().SetToken([]byte("foobar"))
|
|
|
|
tc.conn.handlePacket(getRetryPacket(t, retryConnID, tc.srcConnID, tc.destConnID, []byte("foobar")))
|
|
select {
|
|
case hdr := <-hdrChan:
|
|
assert.Equal(t, retryConnID, hdr.SrcConnectionID)
|
|
assert.Equal(t, []byte("foobar"), hdr.Token)
|
|
require.Equal(t, retryConnID, tc.conn.connIDManager.Get())
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
// Send the first packet. The server changes the connection ID to newConnID.
|
|
hdr1 := wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
SrcConnectionID: newConnID,
|
|
DestConnectionID: tc.srcConnID,
|
|
Type: protocol.PacketTypeInitial,
|
|
Length: 200,
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumber: 1,
|
|
PacketNumberLen: protocol.PacketNumberLen2,
|
|
}
|
|
hdr2 := hdr1
|
|
hdr2.SrcConnectionID = newConnID2
|
|
|
|
receivedFirst := make(chan struct{})
|
|
gomock.InOrder(
|
|
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
|
|
&unpackedPacket{
|
|
hdr: &hdr1,
|
|
encryptionLevel: protocol.EncryptionInitial,
|
|
}, nil,
|
|
),
|
|
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(
|
|
func(*wire.ExtendedHeader, protocol.ByteCount, protocol.ECN, []logging.Frame) { close(receivedFirst) },
|
|
),
|
|
)
|
|
|
|
tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr1), buffer: getPacketBuffer(), rcvTime: time.Now()})
|
|
|
|
select {
|
|
case <-receivedFirst:
|
|
require.Equal(t, newConnID, tc.conn.connIDManager.Get())
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// Send the second packet. We refuse to accept it, because the connection ID is changed again.
|
|
dropped := make(chan struct{})
|
|
tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, gomock.Any(), gomock.Any(), logging.PacketDropUnknownConnectionID).Do(
|
|
func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason) {
|
|
close(dropped)
|
|
},
|
|
)
|
|
|
|
tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr2), buffer: getPacketBuffer(), rcvTime: time.Now()})
|
|
select {
|
|
case <-dropped:
|
|
// the connection ID should not have changed
|
|
require.Equal(t, newConnID, tc.conn.connIDManager.Get())
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// test teardown
|
|
tracer.EXPECT().ClosedConnection(gomock.Any())
|
|
tracer.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any())
|
|
tc.conn.destroy(nil)
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
// When the connection is closed before sending the first packet,
|
|
// we don't send a CONNECTION_CLOSE.
|
|
// This can happen if there's something wrong the tls.Config, and
|
|
// crypto/tls refuses to start the handshake.
|
|
func TestConnectionEarlyClose(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
cryptoSetup := mocks.NewMockCryptoSetup(mockCtrl)
|
|
tc := newClientTestConnection(t,
|
|
mockCtrl,
|
|
nil,
|
|
false,
|
|
connectionOptTracer(tr),
|
|
connectionOptCryptoSetup(cryptoSetup),
|
|
)
|
|
|
|
tc.conn.sentFirstPacket = false
|
|
tracer.EXPECT().ClosedConnection(gomock.Any())
|
|
tracer.EXPECT().Close()
|
|
cryptoSetup.EXPECT().StartHandshake(gomock.Any()).Do(func(context.Context) error {
|
|
tc.conn.closeLocal(errors.New("early error"))
|
|
return nil
|
|
})
|
|
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
|
|
cryptoSetup.EXPECT().Close()
|
|
tc.connRunner.EXPECT().Remove(gomock.Any())
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- tc.conn.run() }()
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "early error")
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|