refactor packet handling into a separate function (#4926)

This commit is contained in:
Marten Seemann 2025-01-26 06:17:40 +01:00 committed by GitHub
parent 19213b24bc
commit ac25c646ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 49 additions and 41 deletions

View file

@ -551,7 +551,7 @@ runLoop:
queue := s.undecryptablePacketsToProcess
s.undecryptablePacketsToProcess = nil
for _, p := range queue {
processed, err := s.handlePacketImpl(p)
processed, err := s.handleOnePacket(p)
if err != nil {
s.setCloseError(&closeError{err: err})
break runLoop
@ -575,34 +575,11 @@ runLoop:
// nothing to see here.
case <-sendQueueAvailable:
case firstPacket := <-s.receivedPackets:
wasProcessed, err := s.handlePacketImpl(firstPacket)
// Don't set timers and send packets if the packet made us close the connection.
wasProcessed, err := s.handlePackets(firstPacket)
if err != nil {
s.setCloseError(&closeError{err: err})
break runLoop
}
if s.handshakeComplete {
// Now process all packets in the receivedPackets channel.
// Limit the number of packets to the length of the receivedPackets channel,
// so we eventually get a chance to send out an ACK when receiving a lot of packets.
numPackets := len(s.receivedPackets)
receiveLoop:
for i := 0; i < numPackets; i++ {
select {
case p := <-s.receivedPackets:
processed, err := s.handlePacketImpl(p)
if err != nil {
s.setCloseError(&closeError{err: err})
break runLoop
}
if processed {
wasProcessed = true
}
default:
break receiveLoop
}
}
}
// Only reset the timers if this packet was actually processed.
// This avoids modifying any state when handling undecryptable packets,
// which could be injected by an attacker.
@ -807,7 +784,38 @@ func (s *connection) handleHandshakeConfirmed(now time.Time) error {
return nil
}
func (s *connection) handlePacketImpl(rp receivedPacket) (wasProcessed bool, _ error) {
func (s *connection) handlePackets(firstPacket receivedPacket) (wasProcessed bool, _ error) {
wasProcessed, err := s.handleOnePacket(firstPacket)
if err != nil {
return false, err
}
// only process a single packet at a time before handshake completion
if !s.handshakeComplete {
return wasProcessed, nil
}
// Now process all packets in the receivedPackets channel.
// Limit the number of packets to the length of the receivedPackets channel,
// so we eventually get a chance to send out an ACK when receiving a lot of packets.
numPackets := len(s.receivedPackets)
receiveLoop:
for i := 0; i < numPackets; i++ {
select {
case p := <-s.receivedPackets:
processed, err := s.handleOnePacket(p)
if err != nil {
return false, err
}
if processed {
wasProcessed = true
}
default:
break receiveLoop
}
}
return wasProcessed, nil
}
func (s *connection) handleOnePacket(rp receivedPacket) (wasProcessed bool, _ error) {
s.sentPacketHandler.ReceivedBytes(rp.Size(), rp.rcvTime)
if wire.IsVersionNegotiationPacket(rp.data) {

View file

@ -537,7 +537,7 @@ func TestConnectionServerInvalidPackets(t *testing.T) {
Token: []byte("foobar"),
}}, make([]byte, 16) /* Retry integrity tag */)
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
wasProcessed, err := tc.conn.handlePacketImpl(p)
wasProcessed, err := tc.conn.handleOnePacket(p)
require.NoError(t, err)
require.False(t, wasProcessed)
})
@ -553,7 +553,7 @@ func TestConnectionServerInvalidPackets(t *testing.T) {
[]Version{Version1},
)
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket)
wasProcessed, err := tc.conn.handlePacketImpl(receivedPacket{data: b, buffer: getPacketBuffer()})
wasProcessed, err := tc.conn.handleOnePacket(receivedPacket{data: b, buffer: getPacketBuffer()})
require.NoError(t, err)
require.False(t, wasProcessed)
})
@ -568,7 +568,7 @@ func TestConnectionServerInvalidPackets(t *testing.T) {
PacketNumberLen: protocol.PacketNumberLen2,
}, nil)
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnsupportedVersion)
wasProcessed, err := tc.conn.handlePacketImpl(p)
wasProcessed, err := tc.conn.handleOnePacket(p)
require.NoError(t, err)
require.False(t, wasProcessed)
})
@ -584,7 +584,7 @@ func TestConnectionServerInvalidPackets(t *testing.T) {
}, nil)
p.data[0] ^= 0x40 // unset the QUIC bit
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError)
wasProcessed, err := tc.conn.handlePacketImpl(p)
wasProcessed, err := tc.conn.handleOnePacket(p)
require.NoError(t, err)
require.False(t, wasProcessed)
})
@ -600,7 +600,7 @@ func TestConnectionClientDrop0RTT(t *testing.T) {
PacketNumberLen: protocol.PacketNumberLen2,
}, nil)
tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
wasProcessed, err := tc.conn.handlePacketImpl(p)
wasProcessed, err := tc.conn.handleOnePacket(p)
require.NoError(t, err)
require.False(t, wasProcessed)
}
@ -649,7 +649,7 @@ func TestConnectionUnpacking(t *testing.T) {
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECNCE, []logging.Frame{})
wasProcessed, err := tc.conn.handlePacketImpl(packet)
wasProcessed, err := tc.conn.handleOnePacket(packet)
require.NoError(t, err)
require.True(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
@ -663,7 +663,7 @@ func TestConnectionUnpacking(t *testing.T) {
data: []byte{0}, // one PADDING frame
}, nil)
tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.PacketNumber(0x1337), protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate)
wasProcessed, err = tc.conn.handlePacketImpl(packet)
wasProcessed, err = tc.conn.handleOnePacket(packet)
require.NoError(t, err)
require.False(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
@ -680,7 +680,7 @@ func TestConnectionUnpacking(t *testing.T) {
protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil,
)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{})
wasProcessed, err = tc.conn.handlePacketImpl(packet)
wasProcessed, err = tc.conn.handleOnePacket(packet)
require.NoError(t, err)
require.True(t, wasProcessed)
}
@ -769,7 +769,7 @@ func TestConnectionUnpackCoalescedPacket(t *testing.T) {
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{&wire.PingFrame{}}),
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(packet3.data)), logging.PacketDropUnknownConnectionID),
)
wasProcessed, err := tc.conn.handlePacketImpl(packet)
wasProcessed, err := tc.conn.handleOnePacket(packet)
require.NoError(t, err)
require.True(t, wasProcessed)
}
@ -2498,7 +2498,7 @@ func TestConnectionVersionNegotiationInvalidPackets(t *testing.T) {
tc.srcConnID,
[]protocol.Version{1234, protocol.Version1},
)
wasProcessed, err := tc.conn.handlePacketImpl(vnp)
wasProcessed, err := tc.conn.handleOnePacket(vnp)
require.NoError(t, err)
require.False(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
@ -2506,7 +2506,7 @@ func TestConnectionVersionNegotiationInvalidPackets(t *testing.T) {
// unparseable, since it's missing 2 bytes
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, gomock.Any(), gomock.Any(), logging.PacketDropHeaderParseError)
vnp.data = vnp.data[:len(vnp.data)-2]
wasProcessed, err = tc.conn.handlePacketImpl(vnp)
wasProcessed, err = tc.conn.handleOnePacket(vnp)
require.NoError(t, err)
require.False(t, wasProcessed)
}
@ -2548,7 +2548,7 @@ func TestConnectionRetryDrops(t *testing.T) {
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropPayloadDecryptError)
retry := getRetryPacket(t, newConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
retry.data[len(retry.data)-1]++
wasProcessed, err := tc.conn.handlePacketImpl(retry)
wasProcessed, err := tc.conn.handleOnePacket(retry)
require.NoError(t, err)
require.False(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
@ -2556,7 +2556,7 @@ func TestConnectionRetryDrops(t *testing.T) {
// receive a retry that doesn't change the connection ID
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedPacket)
retry = getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
wasProcessed, err = tc.conn.handlePacketImpl(retry)
wasProcessed, err = tc.conn.handleOnePacket(retry)
require.NoError(t, err)
require.False(t, wasProcessed)
}
@ -2583,7 +2583,7 @@ func TestConnectionRetryAfterReceivedPacket(t *testing.T) {
encryptionLevel: protocol.EncryptionInitial,
}, nil,
)
wasProcessed, err := tc.conn.handlePacketImpl(receivedPacket{
wasProcessed, err := tc.conn.handleOnePacket(receivedPacket{
data: regular,
buffer: getPacketBuffer(),
rcvTime: time.Now(),
@ -2594,7 +2594,7 @@ func TestConnectionRetryAfterReceivedPacket(t *testing.T) {
// receive a retry
retry := getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedPacket)
wasProcessed, err = tc.conn.handlePacketImpl(retry)
wasProcessed, err = tc.conn.handleOnePacket(retry)
require.NoError(t, err)
require.False(t, wasProcessed)
}