refactor connection error propagation (#4925)

This commit is contained in:
Marten Seemann 2025-01-26 06:01:29 +01:00 committed by GitHub
parent 383b634df6
commit 19213b24bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 57 additions and 56 deletions

View file

@ -94,7 +94,6 @@ type connRunner interface {
type closeError struct {
err error
remote bool
immediate bool
}
@ -158,9 +157,9 @@ type connection struct {
receivedPackets chan receivedPacket
sendingScheduled chan struct{}
closeOnce sync.Once
// closeChan is used to notify the run loop that it should terminate
closeChan chan closeError
closeChan chan struct{}
closeErr atomic.Pointer[closeError]
ctx context.Context
ctxCancel context.CancelCauseFunc
@ -480,7 +479,7 @@ func (s *connection) preSetup() {
)
s.framer = newFramer(s.connFlowController)
s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets)
s.closeChan = make(chan closeError, 1)
s.closeChan = make(chan struct{}, 1)
s.sendingScheduled = make(chan struct{}, 1)
s.handshakeCompleteChan = make(chan struct{})
@ -493,9 +492,8 @@ func (s *connection) preSetup() {
}
// run the connection main loop
func (s *connection) run() error {
var closeErr closeError
defer func() { s.ctxCancel(closeErr.err) }()
func (s *connection) run() (err error) {
defer func() { s.ctxCancel(err) }()
defer func() {
// Drain queued packets that will never be processed.
@ -536,12 +534,12 @@ func (s *connection) run() error {
runLoop:
for {
if s.framer.QueuedTooManyControlFrames() {
closeErr = closeError{err: &qerr.TransportError{ErrorCode: InternalError}}
s.setCloseError(&closeError{err: &qerr.TransportError{ErrorCode: InternalError}})
break runLoop
}
// Close immediately if requested
select {
case closeErr = <-s.closeChan:
case <-s.closeChan:
break runLoop
default:
}
@ -555,7 +553,7 @@ runLoop:
for _, p := range queue {
processed, err := s.handlePacketImpl(p)
if err != nil {
closeErr = closeError{err: err}
s.setCloseError(&closeError{err: err})
break runLoop
}
if processed {
@ -566,7 +564,7 @@ runLoop:
// If we processed any undecryptable packets, jump to the resetting of the timers directly.
if !processedUndecryptablePacket {
select {
case closeErr = <-s.closeChan:
case <-s.closeChan:
break runLoop
case <-s.timer.Chan():
s.timer.SetRead()
@ -580,14 +578,9 @@ runLoop:
wasProcessed, err := s.handlePacketImpl(firstPacket)
// Don't set timers and send packets if the packet made us close the connection.
if err != nil {
closeErr = closeError{err: err}
s.setCloseError(&closeError{err: err})
break runLoop
}
select {
case closeErr = <-s.closeChan:
break runLoop
default:
}
if s.handshakeComplete {
// Now process all packets in the receivedPackets channel.
// Limit the number of packets to the length of the receivedPackets channel,
@ -599,17 +592,12 @@ runLoop:
case p := <-s.receivedPackets:
processed, err := s.handlePacketImpl(p)
if err != nil {
closeErr = closeError{err: err}
s.setCloseError(&closeError{err: err})
break runLoop
}
if processed {
wasProcessed = true
}
select {
case closeErr = <-s.closeChan:
break runLoop
default:
}
default:
break receiveLoop
}
@ -629,7 +617,7 @@ runLoop:
// This could cause packets to be retransmitted.
// Check it before trying to send packets.
if err := s.sentPacketHandler.OnLossDetectionTimeout(now); err != nil {
closeErr = closeError{err: err}
s.setCloseError(&closeError{err: err})
break runLoop
}
}
@ -641,13 +629,13 @@ runLoop:
s.keepAlivePingSent = true
} else if !s.handshakeComplete && now.Sub(s.creationTime) >= s.config.handshakeTimeout() {
s.destroyImpl(qerr.ErrHandshakeTimeout)
continue
break runLoop
} else {
idleTimeoutStartTime := s.idleTimeoutStartTime()
if (!s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.config.HandshakeIdleTimeout) ||
(s.handshakeComplete && now.After(s.nextIdleTimeoutTime())) {
s.destroyImpl(qerr.ErrIdleTimeout)
continue
break runLoop
}
}
@ -657,8 +645,13 @@ runLoop:
sendQueueAvailable = s.sendQueue.Available()
continue
}
if s.closeErr.Load() != nil {
break runLoop
}
if err := s.triggerSending(now); err != nil {
closeErr = closeError{err: err}
s.setCloseError(&closeError{err: err})
break runLoop
}
if s.sendQueue.WouldBlock() {
@ -668,9 +661,10 @@ runLoop:
}
}
closeErr := s.closeErr.Load()
s.cryptoStreamHandler.Close()
s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE
s.handleCloseError(&closeErr)
s.handleCloseError(closeErr)
if s.tracer != nil && s.tracer.Close != nil {
if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) {
s.tracer.Close()
@ -1348,7 +1342,7 @@ func (s *connection) handleFrame(
case *wire.AckFrame:
err = s.handleAckFrame(frame, encLevel, rcvTime)
case *wire.ConnectionCloseFrame:
s.handleConnectionCloseFrame(frame)
err = s.handleConnectionCloseFrame(frame)
case *wire.ResetStreamFrame:
err = s.handleResetStreamFrame(frame, rcvTime)
case *wire.MaxDataFrame:
@ -1401,21 +1395,20 @@ func (s *connection) handlePacket(p receivedPacket) {
}
}
func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) {
func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) error {
if frame.IsApplicationError {
s.closeRemote(&qerr.ApplicationError{
return &qerr.ApplicationError{
Remote: true,
ErrorCode: qerr.ApplicationErrorCode(frame.ErrorCode),
ErrorMessage: frame.ReasonPhrase,
})
return
}
}
s.closeRemote(&qerr.TransportError{
return &qerr.TransportError{
Remote: true,
ErrorCode: qerr.TransportErrorCode(frame.ErrorCode),
FrameType: frame.FrameType,
ErrorMessage: frame.ReasonPhrase,
})
}
}
func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
@ -1599,11 +1592,17 @@ func (s *connection) handleDatagramFrame(f *wire.DatagramFrame) error {
return nil
}
func (s *connection) setCloseError(e *closeError) {
s.closeErr.CompareAndSwap(nil, e)
select {
case s.closeChan <- struct{}{}:
default:
}
}
// closeLocal closes the connection and send a CONNECTION_CLOSE containing the error
func (s *connection) closeLocal(e error) {
s.closeOnce.Do(func() {
s.closeChan <- closeError{err: e, immediate: false, remote: false}
})
s.setCloseError(&closeError{err: e, immediate: false})
}
// destroy closes the connection without sending the error on the wire
@ -1613,16 +1612,7 @@ func (s *connection) destroy(e error) {
}
func (s *connection) destroyImpl(e error) {
s.closeOnce.Do(func() {
s.closeChan <- closeError{err: e, immediate: true, remote: false}
})
}
func (s *connection) closeRemote(e error) {
s.closeOnce.Do(func() {
s.logger.Errorf("Peer closed connection with error: %s", e)
s.closeChan <- closeError{err: e, immediate: true, remote: true}
})
s.setCloseError(&closeError{err: e, immediate: true})
}
func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error {
@ -1668,14 +1658,17 @@ func (s *connection) handleCloseError(closeErr *closeError) {
applicationErr *ApplicationError
transportErr *TransportError
)
var isRemoteClose bool
switch {
case errors.Is(e, qerr.ErrIdleTimeout),
errors.Is(e, qerr.ErrHandshakeTimeout),
errors.As(e, &statelessResetErr),
errors.As(e, &versionNegotiationErr),
errors.As(e, &recreateErr),
errors.As(e, &applicationErr),
errors.As(e, &transportErr):
errors.As(e, &recreateErr):
case errors.As(e, &applicationErr):
isRemoteClose = applicationErr.Remote
case errors.As(e, &transportErr):
isRemoteClose = transportErr.Remote
case closeErr.immediate:
e = closeErr.err
default:
@ -1701,7 +1694,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
}
// If this is a remote close we're done here
if closeErr.remote {
if isRemoteClose {
s.connIDGenerator.ReplaceWithClosed(nil)
return
}

View file

@ -900,13 +900,23 @@ func TestConnectionRemoteClose(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockStreamManager := NewMockStreamManager(mockCtrl)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptStreamManager(mockStreamManager),
connectionOptTracer(tr),
connectionOptUnpacker(unpacker),
)
ccf, err := (&wire.ConnectionCloseFrame{
ErrorCode: uint64(qerr.StreamLimitError),
ReasonPhrase: "foobar",
}).Append(nil, protocol.Version1)
require.NoError(t, err)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2, protocol.KeyPhaseBit(0), ccf, nil)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
expectedErr := &qerr.TransportError{ErrorCode: qerr.StreamLimitError, Remote: true}
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
streamErrChan := make(chan error, 1)
@ -918,10 +928,8 @@ func TestConnectionRemoteClose(t *testing.T) {
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.handleFrame(&wire.ConnectionCloseFrame{
ErrorCode: uint64(qerr.StreamLimitError),
ReasonPhrase: "foobar",
}, protocol.Encryption1RTT, protocol.ConnectionID{}, time.Now())
p := getShortHeaderPacket(t, tc.srcConnID, 1, []byte("encrypted"))
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
select {
case err := <-errChan: