mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
refactor connection error propagation (#4925)
This commit is contained in:
parent
383b634df6
commit
19213b24bc
2 changed files with 57 additions and 56 deletions
|
@ -94,7 +94,6 @@ type connRunner interface {
|
|||
|
||||
type closeError struct {
|
||||
err error
|
||||
remote bool
|
||||
immediate bool
|
||||
}
|
||||
|
||||
|
@ -158,9 +157,9 @@ type connection struct {
|
|||
receivedPackets chan receivedPacket
|
||||
sendingScheduled chan struct{}
|
||||
|
||||
closeOnce sync.Once
|
||||
// closeChan is used to notify the run loop that it should terminate
|
||||
closeChan chan closeError
|
||||
closeChan chan struct{}
|
||||
closeErr atomic.Pointer[closeError]
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelCauseFunc
|
||||
|
@ -480,7 +479,7 @@ func (s *connection) preSetup() {
|
|||
)
|
||||
s.framer = newFramer(s.connFlowController)
|
||||
s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets)
|
||||
s.closeChan = make(chan closeError, 1)
|
||||
s.closeChan = make(chan struct{}, 1)
|
||||
s.sendingScheduled = make(chan struct{}, 1)
|
||||
s.handshakeCompleteChan = make(chan struct{})
|
||||
|
||||
|
@ -493,9 +492,8 @@ func (s *connection) preSetup() {
|
|||
}
|
||||
|
||||
// run the connection main loop
|
||||
func (s *connection) run() error {
|
||||
var closeErr closeError
|
||||
defer func() { s.ctxCancel(closeErr.err) }()
|
||||
func (s *connection) run() (err error) {
|
||||
defer func() { s.ctxCancel(err) }()
|
||||
|
||||
defer func() {
|
||||
// Drain queued packets that will never be processed.
|
||||
|
@ -536,12 +534,12 @@ func (s *connection) run() error {
|
|||
runLoop:
|
||||
for {
|
||||
if s.framer.QueuedTooManyControlFrames() {
|
||||
closeErr = closeError{err: &qerr.TransportError{ErrorCode: InternalError}}
|
||||
s.setCloseError(&closeError{err: &qerr.TransportError{ErrorCode: InternalError}})
|
||||
break runLoop
|
||||
}
|
||||
// Close immediately if requested
|
||||
select {
|
||||
case closeErr = <-s.closeChan:
|
||||
case <-s.closeChan:
|
||||
break runLoop
|
||||
default:
|
||||
}
|
||||
|
@ -555,7 +553,7 @@ runLoop:
|
|||
for _, p := range queue {
|
||||
processed, err := s.handlePacketImpl(p)
|
||||
if err != nil {
|
||||
closeErr = closeError{err: err}
|
||||
s.setCloseError(&closeError{err: err})
|
||||
break runLoop
|
||||
}
|
||||
if processed {
|
||||
|
@ -566,7 +564,7 @@ runLoop:
|
|||
// If we processed any undecryptable packets, jump to the resetting of the timers directly.
|
||||
if !processedUndecryptablePacket {
|
||||
select {
|
||||
case closeErr = <-s.closeChan:
|
||||
case <-s.closeChan:
|
||||
break runLoop
|
||||
case <-s.timer.Chan():
|
||||
s.timer.SetRead()
|
||||
|
@ -580,14 +578,9 @@ runLoop:
|
|||
wasProcessed, err := s.handlePacketImpl(firstPacket)
|
||||
// Don't set timers and send packets if the packet made us close the connection.
|
||||
if err != nil {
|
||||
closeErr = closeError{err: err}
|
||||
s.setCloseError(&closeError{err: err})
|
||||
break runLoop
|
||||
}
|
||||
select {
|
||||
case closeErr = <-s.closeChan:
|
||||
break runLoop
|
||||
default:
|
||||
}
|
||||
if s.handshakeComplete {
|
||||
// Now process all packets in the receivedPackets channel.
|
||||
// Limit the number of packets to the length of the receivedPackets channel,
|
||||
|
@ -599,17 +592,12 @@ runLoop:
|
|||
case p := <-s.receivedPackets:
|
||||
processed, err := s.handlePacketImpl(p)
|
||||
if err != nil {
|
||||
closeErr = closeError{err: err}
|
||||
s.setCloseError(&closeError{err: err})
|
||||
break runLoop
|
||||
}
|
||||
if processed {
|
||||
wasProcessed = true
|
||||
}
|
||||
select {
|
||||
case closeErr = <-s.closeChan:
|
||||
break runLoop
|
||||
default:
|
||||
}
|
||||
default:
|
||||
break receiveLoop
|
||||
}
|
||||
|
@ -629,7 +617,7 @@ runLoop:
|
|||
// This could cause packets to be retransmitted.
|
||||
// Check it before trying to send packets.
|
||||
if err := s.sentPacketHandler.OnLossDetectionTimeout(now); err != nil {
|
||||
closeErr = closeError{err: err}
|
||||
s.setCloseError(&closeError{err: err})
|
||||
break runLoop
|
||||
}
|
||||
}
|
||||
|
@ -641,13 +629,13 @@ runLoop:
|
|||
s.keepAlivePingSent = true
|
||||
} else if !s.handshakeComplete && now.Sub(s.creationTime) >= s.config.handshakeTimeout() {
|
||||
s.destroyImpl(qerr.ErrHandshakeTimeout)
|
||||
continue
|
||||
break runLoop
|
||||
} else {
|
||||
idleTimeoutStartTime := s.idleTimeoutStartTime()
|
||||
if (!s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.config.HandshakeIdleTimeout) ||
|
||||
(s.handshakeComplete && now.After(s.nextIdleTimeoutTime())) {
|
||||
s.destroyImpl(qerr.ErrIdleTimeout)
|
||||
continue
|
||||
break runLoop
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -657,8 +645,13 @@ runLoop:
|
|||
sendQueueAvailable = s.sendQueue.Available()
|
||||
continue
|
||||
}
|
||||
|
||||
if s.closeErr.Load() != nil {
|
||||
break runLoop
|
||||
}
|
||||
|
||||
if err := s.triggerSending(now); err != nil {
|
||||
closeErr = closeError{err: err}
|
||||
s.setCloseError(&closeError{err: err})
|
||||
break runLoop
|
||||
}
|
||||
if s.sendQueue.WouldBlock() {
|
||||
|
@ -668,9 +661,10 @@ runLoop:
|
|||
}
|
||||
}
|
||||
|
||||
closeErr := s.closeErr.Load()
|
||||
s.cryptoStreamHandler.Close()
|
||||
s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE
|
||||
s.handleCloseError(&closeErr)
|
||||
s.handleCloseError(closeErr)
|
||||
if s.tracer != nil && s.tracer.Close != nil {
|
||||
if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) {
|
||||
s.tracer.Close()
|
||||
|
@ -1348,7 +1342,7 @@ func (s *connection) handleFrame(
|
|||
case *wire.AckFrame:
|
||||
err = s.handleAckFrame(frame, encLevel, rcvTime)
|
||||
case *wire.ConnectionCloseFrame:
|
||||
s.handleConnectionCloseFrame(frame)
|
||||
err = s.handleConnectionCloseFrame(frame)
|
||||
case *wire.ResetStreamFrame:
|
||||
err = s.handleResetStreamFrame(frame, rcvTime)
|
||||
case *wire.MaxDataFrame:
|
||||
|
@ -1401,21 +1395,20 @@ func (s *connection) handlePacket(p receivedPacket) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) {
|
||||
func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) error {
|
||||
if frame.IsApplicationError {
|
||||
s.closeRemote(&qerr.ApplicationError{
|
||||
return &qerr.ApplicationError{
|
||||
Remote: true,
|
||||
ErrorCode: qerr.ApplicationErrorCode(frame.ErrorCode),
|
||||
ErrorMessage: frame.ReasonPhrase,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
s.closeRemote(&qerr.TransportError{
|
||||
return &qerr.TransportError{
|
||||
Remote: true,
|
||||
ErrorCode: qerr.TransportErrorCode(frame.ErrorCode),
|
||||
FrameType: frame.FrameType,
|
||||
ErrorMessage: frame.ReasonPhrase,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
|
||||
|
@ -1599,11 +1592,17 @@ func (s *connection) handleDatagramFrame(f *wire.DatagramFrame) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *connection) setCloseError(e *closeError) {
|
||||
s.closeErr.CompareAndSwap(nil, e)
|
||||
select {
|
||||
case s.closeChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// closeLocal closes the connection and send a CONNECTION_CLOSE containing the error
|
||||
func (s *connection) closeLocal(e error) {
|
||||
s.closeOnce.Do(func() {
|
||||
s.closeChan <- closeError{err: e, immediate: false, remote: false}
|
||||
})
|
||||
s.setCloseError(&closeError{err: e, immediate: false})
|
||||
}
|
||||
|
||||
// destroy closes the connection without sending the error on the wire
|
||||
|
@ -1613,16 +1612,7 @@ func (s *connection) destroy(e error) {
|
|||
}
|
||||
|
||||
func (s *connection) destroyImpl(e error) {
|
||||
s.closeOnce.Do(func() {
|
||||
s.closeChan <- closeError{err: e, immediate: true, remote: false}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *connection) closeRemote(e error) {
|
||||
s.closeOnce.Do(func() {
|
||||
s.logger.Errorf("Peer closed connection with error: %s", e)
|
||||
s.closeChan <- closeError{err: e, immediate: true, remote: true}
|
||||
})
|
||||
s.setCloseError(&closeError{err: e, immediate: true})
|
||||
}
|
||||
|
||||
func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error {
|
||||
|
@ -1668,14 +1658,17 @@ func (s *connection) handleCloseError(closeErr *closeError) {
|
|||
applicationErr *ApplicationError
|
||||
transportErr *TransportError
|
||||
)
|
||||
var isRemoteClose bool
|
||||
switch {
|
||||
case errors.Is(e, qerr.ErrIdleTimeout),
|
||||
errors.Is(e, qerr.ErrHandshakeTimeout),
|
||||
errors.As(e, &statelessResetErr),
|
||||
errors.As(e, &versionNegotiationErr),
|
||||
errors.As(e, &recreateErr),
|
||||
errors.As(e, &applicationErr),
|
||||
errors.As(e, &transportErr):
|
||||
errors.As(e, &recreateErr):
|
||||
case errors.As(e, &applicationErr):
|
||||
isRemoteClose = applicationErr.Remote
|
||||
case errors.As(e, &transportErr):
|
||||
isRemoteClose = transportErr.Remote
|
||||
case closeErr.immediate:
|
||||
e = closeErr.err
|
||||
default:
|
||||
|
@ -1701,7 +1694,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
|
|||
}
|
||||
|
||||
// If this is a remote close we're done here
|
||||
if closeErr.remote {
|
||||
if isRemoteClose {
|
||||
s.connIDGenerator.ReplaceWithClosed(nil)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -900,13 +900,23 @@ func TestConnectionRemoteClose(t *testing.T) {
|
|||
mockCtrl := gomock.NewController(t)
|
||||
mockStreamManager := NewMockStreamManager(mockCtrl)
|
||||
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
|
||||
unpacker := NewMockUnpacker(mockCtrl)
|
||||
tc := newServerTestConnection(t,
|
||||
mockCtrl,
|
||||
nil,
|
||||
false,
|
||||
connectionOptStreamManager(mockStreamManager),
|
||||
connectionOptTracer(tr),
|
||||
connectionOptUnpacker(unpacker),
|
||||
)
|
||||
ccf, err := (&wire.ConnectionCloseFrame{
|
||||
ErrorCode: uint64(qerr.StreamLimitError),
|
||||
ReasonPhrase: "foobar",
|
||||
}).Append(nil, protocol.Version1)
|
||||
require.NoError(t, err)
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2, protocol.KeyPhaseBit(0), ccf, nil)
|
||||
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
|
||||
expectedErr := &qerr.TransportError{ErrorCode: qerr.StreamLimitError, Remote: true}
|
||||
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
|
||||
streamErrChan := make(chan error, 1)
|
||||
|
@ -918,10 +928,8 @@ func TestConnectionRemoteClose(t *testing.T) {
|
|||
errChan := make(chan error, 1)
|
||||
go func() { errChan <- tc.conn.run() }()
|
||||
|
||||
tc.conn.handleFrame(&wire.ConnectionCloseFrame{
|
||||
ErrorCode: uint64(qerr.StreamLimitError),
|
||||
ReasonPhrase: "foobar",
|
||||
}, protocol.Encryption1RTT, protocol.ConnectionID{}, time.Now())
|
||||
p := getShortHeaderPacket(t, tc.srcConnID, 1, []byte("encrypted"))
|
||||
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue