simplify handling of packet unpacking errors (#4924)

This commit is contained in:
Marten Seemann 2025-01-26 05:43:26 +01:00 committed by GitHub
parent f20b823154
commit 383b634df6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 125 additions and 76 deletions

View file

@ -536,7 +536,8 @@ func (s *connection) run() error {
runLoop:
for {
if s.framer.QueuedTooManyControlFrames() {
s.closeLocal(&qerr.TransportError{ErrorCode: InternalError})
closeErr = closeError{err: &qerr.TransportError{ErrorCode: InternalError}}
break runLoop
}
// Close immediately if requested
select {
@ -552,14 +553,13 @@ runLoop:
queue := s.undecryptablePacketsToProcess
s.undecryptablePacketsToProcess = nil
for _, p := range queue {
if processed := s.handlePacketImpl(p); processed {
processedUndecryptablePacket = true
}
// Don't set timers and send packets if the packet made us close the connection.
select {
case closeErr = <-s.closeChan:
processed, err := s.handlePacketImpl(p)
if err != nil {
closeErr = closeError{err: err}
break runLoop
default:
}
if processed {
processedUndecryptablePacket = true
}
}
}
@ -577,8 +577,12 @@ runLoop:
// nothing to see here.
case <-sendQueueAvailable:
case firstPacket := <-s.receivedPackets:
wasProcessed := s.handlePacketImpl(firstPacket)
wasProcessed, err := s.handlePacketImpl(firstPacket)
// Don't set timers and send packets if the packet made us close the connection.
if err != nil {
closeErr = closeError{err: err}
break runLoop
}
select {
case closeErr = <-s.closeChan:
break runLoop
@ -593,7 +597,12 @@ runLoop:
for i := 0; i < numPackets; i++ {
select {
case p := <-s.receivedPackets:
if processed := s.handlePacketImpl(p); processed {
processed, err := s.handlePacketImpl(p)
if err != nil {
closeErr = closeError{err: err}
break runLoop
}
if processed {
wasProcessed = true
}
select {
@ -620,7 +629,8 @@ runLoop:
// This could cause packets to be retransmitted.
// Check it before trying to send packets.
if err := s.sentPacketHandler.OnLossDetectionTimeout(now); err != nil {
s.closeLocal(err)
closeErr = closeError{err: err}
break runLoop
}
}
@ -648,7 +658,8 @@ runLoop:
continue
}
if err := s.triggerSending(now); err != nil {
s.closeLocal(err)
closeErr = closeError{err: err}
break runLoop
}
if s.sendQueue.WouldBlock() {
sendQueueAvailable = s.sendQueue.Available()
@ -802,17 +813,16 @@ func (s *connection) handleHandshakeConfirmed(now time.Time) error {
return nil
}
func (s *connection) handlePacketImpl(rp receivedPacket) bool {
func (s *connection) handlePacketImpl(rp receivedPacket) (wasProcessed bool, _ error) {
s.sentPacketHandler.ReceivedBytes(rp.Size(), rp.rcvTime)
if wire.IsVersionNegotiationPacket(rp.data) {
s.handleVersionNegotiationPacket(rp)
return false
return false, nil
}
var counter uint8
var lastConnID protocol.ConnectionID
var processed bool
data := rp.data
p := rp
for len(data) > 0 {
@ -872,26 +882,34 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool {
p.data = packetData
if wasProcessed := s.handleLongHeaderPacket(p, hdr); wasProcessed {
processed = true
processed, err := s.handleLongHeaderPacket(p, hdr)
if err != nil {
return false, err
}
if processed {
wasProcessed = true
}
data = rest
} else {
if counter > 0 {
p.buffer.Split()
}
if wasProcessed := s.handleShortHeaderPacket(p); wasProcessed {
processed = true
processed, err := s.handleShortHeaderPacket(p)
if err != nil {
return false, err
}
if processed {
wasProcessed = true
}
break
}
}
p.buffer.MaybeRelease()
return processed
return wasProcessed, nil
}
func (s *connection) handleShortHeaderPacket(p receivedPacket) bool {
func (s *connection) handleShortHeaderPacket(p receivedPacket) (wasProcessed bool, _ error) {
var wasQueued bool
defer func() {
@ -904,12 +922,12 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket) bool {
destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen)
if err != nil {
s.tracer.DroppedPacket(logging.PacketType1RTT, protocol.InvalidPacketNumber, protocol.ByteCount(len(p.data)), logging.PacketDropHeaderParseError)
return false
return false, nil
}
pn, pnLen, keyPhase, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
if err != nil {
wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT)
return false
wasQueued, err = s.handleUnpackError(err, p, logging.PacketType1RTT)
return false, err
}
if s.logger.Debug() {
@ -922,7 +940,7 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket) bool {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketType1RTT, pn, p.Size(), logging.PacketDropDuplicate)
}
return false
return false, nil
}
var log func([]logging.Frame)
@ -942,13 +960,12 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket) bool {
}
}
if err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log); err != nil {
s.closeLocal(err)
return false
return false, err
}
return true
return true, nil
}
func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ {
func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) (wasProcessed bool, _ error) {
var wasQueued bool
defer func() {
@ -959,7 +976,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header)
}()
if hdr.Type == protocol.PacketTypeRetry {
return s.handleRetryPacket(hdr, p.data, p.rcvTime)
return s.handleRetryPacket(hdr, p.data, p.rcvTime), nil
}
// The server can change the source connection ID with the first Handshake packet.
@ -969,20 +986,20 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header)
s.tracer.DroppedPacket(logging.PacketTypeInitial, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnknownConnectionID)
}
s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID)
return false
return false, nil
}
// drop 0-RTT packets, if we are a client
if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
return false, nil
}
packet, err := s.unpacker.UnpackLongHeader(hdr, p.data)
if err != nil {
wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr))
return false
wasQueued, err = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr))
return false, err
}
if s.logger.Debug() {
@ -995,39 +1012,40 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header)
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), pn, p.Size(), logging.PacketDropDuplicate)
}
return false
return false, nil
}
if err := s.handleUnpackedLongHeaderPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil {
s.closeLocal(err)
return false
return false, err
}
return true
return true, nil
}
func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) {
func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool, _ error) {
switch err {
case handshake.ErrKeysDropped:
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropKeyUnavailable)
}
s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size())
return false, nil
case handshake.ErrKeysNotYetAvailable:
// Sealer for this encryption level not yet available.
// Try again later.
s.tryQueueingUndecryptablePacket(p, pt)
return true
return true, nil
case wire.ErrInvalidReservedBits:
s.closeLocal(&qerr.TransportError{
return false, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: err.Error(),
})
}
case handshake.ErrDecryptionFailed:
// This might be a packet injected by an attacker. Drop it.
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropPayloadDecryptError)
}
s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err)
return false, nil
default:
var headerErr *headerParseError
if errors.As(err, &headerErr) {
@ -1036,13 +1054,12 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P
s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err)
} else {
// This is an error returned by the AEAD (other than ErrDecryptionFailed).
// For example, a PROTOCOL_VIOLATION due to key updates.
s.closeLocal(err)
return false, nil
}
// This is an error returned by the AEAD (other than ErrDecryptionFailed).
// For example, a PROTOCOL_VIOLATION due to key updates.
return false, err
}
return false
}
func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime time.Time) bool /* was this a valid Retry */ {
@ -1585,11 +1602,6 @@ func (s *connection) handleDatagramFrame(f *wire.DatagramFrame) error {
// closeLocal closes the connection and send a CONNECTION_CLOSE containing the error
func (s *connection) closeLocal(e error) {
s.closeOnce.Do(func() {
if e == nil {
s.logger.Infof("Closing connection.")
} else {
s.logger.Errorf("Closing connection with error: %s", e)
}
s.closeChan <- closeError{err: e, immediate: false, remote: false}
})
}
@ -1602,11 +1614,6 @@ func (s *connection) destroy(e error) {
func (s *connection) destroyImpl(e error) {
s.closeOnce.Do(func() {
if nerr, ok := e.(net.Error); ok && nerr.Timeout() {
s.logger.Errorf("Destroying connection: %s", e)
} else {
s.logger.Errorf("Destroying connection with error: %s", e)
}
s.closeChan <- closeError{err: e, immediate: true, remote: false}
})
}
@ -1633,13 +1640,25 @@ func (s *connection) closeWithTransportError(code TransportErrorCode) {
}
func (s *connection) handleCloseError(closeErr *closeError) {
if closeErr.immediate {
if nerr, ok := closeErr.err.(net.Error); ok && nerr.Timeout() {
s.logger.Errorf("Destroying connection: %s", closeErr.err)
} else {
s.logger.Errorf("Destroying connection with error: %s", closeErr.err)
}
} else {
if closeErr.err == nil {
s.logger.Infof("Closing connection.")
} else {
s.logger.Errorf("Closing connection with error: %s", closeErr.err)
}
}
e := closeErr.err
if e == nil {
e = &qerr.ApplicationError{}
} else {
defer func() {
closeErr.err = e
}()
defer func() { closeErr.err = e }()
}
var (

View file

@ -537,7 +537,9 @@ func TestConnectionServerInvalidPackets(t *testing.T) {
Token: []byte("foobar"),
}}, make([]byte, 16) /* Retry integrity tag */)
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
require.False(t, tc.conn.handlePacketImpl(p))
wasProcessed, err := tc.conn.handlePacketImpl(p)
require.NoError(t, err)
require.False(t, wasProcessed)
})
t.Run("version negotiation", func(t *testing.T) {
@ -551,7 +553,9 @@ func TestConnectionServerInvalidPackets(t *testing.T) {
[]Version{Version1},
)
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket)
require.False(t, tc.conn.handlePacketImpl(receivedPacket{data: b, buffer: getPacketBuffer()}))
wasProcessed, err := tc.conn.handlePacketImpl(receivedPacket{data: b, buffer: getPacketBuffer()})
require.NoError(t, err)
require.False(t, wasProcessed)
})
t.Run("unsupported version", func(t *testing.T) {
@ -564,7 +568,9 @@ func TestConnectionServerInvalidPackets(t *testing.T) {
PacketNumberLen: protocol.PacketNumberLen2,
}, nil)
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnsupportedVersion)
require.False(t, tc.conn.handlePacketImpl(p))
wasProcessed, err := tc.conn.handlePacketImpl(p)
require.NoError(t, err)
require.False(t, wasProcessed)
})
t.Run("invalid header", func(t *testing.T) {
@ -578,7 +584,9 @@ func TestConnectionServerInvalidPackets(t *testing.T) {
}, nil)
p.data[0] ^= 0x40 // unset the QUIC bit
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError)
require.False(t, tc.conn.handlePacketImpl(p))
wasProcessed, err := tc.conn.handlePacketImpl(p)
require.NoError(t, err)
require.False(t, wasProcessed)
})
}
@ -592,7 +600,9 @@ func TestConnectionClientDrop0RTT(t *testing.T) {
PacketNumberLen: protocol.PacketNumberLen2,
}, nil)
tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
require.False(t, tc.conn.handlePacketImpl(p))
wasProcessed, err := tc.conn.handlePacketImpl(p)
require.NoError(t, err)
require.False(t, wasProcessed)
}
func TestConnectionUnpacking(t *testing.T) {
@ -639,7 +649,9 @@ func TestConnectionUnpacking(t *testing.T) {
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECNCE, []logging.Frame{})
require.True(t, tc.conn.handlePacketImpl(packet))
wasProcessed, err := tc.conn.handlePacketImpl(packet)
require.NoError(t, err)
require.True(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
// receive a duplicate of this packet
@ -651,7 +663,9 @@ func TestConnectionUnpacking(t *testing.T) {
data: []byte{0}, // one PADDING frame
}, nil)
tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.PacketNumber(0x1337), protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate)
require.False(t, tc.conn.handlePacketImpl(packet))
wasProcessed, err = tc.conn.handlePacketImpl(packet)
require.NoError(t, err)
require.False(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
// receive a short header packet
@ -666,7 +680,9 @@ func TestConnectionUnpacking(t *testing.T) {
protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil,
)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{})
require.True(t, tc.conn.handlePacketImpl(packet))
wasProcessed, err = tc.conn.handlePacketImpl(packet)
require.NoError(t, err)
require.True(t, wasProcessed)
}
func TestConnectionUnpackCoalescedPacket(t *testing.T) {
@ -753,7 +769,9 @@ func TestConnectionUnpackCoalescedPacket(t *testing.T) {
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{&wire.PingFrame{}}),
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(packet3.data)), logging.PacketDropUnknownConnectionID),
)
require.True(t, tc.conn.handlePacketImpl(packet))
wasProcessed, err := tc.conn.handlePacketImpl(packet)
require.NoError(t, err)
require.True(t, wasProcessed)
}
func TestConnectionUnpackFailuresFatal(t *testing.T) {
@ -2472,13 +2490,17 @@ func TestConnectionVersionNegotiationInvalidPackets(t *testing.T) {
tc.srcConnID,
[]protocol.Version{1234, protocol.Version1},
)
require.False(t, tc.conn.handlePacketImpl(vnp))
wasProcessed, err := tc.conn.handlePacketImpl(vnp)
require.NoError(t, err)
require.False(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
// unparseable, since it's missing 2 bytes
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, gomock.Any(), gomock.Any(), logging.PacketDropHeaderParseError)
vnp.data = vnp.data[:len(vnp.data)-2]
require.False(t, tc.conn.handlePacketImpl(vnp))
wasProcessed, err = tc.conn.handlePacketImpl(vnp)
require.NoError(t, err)
require.False(t, wasProcessed)
}
func getRetryPacket(t *testing.T, src, dest, origDest protocol.ConnectionID, token []byte) receivedPacket {
@ -2518,13 +2540,17 @@ func TestConnectionRetryDrops(t *testing.T) {
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropPayloadDecryptError)
retry := getRetryPacket(t, newConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
retry.data[len(retry.data)-1]++
require.False(t, tc.conn.handlePacketImpl(retry))
wasProcessed, err := tc.conn.handlePacketImpl(retry)
require.NoError(t, err)
require.False(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
// receive a retry that doesn't change the connection ID
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedPacket)
retry = getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
require.False(t, tc.conn.handlePacketImpl(retry))
wasProcessed, err = tc.conn.handlePacketImpl(retry)
require.NoError(t, err)
require.False(t, wasProcessed)
}
func TestConnectionRetryAfterReceivedPacket(t *testing.T) {
@ -2549,16 +2575,20 @@ func TestConnectionRetryAfterReceivedPacket(t *testing.T) {
encryptionLevel: protocol.EncryptionInitial,
}, nil,
)
require.True(t, tc.conn.handlePacketImpl(receivedPacket{
wasProcessed, err := tc.conn.handlePacketImpl(receivedPacket{
data: regular,
buffer: getPacketBuffer(),
rcvTime: time.Now(),
}))
})
require.NoError(t, err)
require.True(t, wasProcessed)
// receive a retry
retry := getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedPacket)
require.False(t, tc.conn.handlePacketImpl(retry))
wasProcessed, err = tc.conn.handlePacketImpl(retry)
require.NoError(t, err)
require.False(t, wasProcessed)
}
func TestConnectionConnectionIDChanges(t *testing.T) {