diff --git a/connection.go b/connection.go index aa5cec91..5afc7c6b 100644 --- a/connection.go +++ b/connection.go @@ -2096,7 +2096,7 @@ func (s *connection) sendProbePacket(sendMode ackhandler.SendMode, now time.Time break } var err error - packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), now, s.version) + packet, err = s.packer.MaybePackPTOProbePacket(encLevel, s.maxPacketSize(), now, s.version) if err != nil { return err } @@ -2107,7 +2107,7 @@ func (s *connection) sendProbePacket(sendMode ackhandler.SendMode, now time.Time if packet == nil { s.retransmissionQueue.AddPing(encLevel) var err error - packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), now, s.version) + packet, err = s.packer.MaybePackPTOProbePacket(encLevel, s.maxPacketSize(), now, s.version) if err != nil { return err } diff --git a/connection_test.go b/connection_test.go index 16cdafe4..331fe0c7 100644 --- a/connection_test.go +++ b/connection_test.go @@ -2199,7 +2199,7 @@ func testConnectionPTOProbePackets(t *testing.T, encLevel protocol.EncryptionLev sph.EXPECT().QueueProbePacket(encLevel).Return(false) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tc.packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn( + tc.packer.EXPECT().MaybePackPTOProbePacket(encLevel, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn( func(encLevel protocol.EncryptionLevel, maxSize protocol.ByteCount, t time.Time, version protocol.Version) (*coalescedPacket, error) { return &coalescedPacket{ buffer: getPacketBuffer(), diff --git a/mock_packer_test.go b/mock_packer_test.go index 674a560e..3bdfb05e 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -82,41 +82,41 @@ func (c *MockPackerAppendPacketCall) DoAndReturn(f func(*packetBuffer, protocol. return c } -// MaybePackProbePacket mocks base method. -func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel, arg1 protocol.ByteCount, arg2 time.Time, arg3 protocol.Version) (*coalescedPacket, error) { +// MaybePackPTOProbePacket mocks base method. +func (m *MockPacker) MaybePackPTOProbePacket(arg0 protocol.EncryptionLevel, arg1 protocol.ByteCount, arg2 time.Time, arg3 protocol.Version) (*coalescedPacket, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MaybePackProbePacket", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "MaybePackPTOProbePacket", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } -// MaybePackProbePacket indicates an expected call of MaybePackProbePacket. -func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0, arg1, arg2, arg3 any) *MockPackerMaybePackProbePacketCall { +// MaybePackPTOProbePacket indicates an expected call of MaybePackPTOProbePacket. +func (mr *MockPackerMockRecorder) MaybePackPTOProbePacket(arg0, arg1, arg2, arg3 any) *MockPackerMaybePackPTOProbePacketCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0, arg1, arg2, arg3) - return &MockPackerMaybePackProbePacketCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackPTOProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackPTOProbePacket), arg0, arg1, arg2, arg3) + return &MockPackerMaybePackPTOProbePacketCall{Call: call} } -// MockPackerMaybePackProbePacketCall wrap *gomock.Call -type MockPackerMaybePackProbePacketCall struct { +// MockPackerMaybePackPTOProbePacketCall wrap *gomock.Call +type MockPackerMaybePackPTOProbePacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockPackerMaybePackProbePacketCall) Return(arg0 *coalescedPacket, arg1 error) *MockPackerMaybePackProbePacketCall { +func (c *MockPackerMaybePackPTOProbePacketCall) Return(arg0 *coalescedPacket, arg1 error) *MockPackerMaybePackPTOProbePacketCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockPackerMaybePackProbePacketCall) Do(f func(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error)) *MockPackerMaybePackProbePacketCall { +func (c *MockPackerMaybePackPTOProbePacketCall) Do(f func(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error)) *MockPackerMaybePackPTOProbePacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPackerMaybePackProbePacketCall) DoAndReturn(f func(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error)) *MockPackerMaybePackProbePacketCall { +func (c *MockPackerMaybePackPTOProbePacketCall) DoAndReturn(f func(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error)) *MockPackerMaybePackPTOProbePacketCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -318,6 +318,46 @@ func (c *MockPackerPackMTUProbePacketCall) DoAndReturn(f func(ackhandler.Frame, return c } +// PackPathProbePacket mocks base method. +func (m *MockPacker) PackPathProbePacket(arg0 protocol.ConnectionID, arg1 ackhandler.Frame, arg2 protocol.Version) (shortHeaderPacket, *packetBuffer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PackPathProbePacket", arg0, arg1, arg2) + ret0, _ := ret[0].(shortHeaderPacket) + ret1, _ := ret[1].(*packetBuffer) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// PackPathProbePacket indicates an expected call of PackPathProbePacket. +func (mr *MockPackerMockRecorder) PackPathProbePacket(arg0, arg1, arg2 any) *MockPackerPackPathProbePacketCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPathProbePacket", reflect.TypeOf((*MockPacker)(nil).PackPathProbePacket), arg0, arg1, arg2) + return &MockPackerPackPathProbePacketCall{Call: call} +} + +// MockPackerPackPathProbePacketCall wrap *gomock.Call +type MockPackerPackPathProbePacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPackerPackPathProbePacketCall) Return(arg0 shortHeaderPacket, arg1 *packetBuffer, arg2 error) *MockPackerPackPathProbePacketCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPackerPackPathProbePacketCall) Do(f func(protocol.ConnectionID, ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackPathProbePacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPackerPackPathProbePacketCall) DoAndReturn(f func(protocol.ConnectionID, ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackPathProbePacketCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // SetToken mocks base method. func (m *MockPacker) SetToken(arg0 []byte) { m.ctrl.T.Helper() diff --git a/packet_packer.go b/packet_packer.go index 7724b503..720f1958 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -22,9 +22,10 @@ type packer interface { PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (*coalescedPacket, error) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, error) - MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error) + MaybePackPTOProbePacket(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error) PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) + PackPathProbePacket(protocol.ConnectionID, ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) SetToken([]byte) @@ -57,6 +58,7 @@ type shortHeaderPacket struct { Ack *wire.AckFrame Length protocol.ByteCount IsPathMTUProbePacket bool + IsPathProbePacket bool // used for logging DestConnID protocol.ConnectionID @@ -269,17 +271,17 @@ func (p *packetPacker) packConnectionClose( if sealers[i] == nil { continue } - var paddingLen protocol.ByteCount - if encLevel == protocol.EncryptionInitial { - paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize) - } if encLevel == protocol.Encryption1RTT { - shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, maxPacketSize, sealers[i], false, v) + shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], 0, maxPacketSize, sealers[i], false, v) if err != nil { return nil, err } packet.shortHdrPacket = &shp } else { + var paddingLen protocol.ByteCount + if encLevel == protocol.EncryptionInitial { + paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize) + } longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v) if err != nil { return nil, err @@ -707,7 +709,7 @@ func (p *packetPacker) composeNextPacket( return pl } -func (p *packetPacker) MaybePackProbePacket( +func (p *packetPacker) MaybePackPTOProbePacket( encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, now time.Time, @@ -792,6 +794,26 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B return packet, buffer, err } +func (p *packetPacker) PackPathProbePacket(connID protocol.ConnectionID, f ackhandler.Frame, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { + pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + buf := getPacketBuffer() + s, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return shortHeaderPacket{}, nil, err + } + payload := payload{ + frames: []ackhandler.Frame{f}, + length: f.Frame.Length(v), + } + padding := protocol.MinInitialPacketSize - p.shortHeaderPacketLength(connID, pnLen, payload) - protocol.ByteCount(s.Overhead()) + packet, err := p.appendShortHeaderPacket(buf, connID, pn, pnLen, s.KeyPhase(), payload, padding, protocol.MinInitialPacketSize, s, false, v) + if err != nil { + return shortHeaderPacket{}, nil, err + } + packet.IsPathProbePacket = true + return packet, buf, err +} + func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.Version) *wire.ExtendedHeader { pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) hdr := &wire.ExtendedHeader{ diff --git a/packet_packer_test.go b/packet_packer_test.go index 89f332cb..4c252517 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -809,7 +809,7 @@ func testPackProbePacket(t *testing.T, encLevel protocol.EncryptionLevel, perspe tp.pnManager.EXPECT().PeekPacketNumber(encLevel).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(encLevel).Return(protocol.PacketNumber(0x42)) - p, err := tp.packer.MaybePackProbePacket(encLevel, maxPacketSize, time.Now(), protocol.Version1) + p, err := tp.packer.MaybePackPTOProbePacket(encLevel, maxPacketSize, time.Now(), protocol.Version1) require.NoError(t, err) require.NotNil(t, p) require.Len(t, p.longHdrPackets, 1) @@ -838,7 +838,7 @@ func TestPackProbePacketNothingToSend(t *testing.T) { tp.sealingManager.EXPECT().GetInitialSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, gomock.Any(), true) - p, err := tp.packer.MaybePackProbePacket(protocol.EncryptionInitial, protocol.MaxByteCount, time.Now(), protocol.Version1) + p, err := tp.packer.MaybePackPTOProbePacket(protocol.EncryptionInitial, protocol.MaxByteCount, time.Now(), protocol.Version1) require.NoError(t, err) require.Nil(t, p) } @@ -861,7 +861,7 @@ func TestPack1RTTProbePacket(t *testing.T) { }, ) - p, err := tp.packer.MaybePackProbePacket(protocol.Encryption1RTT, maxPacketSize, time.Now(), protocol.Version1) + p, err := tp.packer.MaybePackPTOProbePacket(protocol.Encryption1RTT, maxPacketSize, time.Now(), protocol.Version1) require.NoError(t, err) require.NotNil(t, p) require.True(t, p.IsOnlyShortHeaderPacket()) @@ -882,7 +882,7 @@ func TestPackProbePacketNothingToPack(t *testing.T) { tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), true) tp.framer.EXPECT().HasData() - packet, err := tp.packer.MaybePackProbePacket(protocol.Encryption1RTT, protocol.MaxByteCount, time.Now(), protocol.Version1) + packet, err := tp.packer.MaybePackPTOProbePacket(protocol.Encryption1RTT, protocol.MaxByteCount, time.Now(), protocol.Version1) require.NoError(t, err) require.Nil(t, packet) } @@ -905,4 +905,27 @@ func TestPackMTUProbePacket(t *testing.T) { require.Equal(t, protocol.PacketNumber(0x43), p.PacketNumber) require.Len(t, buffer.Data, int(probePacketSize)) require.True(t, p.IsPathMTUProbePacket) + require.False(t, p.IsPathProbePacket) +} + +func TestPackPathProbePacket(t *testing.T) { + mockCtrl := gomock.NewController(t) + tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) + tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) + tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2) + tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43)) + + p, buf, err := tp.packer.PackPathProbePacket( + protocol.ParseConnectionID([]byte{1, 2, 3, 4}), + ackhandler.Frame{Frame: &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}}, + protocol.Version1, + ) + require.NoError(t, err) + require.Equal(t, protocol.PacketNumber(0x43), p.PacketNumber) + require.Nil(t, p.Ack) + require.Empty(t, p.StreamFrames) + require.Equal(t, &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, p.Frames[0].Frame) + require.Len(t, buf.Data, protocol.MinInitialPacketSize) + require.True(t, p.IsPathProbePacket) + require.False(t, p.IsPathMTUProbePacket) }