use a single Go routine to send copies of CONNECTION_CLOSE packets

This commit is contained in:
Marten Seemann 2022-08-21 15:13:50 +03:00
parent c3ab9c4ea9
commit b659414495
9 changed files with 102 additions and 145 deletions

View file

@ -2,7 +2,7 @@ package quic
import (
"math/bits"
"sync"
"net"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
@ -12,85 +12,38 @@ import (
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
// with an exponential backoff.
type closedLocalConn struct {
conn sendConn
connClosePacket []byte
closeOnce sync.Once
closeChan chan struct{} // is closed when the connection is closed or destroyed
receivedPackets chan *receivedPacket
counter uint64 // number of packets received
counter uint32
perspective protocol.Perspective
logger utils.Logger
sendPacket func(net.Addr, *packetInfo)
}
var _ packetHandler = &closedLocalConn{}
// newClosedLocalConn creates a new closedLocalConn and runs it.
func newClosedLocalConn(
conn sendConn,
connClosePacket []byte,
perspective protocol.Perspective,
logger utils.Logger,
) packetHandler {
s := &closedLocalConn{
conn: conn,
connClosePacket: connClosePacket,
perspective: perspective,
func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
return &closedLocalConn{
sendPacket: sendPacket,
perspective: pers,
logger: logger,
closeChan: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, 64),
}
go s.run()
return s
}
func (s *closedLocalConn) run() {
for {
select {
case p := <-s.receivedPackets:
s.handlePacketImpl(p)
case <-s.closeChan:
return
}
}
}
func (s *closedLocalConn) handlePacket(p *receivedPacket) {
select {
case s.receivedPackets <- p:
default:
}
}
func (s *closedLocalConn) handlePacketImpl(_ *receivedPacket) {
s.counter++
func (c *closedLocalConn) handlePacket(p *receivedPacket) {
c.counter++
// exponential backoff
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
if bits.OnesCount64(s.counter) != 1 {
if bits.OnesCount32(c.counter) != 1 {
return
}
s.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", s.counter)
if err := s.conn.Write(s.connClosePacket); err != nil {
s.logger.Debugf("Error retransmitting CONNECTION_CLOSE: %s", err)
}
c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", c.counter)
c.sendPacket(p.remoteAddr, p.info)
}
func (s *closedLocalConn) shutdown() {
s.destroy(nil)
}
func (s *closedLocalConn) destroy(error) {
s.closeOnce.Do(func() {
close(s.closeChan)
})
}
func (s *closedLocalConn) getPerspective() protocol.Perspective {
return s.perspective
}
func (c *closedLocalConn) shutdown() {}
func (c *closedLocalConn) destroy(error) {}
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
// A closedRemoteConn is a connection that was closed remotely.
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.

View file

@ -1,10 +1,8 @@
package quic
import (
"errors"
"time"
"net"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
@ -13,44 +11,28 @@ import (
)
var _ = Describe("Closed local connection", func() {
var (
conn packetHandler
mconn *MockSendConn
)
BeforeEach(func() {
mconn = NewMockSendConn(mockCtrl)
conn = newClosedLocalConn(mconn, []byte("close"), protocol.PerspectiveClient, utils.DefaultLogger)
})
AfterEach(func() {
Eventually(areClosedConnsRunning).Should(BeFalse())
})
It("tells its perspective", func() {
conn := newClosedLocalConn(nil, protocol.PerspectiveClient, utils.DefaultLogger)
Expect(conn.getPerspective()).To(Equal(protocol.PerspectiveClient))
// stop the connection
conn.shutdown()
})
It("repeats the packet containing the CONNECTION_CLOSE frame", func() {
written := make(chan []byte)
mconn.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }).AnyTimes()
written := make(chan net.Addr, 1)
conn := newClosedLocalConn(
func(addr net.Addr, _ *packetInfo) { written <- addr },
protocol.PerspectiveClient,
utils.DefaultLogger,
)
addr := &net.UDPAddr{IP: net.IPv4(127, 1, 2, 3), Port: 1337}
for i := 1; i <= 20; i++ {
conn.handlePacket(&receivedPacket{})
conn.handlePacket(&receivedPacket{remoteAddr: addr})
if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 {
Eventually(written).Should(Receive(Equal([]byte("close")))) // receive the CONNECTION_CLOSE
Expect(written).To(Receive(Equal(addr))) // receive the CONNECTION_CLOSE
} else {
Consistently(written, 10*time.Millisecond).Should(HaveLen(0))
Expect(written).ToNot(Receive())
}
}
// stop the connection
conn.shutdown()
})
It("destroys connections", func() {
Eventually(areClosedConnsRunning).Should(BeTrue())
conn.destroy(errors.New("destroy"))
Eventually(areClosedConnsRunning).Should(BeFalse())
})
})

View file

@ -20,7 +20,7 @@ type connIDGenerator struct {
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func([]protocol.ConnectionID, packetHandler)
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte)
queueControlFrame func(wire.Frame)
version protocol.VersionNumber
@ -33,7 +33,7 @@ func newConnIDGenerator(
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func([]protocol.ConnectionID, packetHandler),
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
queueControlFrame func(wire.Frame),
version protocol.VersionNumber,
) *connIDGenerator {
@ -130,7 +130,7 @@ func (m *connIDGenerator) RemoveAll() {
}
}
func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) {
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
if m.initialClientDestConnID != nil {
connIDs = append(connIDs, m.initialClientDestConnID)
@ -138,5 +138,5 @@ func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
}
m.replaceWithClosed(connIDs, handler)
m.replaceWithClosed(connIDs, pers, connClose)
}

View file

@ -16,7 +16,7 @@ var _ = Describe("Connection ID Generator", func() {
addedConnIDs []protocol.ConnectionID
retiredConnIDs []protocol.ConnectionID
removedConnIDs []protocol.ConnectionID
replacedWithClosed map[string]packetHandler
replacedWithClosed []protocol.ConnectionID
queuedFrames []wire.Frame
g *connIDGenerator
)
@ -32,7 +32,7 @@ var _ = Describe("Connection ID Generator", func() {
retiredConnIDs = nil
removedConnIDs = nil
queuedFrames = nil
replacedWithClosed = make(map[string]packetHandler)
replacedWithClosed = nil
g = newConnIDGenerator(
initialConnID,
initialClientDestConnID,
@ -40,10 +40,8 @@ var _ = Describe("Connection ID Generator", func() {
connIDToToken,
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
func(cs []protocol.ConnectionID, h packetHandler) {
for _, c := range cs {
replacedWithClosed[string(c)] = h
}
func(cs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
replacedWithClosed = append(replacedWithClosed, cs...)
},
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
protocol.VersionDraft29,
@ -178,14 +176,13 @@ var _ = Describe("Connection ID Generator", func() {
It("replaces with a closed connection for all connection IDs", func() {
Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
Expect(queuedFrames).To(HaveLen(4))
sess := NewMockPacketHandler(mockCtrl)
g.ReplaceWithClosed(sess)
g.ReplaceWithClosed(protocol.PerspectiveClient, []byte("foobar"))
Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones
Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialClientDestConnID), sess))
Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialConnID), sess))
Expect(replacedWithClosed).To(ContainElement(initialClientDestConnID))
Expect(replacedWithClosed).To(ContainElement(initialConnID))
for _, f := range queuedFrames {
nf := f.(*wire.NewConnectionIDFrame)
Expect(replacedWithClosed).To(HaveKeyWithValue(string(nf.ConnectionID), sess))
Expect(replacedWithClosed).To(ContainElement(nf.ConnectionID))
}
})
})

View file

@ -95,7 +95,7 @@ type connRunner interface {
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
Retire(protocol.ConnectionID)
Remove(protocol.ConnectionID)
ReplaceWithClosed([]protocol.ConnectionID, packetHandler)
ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte)
AddResetToken(protocol.StatelessResetToken, packetHandler)
RemoveResetToken(protocol.StatelessResetToken)
}
@ -1521,7 +1521,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
// If this is a remote close we're done here
if closeErr.remote {
s.connIDGenerator.ReplaceWithClosed(newClosedRemoteConn(s.perspective))
s.connIDGenerator.ReplaceWithClosed(s.perspective, nil)
return
}
if closeErr.immediate {
@ -1538,8 +1538,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
if err != nil {
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err)
}
cs := newClosedLocalConn(s.conn, connClosePacket, s.perspective, s.logger)
s.connIDGenerator.ReplaceWithClosed(cs)
s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket)
}
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {

View file

@ -37,12 +37,6 @@ func areConnsRunning() bool {
return strings.Contains(b.String(), "quic-go.(*connection).run")
}
func areClosedConnsRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*closedLocalConn).run")
}
var _ = Describe("Connection", func() {
var (
conn *connection
@ -72,14 +66,11 @@ var _ = Describe("Connection", func() {
}
expectReplaceWithClosed := func() {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
Expect(connIDs).To(ContainElement(srcConnID))
if len(connIDs) > 1 {
Expect(connIDs).To(ContainElement(clientDestConnID))
}
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
s.shutdown()
Eventually(areClosedConnsRunning).Should(BeFalse())
})
}
@ -333,9 +324,8 @@ var _ = Describe("Connection", func() {
ErrorMessage: "foobar",
}
streamManager.EXPECT().CloseWithError(expectedErr)
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
})
cryptoSetup.EXPECT().Close()
gomock.InOrder(
@ -362,9 +352,8 @@ var _ = Describe("Connection", func() {
ErrorMessage: "foobar",
}
streamManager.EXPECT().CloseWithError(testErr)
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
})
cryptoSetup.EXPECT().Close()
gomock.InOrder(
@ -564,7 +553,7 @@ var _ = Describe("Connection", func() {
runConn()
cryptoSetup.EXPECT().Close()
streamManager.EXPECT().CloseWithError(gomock.Any())
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
buf := &bytes.Buffer{}
hdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
@ -2432,10 +2421,7 @@ var _ = Describe("Client Connection", func() {
}
expectReplaceWithClosed := func() {
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) {
s.shutdown()
Eventually(areClosedConnsRunning).Should(BeFalse())
})
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any())
}
BeforeEach(func() {
@ -2766,10 +2752,7 @@ var _ = Describe("Client Connection", func() {
expectClose := func(applicationClose bool) {
if !closed {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
s.shutdown()
})
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any())
if applicationClose {
packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1)
} else {

View file

@ -99,15 +99,15 @@ func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock
}
// ReplaceWithClosed mocks base method.
func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 packetHandler) {
func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 protocol.Perspective, arg2 []byte) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1)
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2)
}
// ReplaceWithClosed indicates an expected call of ReplaceWithClosed.
func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1, arg2)
}
// Retire mocks base method.

