mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-06 21:57:36 +03:00
retransmit the CONNECTION_CLOSE packet when late packets arrive
This commit is contained in:
parent
5e9e445f5b
commit
9d06b2cfff
6 changed files with 112 additions and 19 deletions
22
internal/utils/atomic_bool.go
Normal file
22
internal/utils/atomic_bool.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package utils
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// An AtomicBool is an atomic bool
|
||||
type AtomicBool struct {
|
||||
v int32
|
||||
}
|
||||
|
||||
// Set sets the value
|
||||
func (a *AtomicBool) Set(value bool) {
|
||||
var n int32
|
||||
if value {
|
||||
n = 1
|
||||
}
|
||||
atomic.StoreInt32(&a.v, n)
|
||||
}
|
||||
|
||||
// Get gets the value
|
||||
func (a *AtomicBool) Get() bool {
|
||||
return atomic.LoadInt32(&a.v) != 0
|
||||
}
|
29
internal/utils/atomic_bool_test.go
Normal file
29
internal/utils/atomic_bool_test.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Atomic Bool", func() {
|
||||
var a *AtomicBool
|
||||
|
||||
BeforeEach(func() {
|
||||
a = &AtomicBool{}
|
||||
})
|
||||
|
||||
It("has the right default value", func() {
|
||||
Expect(a.Get()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("sets the value to true", func() {
|
||||
a.Set(true)
|
||||
Expect(a.Get()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("sets the value to false", func() {
|
||||
a.Set(true)
|
||||
a.Set(false)
|
||||
Expect(a.Get()).To(BeFalse())
|
||||
})
|
||||
})
|
|
@ -56,10 +56,6 @@ func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
|||
}
|
||||
|
||||
func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
|
||||
h.mutex.Lock()
|
||||
h.handlers[id] = nil
|
||||
h.mutex.Unlock()
|
||||
|
||||
time.AfterFunc(h.deleteClosedSessionsAfter, func() {
|
||||
h.mutex.Lock()
|
||||
delete(h.handlers, id)
|
||||
|
@ -102,14 +98,12 @@ func (h *packetHandlerMap) close(e error) error {
|
|||
|
||||
var wg sync.WaitGroup
|
||||
for _, handler := range h.handlers {
|
||||
if handler != nil {
|
||||
wg.Add(1)
|
||||
go func(handler packetHandler) {
|
||||
handler.destroy(e)
|
||||
wg.Done()
|
||||
}(handler)
|
||||
}
|
||||
}
|
||||
|
||||
if h.server != nil {
|
||||
h.server.closeWithError(e)
|
||||
|
|
|
@ -88,20 +88,23 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
Expect(err.Error()).To(ContainSubstring("error parsing invariant header:"))
|
||||
})
|
||||
|
||||
It("deletes nil session entries after a wait time", func() {
|
||||
It("deletes closed session entries after a wait time", func() {
|
||||
handler.deleteClosedSessionsAfter = 10 * time.Millisecond
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
||||
handler.Remove(connID)
|
||||
Eventually(func() error {
|
||||
return handler.handlePacket(nil, getPacket(connID))
|
||||
}).Should(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
|
||||
})
|
||||
|
||||
It("ignores packets arriving late for closed sessions", func() {
|
||||
It("passes packets arriving late for closed sessions to that session", func() {
|
||||
handler.deleteClosedSessionsAfter = time.Hour
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
packetHandler.EXPECT().GetVersion().Return(protocol.VersionWhatever)
|
||||
packetHandler.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any())
|
||||
handler.Add(connID, packetHandler)
|
||||
handler.Remove(connID)
|
||||
err := handler.handlePacket(nil, getPacket(connID))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
31
session.go
31
session.go
|
@ -94,9 +94,13 @@ type session struct {
|
|||
|
||||
receivedPackets chan *receivedPacket
|
||||
sendingScheduled chan struct{}
|
||||
// closeChan is used to notify the run loop that it should terminate.
|
||||
closeChan chan closeError
|
||||
|
||||
closeOnce sync.Once
|
||||
closed utils.AtomicBool
|
||||
// closeChan is used to notify the run loop that it should terminate
|
||||
closeChan chan closeError
|
||||
connectionClosePacket *packedPacket
|
||||
packetsReceivedAfterClose int
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
|
@ -418,6 +422,7 @@ runLoop:
|
|||
if err := s.handleCloseError(closeErr); err != nil {
|
||||
s.logger.Infof("Handling close error failed: %s", err)
|
||||
}
|
||||
s.closed.Set(true)
|
||||
s.logger.Infof("Connection %s closed.", s.srcConnID)
|
||||
s.sessionRunner.removeConnectionID(s.srcConnID)
|
||||
s.cryptoStreamHandler.Close()
|
||||
|
@ -596,6 +601,9 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
|
|||
|
||||
// handlePacket is called by the server with a new packet
|
||||
func (s *session) handlePacket(p *receivedPacket) {
|
||||
if s.closed.Get() {
|
||||
s.handlePacketAfterClosed(p)
|
||||
}
|
||||
// Discard packets once the amount of queued packets is larger than
|
||||
// the channel size, protocol.MaxSessionUnprocessedPackets
|
||||
select {
|
||||
|
@ -604,6 +612,24 @@ func (s *session) handlePacket(p *receivedPacket) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *session) handlePacketAfterClosed(p *receivedPacket) {
|
||||
s.packetsReceivedAfterClose++
|
||||
if s.connectionClosePacket == nil {
|
||||
return
|
||||
}
|
||||
// exponential backoff
|
||||
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
|
||||
for n := s.packetsReceivedAfterClose; n > 1; n = n / 2 {
|
||||
if n%2 != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
s.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", s.packetsReceivedAfterClose)
|
||||
if err := s.conn.Write(s.connectionClosePacket.raw); err != nil {
|
||||
s.logger.Debugf("Error retransmitting CONNECTION_CLOSE: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
|
||||
encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
|
||||
if err != nil {
|
||||
|
@ -943,6 +969,7 @@ func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.connectionClosePacket = packet
|
||||
s.logPacket(packet)
|
||||
return s.conn.Write(packet.raw)
|
||||
}
|
||||
|
|
|
@ -426,6 +426,24 @@ var _ = Describe("Session", func() {
|
|||
sess.Close()
|
||||
Eventually(returned).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("retransmits the CONNECTION_CLOSE packet if packets are arriving late", func() {
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
cryptoSetup.EXPECT().Close()
|
||||
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{raw: []byte("foobar")}, nil)
|
||||
sess.Close()
|
||||
Expect(mconn.written).To(Receive(Equal([]byte("foobar")))) // receive the CONNECTION_CLOSE
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
for i := 1; i <= 20; i++ {
|
||||
sess.handlePacket(&receivedPacket{})
|
||||
if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 {
|
||||
Expect(mconn.written).To(Receive(Equal([]byte("foobar")))) // receive the CONNECTION_CLOSE
|
||||
} else {
|
||||
Expect(mconn.written).To(HaveLen(0))
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("receiving packets", func() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue