uquic/connection_test.go
2025-01-26 06:01:29 +01:00

2767 lines
99 KiB
Go

package quic
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"errors"
"net"
"net/netip"
"strconv"
"testing"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/mocks"
mockackhandler "github.com/quic-go/quic-go/internal/mocks/ackhandler"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)
type testConnectionOpt func(*connection)
func connectionOptCryptoSetup(cs *mocks.MockCryptoSetup) testConnectionOpt {
return func(conn *connection) { conn.cryptoStreamHandler = cs }
}
func connectionOptStreamManager(sm *MockStreamManager) testConnectionOpt {
return func(conn *connection) { conn.streamsMap = sm }
}
func connectionOptConnFlowController(cfc *mocks.MockConnectionFlowController) testConnectionOpt {
return func(conn *connection) { conn.connFlowController = cfc }
}
func connectionOptTracer(tr *logging.ConnectionTracer) testConnectionOpt {
return func(conn *connection) { conn.tracer = tr }
}
func connectionOptSentPacketHandler(sph ackhandler.SentPacketHandler) testConnectionOpt {
return func(conn *connection) { conn.sentPacketHandler = sph }
}
func connectionOptReceivedPacketHandler(rph ackhandler.ReceivedPacketHandler) testConnectionOpt {
return func(conn *connection) { conn.receivedPacketHandler = rph }
}
func connectionOptUnpacker(u unpacker) testConnectionOpt {
return func(conn *connection) { conn.unpacker = u }
}
func connectionOptSender(s sender) testConnectionOpt {
return func(conn *connection) { conn.sendQueue = s }
}
func connectionOptHandshakeConfirmed() testConnectionOpt {
return func(conn *connection) {
conn.handshakeComplete = true
conn.handshakeConfirmed = true
}
}
func connectionOptRTT(rtt time.Duration) testConnectionOpt {
var rttStats utils.RTTStats
rttStats.UpdateRTT(rtt, 0)
return func(conn *connection) { conn.rttStats = &rttStats }
}
func connectionOptRetrySrcConnID(rcid protocol.ConnectionID) testConnectionOpt {
return func(conn *connection) { conn.retrySrcConnID = &rcid }
}
type testConnection struct {
conn *connection
connRunner *MockConnRunner
sendConn *MockSendConn
packer *MockPacker
destConnID protocol.ConnectionID
srcConnID protocol.ConnectionID
}
func newServerTestConnection(
t *testing.T,
mockCtrl *gomock.Controller,
config *Config,
gso bool,
opts ...testConnectionOpt,
) *testConnection {
if mockCtrl == nil {
mockCtrl = gomock.NewController(t)
}
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
connRunner := NewMockConnRunner(mockCtrl)
sendConn := NewMockSendConn(mockCtrl)
sendConn.EXPECT().capabilities().Return(connCapabilities{GSO: gso}).AnyTimes()
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
sendConn.EXPECT().LocalAddr().Return(localAddr).AnyTimes()
packer := NewMockPacker(mockCtrl)
b := make([]byte, 12)
rand.Read(b)
origDestConnID := protocol.ParseConnectionID(b[:6])
srcConnID := protocol.ParseConnectionID(b[6:12])
ctx, cancel := context.WithCancelCause(context.Background())
if config == nil {
config = &Config{DisablePathMTUDiscovery: true}
}
conn := newConnection(
ctx,
cancel,
sendConn,
connRunner,
origDestConnID,
nil,
protocol.ConnectionID{},
protocol.ConnectionID{},
srcConnID,
&protocol.DefaultConnectionIDGenerator{},
newStatelessResetter(nil),
populateConfig(config),
&tls.Config{},
handshake.NewTokenGenerator(handshake.TokenProtectorKey{}),
false,
nil,
utils.DefaultLogger,
protocol.Version1,
).(*connection)
conn.packer = packer
for _, opt := range opts {
opt(conn)
}
return &testConnection{
conn: conn,
connRunner: connRunner,
sendConn: sendConn,
packer: packer,
destConnID: origDestConnID,
srcConnID: srcConnID,
}
}
func newClientTestConnection(
t *testing.T,
mockCtrl *gomock.Controller,
config *Config,
enable0RTT bool,
opts ...testConnectionOpt,
) *testConnection {
if mockCtrl == nil {
mockCtrl = gomock.NewController(t)
}
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
connRunner := NewMockConnRunner(mockCtrl)
sendConn := NewMockSendConn(mockCtrl)
sendConn.EXPECT().capabilities().Return(connCapabilities{}).AnyTimes()
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
sendConn.EXPECT().LocalAddr().Return(localAddr).AnyTimes()
packer := NewMockPacker(mockCtrl)
b := make([]byte, 12)
rand.Read(b)
destConnID := protocol.ParseConnectionID(b[:6])
srcConnID := protocol.ParseConnectionID(b[6:12])
if config == nil {
config = &Config{DisablePathMTUDiscovery: true}
}
conn := newClientConnection(
context.Background(),
sendConn,
connRunner,
destConnID,
srcConnID,
&protocol.DefaultConnectionIDGenerator{},
newStatelessResetter(nil),
populateConfig(config),
&tls.Config{ServerName: "quic-go.net"},
0,
enable0RTT,
false,
nil,
utils.DefaultLogger,
protocol.Version1,
).(*connection)
conn.packer = packer
for _, opt := range opts {
opt(conn)
}
return &testConnection{
conn: conn,
connRunner: connRunner,
sendConn: sendConn,
packer: packer,
destConnID: destConnID,
srcConnID: srcConnID,
}
}
func TestConnectionHandleReceiveStreamFrames(t *testing.T) {
const streamID protocol.StreamID = 5
now := time.Now()
connID := protocol.ConnectionID{}
f := &wire.StreamFrame{StreamID: streamID, Data: []byte("foobar")}
rsf := &wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 42, FinalSize: 1337}
sdbf := &wire.StreamDataBlockedFrame{StreamID: streamID, MaximumStreamData: 1337}
t.Run("for existing and new streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
str := NewMockReceiveStreamI(mockCtrl)
// STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
str.EXPECT().handleStreamFrame(f, now)
require.NoError(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now))
// RESET_STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
str.EXPECT().handleResetStreamFrame(rsf, now)
require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now))
// STREAM_DATA_BLOCKED frames are not passed to the stream
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now))
})
t.Run("for closed streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now))
// RESET_STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now))
// STREAM_DATA_BLOCKED frames are not passed to the stream
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now))
})
t.Run("for invalid streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
testErr := errors.New("test err")
// STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now), testErr)
// RESET_STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now), testErr)
// STREAM_DATA_BLOCKED frames are not passed to the stream
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now), testErr)
})
}
func TestConnectionHandleSendStreamFrames(t *testing.T) {
const streamID protocol.StreamID = 3
now := time.Now()
connID := protocol.ConnectionID{}
ss := &wire.StopSendingFrame{StreamID: streamID, ErrorCode: 42}
msd := &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337}
t.Run("for existing and new streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
str := NewMockSendStreamI(mockCtrl)
// STOP_SENDING frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
str.EXPECT().handleStopSendingFrame(ss)
require.NoError(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now))
// MAX_STREAM_DATA frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
str.EXPECT().updateSendWindow(msd.MaximumStreamData)
require.NoError(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now))
})
t.Run("for closed streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// STOP_SENDING frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now))
// MAX_STREAM_DATA frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now))
})
t.Run("for invalid streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
testErr := errors.New("test err")
// STOP_SENDING frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now), testErr)
// MAX_STREAM_DATA frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now), testErr)
})
}
func TestConnectionHandleStreamNumFrames(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
now := time.Now()
connID := protocol.ConnectionID{}
// MAX_STREAMS frame
msf := &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10}
streamsMap.EXPECT().HandleMaxStreamsFrame(msf)
require.NoError(t, tc.conn.handleFrame(msf, protocol.Encryption1RTT, connID, now))
// STREAMS_BLOCKED frame
tc.conn.handleFrame(&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 1}, protocol.Encryption1RTT, connID, now)
}
func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) {
mockCtrl := gomock.NewController(t)
connFC := mocks.NewMockConnectionFlowController(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC))
now := time.Now()
connID := protocol.ConnectionID{}
// MAX_DATA frame
connFC.EXPECT().UpdateSendWindow(protocol.ByteCount(1337))
require.NoError(t, tc.conn.handleFrame(&wire.MaxDataFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now))
// DATA_BLOCKED frame
require.NoError(t, tc.conn.handleFrame(&wire.DataBlockedFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now))
}
func TestConnectionOpenStreams(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// using OpenStream
mstr := NewMockStreamI(mockCtrl)
streamsMap.EXPECT().OpenStream().Return(mstr, nil)
str, err := tc.conn.OpenStream()
require.NoError(t, err)
require.Equal(t, mstr, str)
// using OpenStreamSync
streamsMap.EXPECT().OpenStreamSync(context.Background()).Return(mstr, nil)
str, err = tc.conn.OpenStreamSync(context.Background())
require.NoError(t, err)
require.Equal(t, mstr, str)
// using OpenUniStream
streamsMap.EXPECT().OpenUniStream().Return(mstr, nil)
ustr, err := tc.conn.OpenUniStream()
require.NoError(t, err)
require.Equal(t, mstr, ustr)
// using OpenUniStreamSync
streamsMap.EXPECT().OpenUniStreamSync(context.Background()).Return(mstr, nil)
ustr, err = tc.conn.OpenUniStreamSync(context.Background())
require.NoError(t, err)
require.Equal(t, mstr, ustr)
}
func TestConnectionAcceptStreams(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// bidirectional streams
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
mstr := NewMockStreamI(mockCtrl)
streamsMap.EXPECT().AcceptStream(ctx).Return(mstr, nil)
str, err := tc.conn.AcceptStream(ctx)
require.NoError(t, err)
require.Equal(t, mstr, str)
// unidirectional streams
streamsMap.EXPECT().AcceptUniStream(ctx).Return(mstr, nil)
ustr, err := tc.conn.AcceptUniStream(ctx)
require.NoError(t, err)
require.Equal(t, mstr, ustr)
}
func TestConnectionServerInvalidFrames(t *testing.T) {
mockCtrl := gomock.NewController(t)
tc := newServerTestConnection(t, mockCtrl, nil, false)
for _, test := range []struct {
Name string
Frame wire.Frame
}{
{Name: "NEW_TOKEN", Frame: &wire.NewTokenFrame{Token: []byte("foobar")}},
{Name: "HANDSHAKE_DONE", Frame: &wire.HandshakeDoneFrame{}},
{Name: "PATH_RESPONSE", Frame: &wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}},
} {
t.Run(test.Name, func(t *testing.T) {
require.ErrorIs(t,
tc.conn.handleFrame(test.Frame, protocol.Encryption1RTT, protocol.ConnectionID{}, time.Now()),
&qerr.TransportError{ErrorCode: qerr.ProtocolViolation},
)
})
}
}
func TestConnectionTransportError(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
errChan := make(chan error, 1)
expectedErr := &qerr.TransportError{
ErrorCode: 1337,
FrameType: 42,
ErrorMessage: "test error",
}
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
b := getPacketBuffer()
b.Data = append(b.Data, []byte("connection close")...)
tc.packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil)
tc.sendConn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any())
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(expectedErr),
tracer.EXPECT().Close(),
)
go func() { errChan <- tc.conn.run() }()
tc.conn.closeLocal(expectedErr)
select {
case err := <-errChan:
require.ErrorIs(t, err, expectedErr)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// further calls to CloseWithError don't do anything
tc.conn.CloseWithError(42, "another error")
}
func TestConnectionApplicationClose(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
errChan := make(chan error, 1)
expectedErr := &qerr.ApplicationError{
ErrorCode: 1337,
ErrorMessage: "test error",
}
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
b := getPacketBuffer()
b.Data = append(b.Data, []byte("connection close")...)
tc.packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil)
tc.sendConn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any())
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(expectedErr),
tracer.EXPECT().Close(),
)
go func() { errChan <- tc.conn.run() }()
tc.conn.CloseWithError(1337, "test error")
select {
case err := <-errChan:
require.ErrorIs(t, err, expectedErr)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// further calls to CloseWithError don't do anything
tc.conn.CloseWithError(42, "another error")
}
func TestConnectionStatelessReset(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
errChan := make(chan error, 1)
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(&StatelessResetError{}),
tracer.EXPECT().Close(),
)
go func() { errChan <- tc.conn.run() }()
tc.conn.destroy(&StatelessResetError{})
select {
case err := <-errChan:
require.ErrorIs(t, err, &StatelessResetError{})
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func getLongHeaderPacket(t *testing.T, extHdr *wire.ExtendedHeader, data []byte) receivedPacket {
t.Helper()
b, err := extHdr.Append(nil, protocol.Version1)
require.NoError(t, err)
return receivedPacket{
data: append(b, data...),
buffer: getPacketBuffer(),
rcvTime: time.Now(),
}
}
func getShortHeaderPacket(t *testing.T, connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) receivedPacket {
t.Helper()
b, err := wire.AppendShortHeader(nil, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
require.NoError(t, err)
return receivedPacket{
data: append(b, data...),
buffer: getPacketBuffer(),
rcvTime: time.Now(),
}
}
func TestConnectionServerInvalidPackets(t *testing.T) {
t.Run("Retry", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
p := getLongHeaderPacket(t, &wire.ExtendedHeader{Header: wire.Header{
Type: protocol.PacketTypeRetry,
DestConnectionID: tc.conn.origDestConnID,
SrcConnectionID: tc.srcConnID,
Version: tc.conn.version,
Token: []byte("foobar"),
}}, make([]byte, 16) /* Retry integrity tag */)
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
wasProcessed, err := tc.conn.handlePacketImpl(p)
require.NoError(t, err)
require.False(t, wasProcessed)
})
t.Run("version negotiation", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
b := wire.ComposeVersionNegotiation(
protocol.ArbitraryLenConnectionID(tc.srcConnID.Bytes()),
protocol.ArbitraryLenConnectionID(tc.conn.origDestConnID.Bytes()),
[]Version{Version1},
)
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket)
wasProcessed, err := tc.conn.handlePacketImpl(receivedPacket{data: b, buffer: getPacketBuffer()})
require.NoError(t, err)
require.False(t, wasProcessed)
})
t.Run("unsupported version", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
p := getLongHeaderPacket(t, &wire.ExtendedHeader{
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: 1234},
PacketNumberLen: protocol.PacketNumberLen2,
}, nil)
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnsupportedVersion)
wasProcessed, err := tc.conn.handlePacketImpl(p)
require.NoError(t, err)
require.False(t, wasProcessed)
})
t.Run("invalid header", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
p := getLongHeaderPacket(t, &wire.ExtendedHeader{
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: Version1},
PacketNumberLen: protocol.PacketNumberLen2,
}, nil)
p.data[0] ^= 0x40 // unset the QUIC bit
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError)
wasProcessed, err := tc.conn.handlePacketImpl(p)
require.NoError(t, err)
require.False(t, wasProcessed)
})
}
func TestConnectionClientDrop0RTT(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
p := getLongHeaderPacket(t, &wire.ExtendedHeader{
Header: wire.Header{Type: protocol.PacketType0RTT, Length: 2, Version: protocol.Version1},
PacketNumberLen: protocol.PacketNumberLen2,
}, nil)
tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
wasProcessed, err := tc.conn.handlePacketImpl(p)
require.NoError(t, err)
require.False(t, wasProcessed)
}
func TestConnectionUnpacking(t *testing.T) {
mockCtrl := gomock.NewController(t)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptReceivedPacketHandler(rph),
connectionOptUnpacker(unpacker),
connectionOptTracer(tr),
)
// receive a long header packet
hdr := &wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: tc.srcConnID,
Version: protocol.Version1,
Length: 1,
},
PacketNumber: 0x37,
PacketNumberLen: protocol.PacketNumberLen1,
}
unpackedHdr := *hdr
unpackedHdr.PacketNumber = 0x1337
packet := getLongHeaderPacket(t, hdr, nil)
packet.ecn = protocol.ECNCE
rcvTime := time.Now().Add(-10 * time.Second)
packet.rcvTime = rcvTime
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial,
hdr: &unpackedHdr,
data: []byte{0}, // one PADDING frame
}, nil)
gomock.InOrder(
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial),
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNCE, protocol.EncryptionInitial, rcvTime, false),
)
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECNCE, []logging.Frame{})
wasProcessed, err := tc.conn.handlePacketImpl(packet)
require.NoError(t, err)
require.True(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
// receive a duplicate of this packet
packet = getLongHeaderPacket(t, hdr, nil)
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial).Return(true)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial,
hdr: &unpackedHdr,
data: []byte{0}, // one PADDING frame
}, nil)
tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.PacketNumber(0x1337), protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate)
wasProcessed, err = tc.conn.handlePacketImpl(packet)
require.NoError(t, err)
require.False(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
// receive a short header packet
packet = getShortHeaderPacket(t, tc.srcConnID, 0x37, nil)
packet.ecn = protocol.ECT1
packet.rcvTime = rcvTime
gomock.InOrder(
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT),
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECT1, protocol.Encryption1RTT, rcvTime, false),
)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil,
)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{})
wasProcessed, err = tc.conn.handlePacketImpl(packet)
require.NoError(t, err)
require.True(t, wasProcessed)
}
func TestConnectionUnpackCoalescedPacket(t *testing.T) {
mockCtrl := gomock.NewController(t)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptReceivedPacketHandler(rph),
connectionOptUnpacker(unpacker),
connectionOptTracer(tr),
)
hdr1 := &wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: tc.srcConnID,
Version: protocol.Version1,
Length: 1,
},
PacketNumber: 37,
PacketNumberLen: protocol.PacketNumberLen1,
}
hdr2 := &wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: tc.srcConnID,
Version: protocol.Version1,
Length: 1,
},
PacketNumber: 38,
PacketNumberLen: protocol.PacketNumberLen1,
}
// add a packet with a different source connection ID
incorrectSrcConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc})
hdr3 := &wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: incorrectSrcConnID,
Version: protocol.Version1,
Length: 1,
},
PacketNumber: 0x42,
PacketNumberLen: protocol.PacketNumberLen1,
}
unpackedHdr1 := *hdr1
unpackedHdr1.PacketNumber = 1337
unpackedHdr2 := *hdr2
unpackedHdr2.PacketNumber = 1338
packet := getLongHeaderPacket(t, hdr1, nil)
packet2 := getLongHeaderPacket(t, hdr2, nil)
packet3 := getLongHeaderPacket(t, hdr3, nil)
packet.data = append(packet.data, packet2.data...)
packet.data = append(packet.data, packet3.data...)
packet.ecn = protocol.ECT1
rcvTime := time.Now()
packet.rcvTime = rcvTime
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial,
hdr: &unpackedHdr1,
data: []byte{0}, // one PADDING frame
}, nil)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
hdr: &unpackedHdr2,
data: []byte{1}, // one PING frame
}, nil)
gomock.InOrder(
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(1337), protocol.EncryptionInitial),
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(1337), protocol.ECT1, protocol.EncryptionInitial, rcvTime, false),
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(1338), protocol.EncryptionHandshake),
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(1338), protocol.ECT1, protocol.EncryptionHandshake, rcvTime, true),
)
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial)
rph.EXPECT().DropPackets(protocol.EncryptionInitial)
gomock.InOrder(
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{}),
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{&wire.PingFrame{}}),
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(packet3.data)), logging.PacketDropUnknownConnectionID),
)
wasProcessed, err := tc.conn.handlePacketImpl(packet)
require.NoError(t, err)
require.True(t, wasProcessed)
}
func TestConnectionUnpackFailuresFatal(t *testing.T) {
t.Run("other errors", func(t *testing.T) {
require.ErrorIs(t,
testConnectionUnpackFailureFatal(t, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}),
&qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError},
)
})
t.Run("invalid reserved bits", func(t *testing.T) {
require.ErrorIs(t,
testConnectionUnpackFailureFatal(t, wire.ErrInvalidReservedBits),
&qerr.TransportError{ErrorCode: qerr.ProtocolViolation},
)
})
}
func testConnectionUnpackFailureFatal(t *testing.T, unpackErr error) error {
mockCtrl := gomock.NewController(t)
unpacker := NewMockUnpacker(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptUnpacker(unpacker),
)
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, unpackErr)
tc.packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tc.conn.handlePacket(getShortHeaderPacket(t, tc.srcConnID, 0x42, nil))
select {
case err := <-errChan:
require.Error(t, err)
return err
case <-time.After(time.Second):
t.Fatal("timeout")
}
return nil
}
func TestConnectionUnpackFailureDropped(t *testing.T) {
t.Run("keys dropped", func(t *testing.T) {
testConnectionUnpackFailureDropped(t, handshake.ErrKeysDropped, logging.PacketDropKeyUnavailable)
})
t.Run("decryption failed", func(t *testing.T) {
testConnectionUnpackFailureDropped(t, handshake.ErrDecryptionFailed, logging.PacketDropPayloadDecryptError)
})
t.Run("header parse error", func(t *testing.T) {
testErr := errors.New("foo")
testConnectionUnpackFailureDropped(t, &headerParseError{err: testErr}, logging.PacketDropHeaderParseError)
})
}
func testConnectionUnpackFailureDropped(t *testing.T, unpackErr error, packetDropReason logging.PacketDropReason) {
mockCtrl := gomock.NewController(t)
unpacker := NewMockUnpacker(mockCtrl)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptUnpacker(unpacker),
connectionOptTracer(tr),
)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, unpackErr)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
done := make(chan struct{})
tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.InvalidPacketNumber, gomock.Any(), packetDropReason).Do(
func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason) {
close(done)
},
)
tc.conn.handlePacket(getShortHeaderPacket(t, tc.srcConnID, 0x42, nil))
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case <-errChan:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionMaxUnprocessedPackets(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
done := make(chan struct{})
for i := protocol.PacketNumber(0); i < protocol.MaxConnUnprocessedPackets; i++ {
// nothing here should block
tc.conn.handlePacket(receivedPacket{data: []byte("foobar")})
}
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, logging.ByteCount(6), logging.PacketDropDOSPrevention).Do(func(logging.PacketType, logging.PacketNumber, logging.ByteCount, logging.PacketDropReason) {
close(done)
})
tc.conn.handlePacket(receivedPacket{data: []byte("foobar")})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionRemoteClose(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockStreamManager := NewMockStreamManager(mockCtrl)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptStreamManager(mockStreamManager),
connectionOptTracer(tr),
connectionOptUnpacker(unpacker),
)
ccf, err := (&wire.ConnectionCloseFrame{
ErrorCode: uint64(qerr.StreamLimitError),
ReasonPhrase: "foobar",
}).Append(nil, protocol.Version1)
require.NoError(t, err)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2, protocol.KeyPhaseBit(0), ccf, nil)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
expectedErr := &qerr.TransportError{ErrorCode: qerr.StreamLimitError, Remote: true}
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
streamErrChan := make(chan error, 1)
mockStreamManager.EXPECT().CloseWithError(gomock.Any()).Do(func(e error) { streamErrChan <- e })
tracerErrChan := make(chan error, 1)
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { tracerErrChan <- e })
tracer.EXPECT().Close()
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
p := getShortHeaderPacket(t, tc.srcConnID, 1, []byte("encrypted"))
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
select {
case err := <-errChan:
require.ErrorIs(t, err, expectedErr)
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case err := <-tracerErrChan:
require.ErrorIs(t, err, expectedErr)
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case err := <-streamErrChan:
require.ErrorIs(t, err, expectedErr)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionIdleTimeoutDuringHandshake(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
&Config{HandshakeIdleTimeout: scaleDuration(25 * time.Millisecond)},
false,
connectionOptTracer(tr),
)
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(&IdleTimeoutError{}),
tracer.EXPECT().Close(),
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case err := <-errChan:
require.ErrorIs(t, err, &IdleTimeoutError{})
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionHandshakeIdleTimeout(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
&Config{HandshakeIdleTimeout: scaleDuration(25 * time.Millisecond)},
false,
connectionOptTracer(tr),
func(c *connection) { c.creationTime = time.Now().Add(-10 * time.Second) },
)
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(&HandshakeTimeoutError{}),
tracer.EXPECT().Close(),
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case err := <-errChan:
require.ErrorIs(t, err, &HandshakeTimeoutError{})
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionTransportParameters(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
streamManager := NewMockStreamManager(mockCtrl)
connFC := mocks.NewMockConnectionFlowController(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
connectionOptStreamManager(streamManager),
connectionOptConnFlowController(connFC),
)
tracer.EXPECT().ReceivedTransportParameters(gomock.Any())
params := &wire.TransportParameters{
MaxIdleTimeout: 90 * time.Second,
InitialMaxStreamDataBidiLocal: 0x5000,
InitialMaxData: 0x5000,
ActiveConnectionIDLimit: 3,
// marshaling always sets it to this value
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
OriginalDestinationConnectionID: tc.destConnID,
}
streamManager.EXPECT().UpdateLimits(params)
connFC.EXPECT().UpdateSendWindow(params.InitialMaxData)
require.NoError(t, tc.conn.handleTransportParameters(params))
}
func TestConnectionTransportParameterValidationFailureServer(t *testing.T) {
tc := newServerTestConnection(t, nil, nil, false)
err := tc.conn.handleTransportParameters(&wire.TransportParameters{
InitialSourceConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
})
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
assert.ErrorContains(t, err, "expected initial_source_connection_id to equal")
}
func TestConnectionTransportParameterValidationFailureClient(t *testing.T) {
t.Run("initial_source_connection_id", func(t *testing.T) {
tc := newClientTestConnection(t, nil, nil, false)
err := tc.conn.handleTransportParameters(&wire.TransportParameters{
InitialSourceConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
})
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
assert.ErrorContains(t, err, "expected initial_source_connection_id to equal")
})
t.Run("original_destination_connection_id", func(t *testing.T) {
tc := newClientTestConnection(t, nil, nil, false)
err := tc.conn.handleTransportParameters(&wire.TransportParameters{
InitialSourceConnectionID: tc.destConnID,
OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
})
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
assert.ErrorContains(t, err, "expected original_destination_connection_id to equal")
})
t.Run("retry_source_connection_id if no retry", func(t *testing.T) {
tc := newClientTestConnection(t, nil, nil, false)
rcid := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
params := &wire.TransportParameters{
InitialSourceConnectionID: tc.destConnID,
OriginalDestinationConnectionID: tc.destConnID,
RetrySourceConnectionID: &rcid,
}
err := tc.conn.handleTransportParameters(params)
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
assert.ErrorContains(t, err, "received retry_source_connection_id, although no Retry was performed")
})
t.Run("retry_source_connection_id missing", func(t *testing.T) {
tc := newClientTestConnection(t,
nil,
nil,
false,
connectionOptRetrySrcConnID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})),
)
params := &wire.TransportParameters{
InitialSourceConnectionID: tc.destConnID,
OriginalDestinationConnectionID: tc.destConnID,
}
err := tc.conn.handleTransportParameters(params)
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
assert.ErrorContains(t, err, "missing retry_source_connection_id")
})
t.Run("retry_source_connection_id incorrect", func(t *testing.T) {
tc := newClientTestConnection(t,
nil,
nil,
false,
connectionOptRetrySrcConnID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})),
)
wrongCID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
params := &wire.TransportParameters{
InitialSourceConnectionID: tc.destConnID,
OriginalDestinationConnectionID: tc.destConnID,
RetrySourceConnectionID: &wrongCID,
}
err := tc.conn.handleTransportParameters(params)
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
assert.ErrorContains(t, err, "expected retry_source_connection_id to equal")
})
}
func TestConnectionHandshakeServer(t *testing.T) {
mockCtrl := gomock.NewController(t)
cs := mocks.NewMockCryptoSetup(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newServerTestConnection(
t,
mockCtrl,
nil,
false,
connectionOptCryptoSetup(cs),
connectionOptUnpacker(unpacker),
)
// the state transition is driven by processing of a CRYPTO frame
hdr := &wire.ExtendedHeader{
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1},
PacketNumberLen: protocol.PacketNumberLen2,
}
data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1)
require.NoError(t, err)
cs.EXPECT().DiscardInitialKeys()
tc.connRunner.EXPECT().Retire(gomock.Any())
gomock.InOrder(
cs.EXPECT().StartHandshake(gomock.Any()),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil,
),
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
cs.EXPECT().SetHandshakeConfirmed(),
cs.EXPECT().GetSessionTicket().Return([]byte("session ticket"), nil),
)
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
p := getLongHeaderPacket(t, hdr, nil)
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
select {
case <-tc.conn.HandshakeComplete():
case <-tc.conn.Context().Done():
t.Fatal("connection context done")
case <-time.After(time.Second):
t.Fatal("timeout")
}
var foundSessionTicket, foundHandshakeDone, foundNewToken bool
frames, _, _ := tc.conn.framer.Append(nil, nil, protocol.MaxByteCount, time.Now(), protocol.Version1)
for _, frame := range frames {
switch f := frame.Frame.(type) {
case *wire.CryptoFrame:
assert.Equal(t, []byte("session ticket"), f.Data)
foundSessionTicket = true
case *wire.HandshakeDoneFrame:
foundHandshakeDone = true
case *wire.NewTokenFrame:
assert.NotEmpty(t, f.Token)
foundNewToken = true
}
}
assert.True(t, foundSessionTicket)
assert.True(t, foundHandshakeDone)
assert.True(t, foundNewToken)
// test teardown
cs.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionHandshakeClient(t *testing.T) {
t.Run("without preferred address", func(t *testing.T) {
testConnectionHandshakeClient(t, false)
})
t.Run("with preferred address", func(t *testing.T) {
testConnectionHandshakeClient(t, true)
})
}
func testConnectionHandshakeClient(t *testing.T, usePreferredAddress bool) {
mockCtrl := gomock.NewController(t)
cs := mocks.NewMockCryptoSetup(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptCryptoSetup(cs), connectionOptUnpacker(unpacker))
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
// the state transition is driven by processing of a CRYPTO frame
hdr := &wire.ExtendedHeader{
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1},
PacketNumberLen: protocol.PacketNumberLen2,
}
data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1)
require.NoError(t, err)
tp := &wire.TransportParameters{
OriginalDestinationConnectionID: tc.destConnID,
MaxIdleTimeout: time.Hour,
}
preferredAddressConnID := protocol.ParseConnectionID([]byte{10, 8, 6, 4})
preferredAddressResetToken := protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}
if usePreferredAddress {
tp.PreferredAddress = &wire.PreferredAddress{
IPv4: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 42),
IPv6: netip.AddrPortFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), 13),
ConnectionID: preferredAddressConnID,
StatelessResetToken: preferredAddressResetToken,
}
}
packedFirstPacket := make(chan struct{})
gomock.InOrder(
cs.EXPECT().StartHandshake(gomock.Any()),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(
func(b bool, bc protocol.ByteCount, t time.Time, v protocol.Version) (*coalescedPacket, error) {
close(packedFirstPacket)
return &coalescedPacket{buffer: getPacketBuffer(), longHdrPackets: []*longHeaderPacket{{header: hdr}}}, nil
},
),
// initial keys are dropped when the first handshake packet is sent
cs.EXPECT().DiscardInitialKeys(),
// no more data to send
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil,
),
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedTransportParameters, TransportParameters: tp}),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
)
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes()
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case <-packedFirstPacket:
case <-time.After(time.Second):
t.Fatal("timeout")
}
p := getLongHeaderPacket(t, hdr, nil)
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
select {
case <-tc.conn.HandshakeComplete():
case <-tc.conn.Context().Done():
t.Fatal("connection context done")
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.True(t, mockCtrl.Satisfied())
// the handshake isn't confirmed until we receive a HANDSHAKE_DONE frame from the server
data, err = (&wire.HandshakeDoneFrame{}).Append(nil, protocol.Version1)
require.NoError(t, err)
done := make(chan struct{})
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes()
gomock.InOrder(
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.Encryption1RTT, data: data}, nil,
),
cs.EXPECT().SetHandshakeConfirmed(),
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
close(done)
return shortHeaderPacket{}, errNothingToPack
},
),
)
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
p = getLongHeaderPacket(t, hdr, nil)
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if usePreferredAddress {
tc.connRunner.EXPECT().AddResetToken(preferredAddressResetToken, gomock.Any())
}
nextConnID := tc.conn.connIDManager.Get()
if usePreferredAddress {
require.Equal(t, preferredAddressConnID, nextConnID)
}
// test teardown
cs.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
if usePreferredAddress {
tc.connRunner.EXPECT().RemoveResetToken(preferredAddressResetToken)
}
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnection0RTTTransportParameters(t *testing.T) {
mockCtrl := gomock.NewController(t)
cs := mocks.NewMockCryptoSetup(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptCryptoSetup(cs), connectionOptUnpacker(unpacker))
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
// the state transition is driven by processing of a CRYPTO frame
hdr := &wire.ExtendedHeader{
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1},
PacketNumberLen: protocol.PacketNumberLen2,
}
data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1)
require.NoError(t, err)
restored := &wire.TransportParameters{
ActiveConnectionIDLimit: 3,
InitialMaxData: 0x5000,
InitialMaxStreamDataBidiLocal: 0x5000,
InitialMaxStreamDataBidiRemote: 1000,
InitialMaxStreamDataUni: 1000,
MaxBidiStreamNum: 500,
MaxUniStreamNum: 500,
}
new := *restored
new.MaxBidiStreamNum-- // the server is not allowed to reduce the limit
new.OriginalDestinationConnectionID = tc.destConnID
packedFirstPacket := make(chan struct{})
gomock.InOrder(
cs.EXPECT().StartHandshake(gomock.Any()),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventRestoredTransportParameters, TransportParameters: restored}),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(
func(b bool, bc protocol.ByteCount, t time.Time, v protocol.Version) (*coalescedPacket, error) {
close(packedFirstPacket)
return &coalescedPacket{buffer: getPacketBuffer(), longHdrPackets: []*longHeaderPacket{{header: hdr}}}, nil
},
),
// initial keys are dropped when the first handshake packet is sent
cs.EXPECT().DiscardInitialKeys(),
// no more data to send
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil,
),
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedTransportParameters, TransportParameters: &new}),
cs.EXPECT().ConnectionState().Return(handshake.ConnectionState{Used0RTT: true}),
// cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
cs.EXPECT().Close(),
)
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes()
tc.packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case <-packedFirstPacket:
case <-time.After(time.Second):
t.Fatal("timeout")
}
p := getLongHeaderPacket(t, hdr, nil)
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
select {
case err := <-errChan:
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
require.ErrorContains(t, err, "server sent reduced limits after accepting 0-RTT data")
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionReceivePrioritization(t *testing.T) {
t.Run("handshake complete", func(t *testing.T) {
counter := testConnectionReceivePrioritization(t, true)
require.Equal(t, 10, counter)
})
// before handshake completion, we trigger packing of a new packet every time we receive a packet
t.Run("handshake not complete", func(t *testing.T) {
counter := testConnectionReceivePrioritization(t, false)
require.Equal(t, 1, counter)
})
}
func testConnectionReceivePrioritization(t *testing.T, handshakeComplete bool) int {
mockCtrl := gomock.NewController(t)
unpacker := NewMockUnpacker(mockCtrl)
opts := []testConnectionOpt{connectionOptUnpacker(unpacker)}
if handshakeComplete {
opts = append(opts, connectionOptHandshakeConfirmed())
}
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
opts...,
)
var counter int
var packedFirst bool
done := make(chan struct{})
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(
func(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
if !packedFirst {
counter++
}
return protocol.PacketNumber(counter), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0, 1} /* PADDING, PING */, nil
},
).AnyTimes()
switch handshakeComplete {
case false:
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(b bool, bc protocol.ByteCount, t time.Time, v protocol.Version) (*coalescedPacket, error) {
if !packedFirst {
packedFirst = true
close(done)
}
return nil, nil
},
).AnyTimes()
case true:
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(b *packetBuffer, bc protocol.ByteCount, t time.Time, v protocol.Version) (shortHeaderPacket, error) {
if !packedFirst {
packedFirst = true
close(done)
}
return shortHeaderPacket{}, errNothingToPack
},
).AnyTimes()
}
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
for i := 0; i < 10; i++ {
tc.conn.handlePacket(getShortHeaderPacket(t, tc.srcConnID, protocol.PacketNumber(i), []byte("foobar")))
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
return counter
}
func TestConnectionPacketBuffering(t *testing.T) {
mockCtrl := gomock.NewController(t)
unpacker := NewMockUnpacker(mockCtrl)
cs := mocks.NewMockCryptoSetup(mockCtrl)
tracer, tr := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptUnpacker(unpacker),
connectionOptCryptoSetup(cs),
connectionOptTracer(tracer),
)
tr.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
tr.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tr.EXPECT().DroppedEncryptionLevel(gomock.Any())
cs.EXPECT().DiscardInitialKeys()
hdr1 := wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: tc.srcConnID,
SrcConnectionID: tc.destConnID,
Length: 8,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 1,
}
hdr2 := hdr1
hdr2.PacketNumber = 2
cs.EXPECT().StartHandshake(gomock.Any())
buffered := make(chan struct{})
gomock.InOrder(
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable),
tr.EXPECT().BufferedPacket(logging.PacketTypeHandshake, gomock.Any()),
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable),
tr.EXPECT().BufferedPacket(logging.PacketTypeHandshake, gomock.Any()).Do(
func(logging.PacketType, logging.ByteCount) { close(buffered) },
),
)
tc.conn.handlePacket(getLongHeaderPacket(t, &hdr1, []byte("packet1")))
tc.conn.handlePacket(getLongHeaderPacket(t, &hdr2, []byte("packet2")))
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case <-buffered:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// Now send another packet.
// In reality, this packet would contain a CRYPTO frame that advances the TLS handshake
// such that new keys become available.
var packets []string
hdr3 := hdr1
hdr3.PacketNumber = 3
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
unpacked := make(chan struct{})
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedReadKeys})
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
gomock.InOrder(
// packet 3 contains a CRYPTO frame and triggers the keys to become available
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(
func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
packets = append(packets, string(data[len(data)-7:]))
cf := &wire.CryptoFrame{Data: []byte("foobar")}
b, _ := cf.Append(nil, protocol.Version1)
return &unpackedPacket{hdr: &hdr3, encryptionLevel: protocol.EncryptionHandshake, data: b}, nil
},
),
cs.EXPECT().HandleMessage(gomock.Any(), gomock.Any()),
tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
// packet 1 dequeued from the buffer
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(
func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
packets = append(packets, string(data[len(data)-7:]))
return &unpackedPacket{hdr: &hdr1, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil
},
),
tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
// packet 2 dequeued from the buffer
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(
func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
packets = append(packets, string(data[len(data)-7:]))
close(unpacked)
return &unpackedPacket{hdr: &hdr2, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil
},
),
tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
)
tc.conn.handlePacket(getLongHeaderPacket(t, &hdr3, []byte("packet3")))
select {
case <-unpacked:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// packet3 triggered the keys to become available
// packet1 and packet2 are processed from the buffer in order
require.Equal(t, []string{"packet3", "packet1", "packet2"}, packets)
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
cs.EXPECT().Close()
tr.EXPECT().ClosedConnection(gomock.Any())
tr.EXPECT().Close()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionPacketPacing(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sender := NewMockSender(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptSentPacketHandler(sph),
connectionOptSender(sender),
connectionOptHandshakeConfirmed(),
// set a fixed RTT, so that the idle timeout doesn't interfere with this test
connectionOptRTT(10*time.Second),
)
sender.EXPECT().Run()
step := scaleDuration(50 * time.Millisecond)
sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes()
gomock.InOrder(
// 1. allow 2 packets to be sent
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
// 2. become pacing limited for 25ms
sph.EXPECT().TimeUntilSend().DoAndReturn(func() time.Time { return time.Now().Add(step) }),
// 3. send another packet
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
// 4. become pacing limited for 25ms...
sph.EXPECT().TimeUntilSend().DoAndReturn(func() time.Time { return time.Now().Add(step) }),
// ... but this time we're still pacing limited when waking up.
// In this case, we can only send an ACK.
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
// 5. stop the test by becoming pacing limited forever
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
)
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
for i := 0; i < 3; i++ {
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), Version1).DoAndReturn(
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
buf.Data = append(buf.Data, []byte("packet"+strconv.Itoa(i+1))...)
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i + 1)}, nil
},
)
}
tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(_ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
buf := getPacketBuffer()
buf.Data = []byte("ack")
return shortHeaderPacket{PacketNumber: 1}, buf, nil
},
)
sender.EXPECT().WouldBlock().AnyTimes()
type sentPacket struct {
time time.Time
data []byte
}
sendChan := make(chan sentPacket, 10)
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
sendChan <- sentPacket{time: time.Now(), data: b.Data}
}).Times(4)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
var times []time.Time
for i := 0; i < 3; i++ {
select {
case b := <-sendChan:
require.Equal(t, []byte("packet"+strconv.Itoa(i+1)), b.data)
times = append(times, b.time)
case <-time.After(scaleDuration(time.Second)):
t.Fatal("timeout")
}
}
select {
case b := <-sendChan:
require.Equal(t, []byte("ack"), b.data)
times = append(times, b.time)
case <-time.After(scaleDuration(time.Second)):
t.Fatal("timeout")
}
require.InDelta(t, times[0].Sub(times[1]).Seconds(), 0, scaleDuration(10*time.Millisecond).Seconds())
require.InDelta(t, times[2].Sub(times[1]).Seconds(), step.Seconds(), scaleDuration(20*time.Millisecond).Seconds())
require.InDelta(t, times[3].Sub(times[2]).Seconds(), step.Seconds(), scaleDuration(20*time.Millisecond).Seconds())
time.Sleep(scaleDuration(step)) // make sure that no more packets are sent
require.True(t, mockCtrl.Satisfied())
// test teardown
sender.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case <-sendChan:
t.Fatal("should not have sent any more packets")
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionIdleTimeout(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
&Config{MaxIdleTimeout: time.Second},
false,
connectionOptHandshakeConfirmed(),
connectionOptSentPacketHandler(sph),
connectionOptRTT(time.Millisecond),
)
// the idle timeout is set when the transport parameters are received
idleTimeout := scaleDuration(50 * time.Millisecond)
require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
MaxIdleTimeout: idleTimeout,
}))
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
var lastSendTime time.Time
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
buf.Data = append(buf.Data, []byte("foobar")...)
lastSendTime = time.Now()
return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil
},
)
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case err := <-errChan:
require.ErrorIs(t, err, &IdleTimeoutError{})
require.NotZero(t, lastSendTime)
require.InDelta(t,
time.Since(lastSendTime).Seconds(),
idleTimeout.Seconds(),
scaleDuration(10*time.Millisecond).Seconds(),
)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionKeepAlive(t *testing.T) {
t.Run("enabled", func(t *testing.T) {
testConnectionKeepAlive(t, true, true)
})
t.Run("disabled", func(t *testing.T) {
testConnectionKeepAlive(t, false, false)
})
}
func testConnectionKeepAlive(t *testing.T, enable, expectKeepAlive bool) {
var keepAlivePeriod time.Duration
if enable {
keepAlivePeriod = time.Second
}
mockCtrl := gomock.NewController(t)
unpacker := NewMockUnpacker(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
&Config{MaxIdleTimeout: time.Second, KeepAlivePeriod: keepAlivePeriod},
false,
connectionOptUnpacker(unpacker),
connectionOptHandshakeConfirmed(),
connectionOptRTT(time.Millisecond),
)
// the idle timeout is set when the transport parameters are received
idleTimeout := scaleDuration(50 * time.Millisecond)
require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
MaxIdleTimeout: idleTimeout,
}))
// Receive a packet. This starts the keep-alive timer.
buf := getPacketBuffer()
var err error
buf.Data, err = wire.AppendShortHeader(buf.Data, tc.srcConnID, 1, protocol.PacketNumberLen1, protocol.KeyPhaseZero)
require.NoError(t, err)
buf.Data = append(buf.Data, []byte("packet")...)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
var unpackTime, packTime time.Time
done := make(chan struct{})
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(
func(t time.Time, bytes []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
unpackTime = time.Now()
return protocol.PacketNumber(1), protocol.PacketNumberLen1, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil
},
)
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
switch expectKeepAlive {
case true:
// record the time of the keep-alive is sent
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
packTime = time.Now()
close(done)
return shortHeaderPacket{}, errNothingToPack
},
)
tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: time.Now()})
select {
case <-done:
// the keep-alive packet should be sent after half the idle timeout
diff := packTime.Sub(unpackTime)
require.InDelta(t, diff.Seconds(), idleTimeout.Seconds()/2, scaleDuration(10*time.Millisecond).Seconds())
case <-time.After(idleTimeout):
t.Fatal("timeout")
}
case false: // if keep-alives are disabled, the connection will run into an idle timeout
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: time.Now()})
select {
case <-time.After(3 * time.Second):
t.Fatal("timeout")
case <-time.After(idleTimeout):
}
}
// test teardown
if expectKeepAlive {
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
}
select {
case err := <-errChan:
if expectKeepAlive {
require.NoError(t, err)
} else {
require.ErrorIs(t, err, &IdleTimeoutError{})
}
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionACKTimer(t *testing.T) {
mockCtrl := gomock.NewController(t)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
&Config{MaxIdleTimeout: time.Second},
false,
connectionOptHandshakeConfirmed(),
connectionOptReceivedPacketHandler(rph),
connectionOptSentPacketHandler(sph),
connectionOptRTT(10*time.Second),
)
alarmTimeout := scaleDuration(50 * time.Millisecond)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour))
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
var times []time.Time
done := make(chan struct{}, 5)
var calls []any
for i := 0; i < 2; i++ {
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
buf.Data = append(buf.Data, []byte("foobar")...)
times = append(times, time.Now())
return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil
},
))
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
done <- struct{}{}
return shortHeaderPacket{}, errNothingToPack
},
))
if i == 0 {
calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(alarmTimeout)))
} else {
calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).MaxTimes(1))
}
}
gomock.InOrder(calls...)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
for i := 0; i < 2; i++ {
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
assert.Len(t, times, 2)
require.InDelta(t, times[1].Sub(times[0]).Seconds(), alarmTimeout.Seconds(), scaleDuration(10*time.Millisecond).Seconds())
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
// Send a GSO batch, until we have no more data to send.
func TestConnectionGSOBatch(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
true,
connectionOptHandshakeConfirmed(),
connectionOptSentPacketHandler(sph),
)
// allow packets to be sent
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().TimeUntilSend().Return(time.Time{}).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().Return(time.Time{}).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
maxPacketSize := tc.conn.maxPacketSize()
var expectedData []byte
for i := 0; i < 4; i++ {
data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
expectedData = append(expectedData, data...)
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
buffer.Data = append(buffer.Data, data...)
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil
},
)
}
done := make(chan struct{})
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1).DoAndReturn(
func([]byte, uint16, protocol.ECN) error { close(done); return nil },
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
// Send a GSO batch, until a packet smaller than the maximum size is packed
func TestConnectionGSOBatchPacketSize(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
true,
connectionOptHandshakeConfirmed(),
connectionOptSentPacketHandler(sph),
)
// allow packets to be sent
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().TimeUntilSend().Return(time.Time{}).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().Return(time.Time{}).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
maxPacketSize := tc.conn.maxPacketSize()
var expectedData []byte
var calls []any
for i := 0; i < 4; i++ {
var data []byte
if i == 3 {
data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize-1))
} else {
data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
}
expectedData = append(expectedData, data...)
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
buffer.Data = append(buffer.Data, data...)
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(10 + i)}, nil
},
))
}
// The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch.
// We therefore send a "foobar", so we can check that we're actually generating two GSO batches.
calls = append(calls,
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
buffer.Data = append(buffer.Data, []byte("foobar")...)
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(14)}, nil
},
),
)
calls = append(calls,
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack),
)
gomock.InOrder(calls...)
done := make(chan struct{})
gomock.InOrder(
tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1),
tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECT1).DoAndReturn(
func([]byte, uint16, protocol.ECN) error { close(done); return nil },
),
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionGSOBatchECN(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
true,
connectionOptHandshakeConfirmed(),
connectionOptSentPacketHandler(sph),
)
// allow packets to be sent
ecnMode := protocol.ECT1
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().TimeUntilSend().Return(time.Time{}).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().Return(time.Time{}).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).DoAndReturn(func(bool) protocol.ECN { return ecnMode }).AnyTimes()
// 3. Send a GSO batch, until the ECN marking changes.
var expectedData []byte
var calls []any
maxPacketSize := tc.conn.maxPacketSize()
for i := 0; i < 3; i++ {
data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
expectedData = append(expectedData, data...)
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
buffer.Data = append(buffer.Data, data...)
if i == 2 {
ecnMode = protocol.ECNCE
}
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(20 + i)}, nil
},
))
}
// The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch.
// We therefore send a "foobar", so we can check that we're actually generating two GSO batches.
calls = append(calls,
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
buffer.Data = append(buffer.Data, []byte("foobar")...)
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(24)}, nil
},
),
)
calls = append(calls,
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack),
)
gomock.InOrder(calls...)
done3 := make(chan struct{})
tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1)
tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECNCE).DoAndReturn(
func([]byte, uint16, protocol.ECN) error { close(done3); return nil },
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case <-done3:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionPTOProbePackets(t *testing.T) {
t.Run("Initial", func(t *testing.T) {
testConnectionPTOProbePackets(t, protocol.EncryptionInitial)
})
t.Run("Handshake", func(t *testing.T) {
testConnectionPTOProbePackets(t, protocol.EncryptionHandshake)
})
t.Run("1-RTT", func(t *testing.T) {
testConnectionPTOProbePackets(t, protocol.Encryption1RTT)
})
}
func testConnectionPTOProbePackets(t *testing.T, encLevel protocol.EncryptionLevel) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptSentPacketHandler(sph),
)
var sendMode ackhandler.SendMode
switch encLevel {
case protocol.EncryptionInitial:
sendMode = ackhandler.SendPTOInitial
case protocol.EncryptionHandshake:
sendMode = ackhandler.SendPTOHandshake
case protocol.Encryption1RTT:
sendMode = ackhandler.SendPTOAppData
}
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(sendMode)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
sph.EXPECT().ECNMode(gomock.Any())
sph.EXPECT().QueueProbePacket(encLevel).Return(false)
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tc.packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(
func(encLevel protocol.EncryptionLevel, maxSize protocol.ByteCount, t time.Time, version protocol.Version) (*coalescedPacket, error) {
return &coalescedPacket{
buffer: getPacketBuffer(),
shortHdrPacket: &shortHeaderPacket{PacketNumber: 1},
}, nil
},
)
done := make(chan struct{})
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(
func([]byte, uint16, protocol.ECN) error { close(done); return nil },
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionCongestionControl(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptHandshakeConfirmed(),
connectionOptSentPacketHandler(sph),
connectionOptRTT(10*time.Second),
)
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().ECNMode(true).AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck).MaxTimes(1)
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
// Since we're already sending out packets, we don't expect any calls to PackAckOnlyPacket
for i := 0; i < 2; i++ {
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
buffer.Data = append(buffer.Data, []byte("foobar")...)
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil
},
)
}
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
done1 := make(chan struct{})
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(
func([]byte, uint16, protocol.ECN) error { close(done1); return nil },
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case <-done1:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.True(t, mockCtrl.Satisfied())
// Now that we're congestion limited, we can only send an ack-only packet
done2 := make(chan struct{})
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck)
tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(protocol.ByteCount, time.Time, protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
close(done2)
return shortHeaderPacket{}, nil, errNothingToPack
},
)
tc.conn.scheduleSending()
select {
case <-done2:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.True(t, mockCtrl.Satisfied())
// If the send mode is "none", we can't even send an ack-only packet
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
tc.conn.scheduleSending()
time.Sleep(scaleDuration(10 * time.Millisecond)) // make sure there are no calls to the packer
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionSendQueue(t *testing.T) {
t.Run("with GSO", func(t *testing.T) {
testConnectionSendQueue(t, true)
})
t.Run("without GSO", func(t *testing.T) {
testConnectionSendQueue(t, false)
})
}
func testConnectionSendQueue(t *testing.T, enableGSO bool) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sender := NewMockSender(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
enableGSO,
connectionOptSender(sender),
connectionOptHandshakeConfirmed(),
connectionOptSentPacketHandler(sph),
)
sender.EXPECT().Run().MaxTimes(1)
sender.EXPECT().WouldBlock()
sender.EXPECT().WouldBlock().Return(true).Times(2)
available := make(chan struct{})
blocked := make(chan struct{})
sender.EXPECT().Available().DoAndReturn(
func() <-chan struct{} {
close(blocked)
return available
},
)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
shortHeaderPacket{PacketNumber: protocol.PacketNumber(1)}, nil,
)
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any())
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case <-blocked:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.True(t, mockCtrl.Satisfied())
// now make room in the send queue
sender.EXPECT().WouldBlock().AnyTimes()
unblocked := make(chan struct{})
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(*packetBuffer, protocol.ByteCount, time.Time, protocol.Version) (shortHeaderPacket, error) {
close(unblocked)
return shortHeaderPacket{}, errNothingToPack
},
)
available <- struct{}{}
select {
case <-unblocked:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
sender.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func getVersionNegotiationPacket(src, dest protocol.ConnectionID, versions []protocol.Version) receivedPacket {
b := wire.ComposeVersionNegotiation(
protocol.ArbitraryLenConnectionID(src.Bytes()),
protocol.ArbitraryLenConnectionID(dest.Bytes()),
versions,
)
return receivedPacket{
rcvTime: time.Now(),
data: b,
buffer: getPacketBuffer(),
}
}
func TestConnectionVersionNegotiation(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
)
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
var tracerVersions []logging.Version
gomock.InOrder(
tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) {
tracerVersions = versions
}),
tracer.EXPECT().NegotiatedVersion(protocol.Version2, gomock.Any(), gomock.Any()),
tc.connRunner.EXPECT().Remove(gomock.Any()),
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.handlePacket(getVersionNegotiationPacket(
tc.destConnID,
tc.srcConnID,
[]protocol.Version{1234, protocol.Version2},
))
select {
case err := <-errChan:
var rerr *errCloseForRecreating
require.ErrorAs(t, err, &rerr)
require.Equal(t, rerr.nextVersion, protocol.Version2)
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.Contains(t, tracerVersions, protocol.Version(1234))
require.Contains(t, tracerVersions, protocol.Version2)
}
func TestConnectionVersionNegotiationNoMatch(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
&Config{Versions: []protocol.Version{protocol.Version1}},
false,
connectionOptTracer(tr),
)
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
var tracerVersions []logging.Version
tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(
func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) { tracerVersions = versions },
)
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any())
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.handlePacket(getVersionNegotiationPacket(
tc.destConnID,
tc.srcConnID,
[]protocol.Version{protocol.Version2},
))
select {
case err := <-errChan:
var verr *VersionNegotiationError
require.ErrorAs(t, err, &verr)
require.Contains(t, verr.Theirs, protocol.Version2)
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.Contains(t, tracerVersions, protocol.Version2)
}
func TestConnectionVersionNegotiationInvalidPackets(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
)
// offers the current version
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedVersion)
vnp := getVersionNegotiationPacket(
tc.destConnID,
tc.srcConnID,
[]protocol.Version{1234, protocol.Version1},
)
wasProcessed, err := tc.conn.handlePacketImpl(vnp)
require.NoError(t, err)
require.False(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
// unparseable, since it's missing 2 bytes
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, gomock.Any(), gomock.Any(), logging.PacketDropHeaderParseError)
vnp.data = vnp.data[:len(vnp.data)-2]
wasProcessed, err = tc.conn.handlePacketImpl(vnp)
require.NoError(t, err)
require.False(t, wasProcessed)
}
func getRetryPacket(t *testing.T, src, dest, origDest protocol.ConnectionID, token []byte) receivedPacket {
hdr := wire.Header{
Type: protocol.PacketTypeRetry,
SrcConnectionID: src,
DestConnectionID: dest,
Token: token,
Version: protocol.Version1,
}
b, err := (&wire.ExtendedHeader{Header: hdr}).Append(nil, protocol.Version1)
require.NoError(t, err)
tag := handshake.GetRetryIntegrityTag(b, origDest, protocol.Version1)
b = append(b, tag[:]...)
return receivedPacket{
rcvTime: time.Now(),
data: b,
buffer: getPacketBuffer(),
}
}
func TestConnectionRetryDrops(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
connectionOptUnpacker(unpacker),
)
newConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
// invalid integrity tag
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropPayloadDecryptError)
retry := getRetryPacket(t, newConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
retry.data[len(retry.data)-1]++
wasProcessed, err := tc.conn.handlePacketImpl(retry)
require.NoError(t, err)
require.False(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
// receive a retry that doesn't change the connection ID
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedPacket)
retry = getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
wasProcessed, err = tc.conn.handlePacketImpl(retry)
require.NoError(t, err)
require.False(t, wasProcessed)
}
func TestConnectionRetryAfterReceivedPacket(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
connectionOptUnpacker(unpacker),
)
// receive a regular packet
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
regular := getPacketWithPacketType(t, tc.srcConnID, protocol.PacketTypeInitial, 200)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
&unpackedPacket{
hdr: &wire.ExtendedHeader{Header: wire.Header{Type: protocol.PacketTypeInitial}},
encryptionLevel: protocol.EncryptionInitial,
}, nil,
)
wasProcessed, err := tc.conn.handlePacketImpl(receivedPacket{
data: regular,
buffer: getPacketBuffer(),
rcvTime: time.Now(),
})
require.NoError(t, err)
require.True(t, wasProcessed)
// receive a retry
retry := getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedPacket)
wasProcessed, err = tc.conn.handlePacketImpl(retry)
require.NoError(t, err)
require.False(t, wasProcessed)
}
func TestConnectionConnectionIDChanges(t *testing.T) {
t.Run("with retry", func(t *testing.T) {
testConnectionConnectionIDChanges(t, true)
})
t.Run("without retry", func(t *testing.T) {
testConnectionConnectionIDChanges(t, false)
})
}
func testConnectionConnectionIDChanges(t *testing.T, sendRetry bool) {
makeInitialPacket := func(t *testing.T, hdr *wire.ExtendedHeader) []byte {
t.Helper()
data, err := hdr.Append(nil, protocol.Version1)
require.NoError(t, err)
data = append(data, make([]byte, hdr.Length-protocol.ByteCount(hdr.PacketNumberLen))...)
return data
}
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
connectionOptUnpacker(unpacker),
)
dstConnID := tc.destConnID
b := make([]byte, 3*10)
rand.Read(b)
newConnID := protocol.ParseConnectionID(b[:11])
newConnID2 := protocol.ParseConnectionID(b[11:20])
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
require.Equal(t, dstConnID, tc.conn.connIDManager.Get())
var retryConnID protocol.ConnectionID
if sendRetry {
retryConnID = protocol.ParseConnectionID(b[20:30])
hdrChan := make(chan *wire.Header)
tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { hdrChan <- hdr })
tc.packer.EXPECT().SetToken([]byte("foobar"))
tc.conn.handlePacket(getRetryPacket(t, retryConnID, tc.srcConnID, tc.destConnID, []byte("foobar")))
select {
case hdr := <-hdrChan:
assert.Equal(t, retryConnID, hdr.SrcConnectionID)
assert.Equal(t, []byte("foobar"), hdr.Token)
require.Equal(t, retryConnID, tc.conn.connIDManager.Get())
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
// Send the first packet. The server changes the connection ID to newConnID.
hdr1 := wire.ExtendedHeader{
Header: wire.Header{
SrcConnectionID: newConnID,
DestConnectionID: tc.srcConnID,
Type: protocol.PacketTypeInitial,
Length: 200,
Version: protocol.Version1,
},
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen2,
}
hdr2 := hdr1
hdr2.SrcConnectionID = newConnID2
receivedFirst := make(chan struct{})
gomock.InOrder(
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
&unpackedPacket{
hdr: &hdr1,
encryptionLevel: protocol.EncryptionInitial,
}, nil,
),
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(
func(*wire.ExtendedHeader, protocol.ByteCount, protocol.ECN, []logging.Frame) { close(receivedFirst) },
),
)
tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr1), buffer: getPacketBuffer(), rcvTime: time.Now()})
select {
case <-receivedFirst:
require.Equal(t, newConnID, tc.conn.connIDManager.Get())
case <-time.After(time.Second):
t.Fatal("timeout")
}
// Send the second packet. We refuse to accept it, because the connection ID is changed again.
dropped := make(chan struct{})
tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, gomock.Any(), gomock.Any(), logging.PacketDropUnknownConnectionID).Do(
func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason) {
close(dropped)
},
)
tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr2), buffer: getPacketBuffer(), rcvTime: time.Now()})
select {
case <-dropped:
// the connection ID should not have changed
require.Equal(t, newConnID, tc.conn.connIDManager.Get())
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any())
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
// When the connection is closed before sending the first packet,
// we don't send a CONNECTION_CLOSE.
// This can happen if there's something wrong the tls.Config, and
// crypto/tls refuses to start the handshake.
func TestConnectionEarlyClose(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
cryptoSetup := mocks.NewMockCryptoSetup(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
connectionOptCryptoSetup(cryptoSetup),
)
tc.conn.sentFirstPacket = false
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
cryptoSetup.EXPECT().StartHandshake(gomock.Any()).Do(func(context.Context) error {
tc.conn.closeLocal(errors.New("early error"))
return nil
})
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any())
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case err := <-errChan:
require.Error(t, err)
require.ErrorContains(t, err, "early error")
case <-time.After(time.Second):
t.Fatal("timeout")
}
}