View file

@ -139,15 +139,15 @@ func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{
}
// ReplaceWithClosed mocks base method.
func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 packetHandler) {
func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 protocol.Perspective, arg2 []byte) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1)
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2)
}
// ReplaceWithClosed indicates an expected call of ReplaceWithClosed.
func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1, arg2)
}
// Retire mocks base method.

View file

@ -30,6 +30,12 @@ type rawConn interface {
io.Closer
}
type closePacket struct {
payload []byte
addr net.Addr
info *packetInfo
}
// The packetHandlerMap stores packetHandlers, identified by connection ID.
// It is used:
// * by the server to store connections
@ -40,6 +46,8 @@ type packetHandlerMap struct {
conn rawConn
connIDLen int
closeQueue chan closePacket
handlers map[string] /* string(ConnectionID)*/ packetHandler
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
server unknownPacketHandler
@ -123,12 +131,14 @@ func newPacketHandlerMap(
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration,
closeQueue: make(chan closePacket, 4),
statelessResetEnabled: len(statelessResetKey) > 0,
statelessResetHasher: hmac.New(sha256.New, statelessResetKey),
tracer: tracer,
logger: logger,
}
go m.listen()
go m.runCloseQueue()
if logger.Debug() {
go m.logUsage()
@ -219,7 +229,29 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
})
}
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, handler packetHandler) {
// ReplaceWithClosed is called when a connection is closed.
// Depending on which side closed the connection, we need to:
// * remote close: absorb delayed packets
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers protocol.Perspective, connClosePacket []byte) {
var handler packetHandler
if connClosePacket != nil {
handler = newClosedLocalConn(
func(addr net.Addr, info *packetInfo) {
select {
case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}:
default:
// Oops, we're backlogged.
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
}
},
pers,
h.logger,
)
} else {
handler = newClosedRemoteConn(pers)
}
h.mutex.Lock()
for _, id := range ids {
h.handlers[string(id)] = handler
@ -238,6 +270,17 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, handle
})
}
func (h *packetHandlerMap) runCloseQueue() {
for {
select {
case <-h.listening:
return
case p := <-h.closeQueue:
h.conn.WritePacket(p.payload, p.addr, p.info.OOB())
}
}
}
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
h.mutex.Lock()
h.resetTokens[token] = handler