mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
use a single Go routine to send copies of CONNECTION_CLOSE packets
This commit is contained in:
parent
c3ab9c4ea9
commit
b659414495
9 changed files with 102 additions and 145 deletions
|
@ -2,7 +2,7 @@ package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/bits"
|
"math/bits"
|
||||||
"sync"
|
"net"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
@ -12,85 +12,38 @@ import (
|
||||||
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
|
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
|
||||||
// with an exponential backoff.
|
// with an exponential backoff.
|
||||||
type closedLocalConn struct {
|
type closedLocalConn struct {
|
||||||
conn sendConn
|
counter uint32
|
||||||
connClosePacket []byte
|
|
||||||
|
|
||||||
closeOnce sync.Once
|
|
||||||
closeChan chan struct{} // is closed when the connection is closed or destroyed
|
|
||||||
|
|
||||||
receivedPackets chan *receivedPacket
|
|
||||||
counter uint64 // number of packets received
|
|
||||||
|
|
||||||
perspective protocol.Perspective
|
perspective protocol.Perspective
|
||||||
|
logger utils.Logger
|
||||||
|
|
||||||
logger utils.Logger
|
sendPacket func(net.Addr, *packetInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ packetHandler = &closedLocalConn{}
|
var _ packetHandler = &closedLocalConn{}
|
||||||
|
|
||||||
// newClosedLocalConn creates a new closedLocalConn and runs it.
|
// newClosedLocalConn creates a new closedLocalConn and runs it.
|
||||||
func newClosedLocalConn(
|
func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
|
||||||
conn sendConn,
|
return &closedLocalConn{
|
||||||
connClosePacket []byte,
|
sendPacket: sendPacket,
|
||||||
perspective protocol.Perspective,
|
perspective: pers,
|
||||||
logger utils.Logger,
|
logger: logger,
|
||||||
) packetHandler {
|
|
||||||
s := &closedLocalConn{
|
|
||||||
conn: conn,
|
|
||||||
connClosePacket: connClosePacket,
|
|
||||||
perspective: perspective,
|
|
||||||
logger: logger,
|
|
||||||
closeChan: make(chan struct{}),
|
|
||||||
receivedPackets: make(chan *receivedPacket, 64),
|
|
||||||
}
|
|
||||||
go s.run()
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *closedLocalConn) run() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case p := <-s.receivedPackets:
|
|
||||||
s.handlePacketImpl(p)
|
|
||||||
case <-s.closeChan:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *closedLocalConn) handlePacket(p *receivedPacket) {
|
func (c *closedLocalConn) handlePacket(p *receivedPacket) {
|
||||||
select {
|
c.counter++
|
||||||
case s.receivedPackets <- p:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *closedLocalConn) handlePacketImpl(_ *receivedPacket) {
|
|
||||||
s.counter++
|
|
||||||
// exponential backoff
|
// exponential backoff
|
||||||
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
|
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
|
||||||
if bits.OnesCount64(s.counter) != 1 {
|
if bits.OnesCount32(c.counter) != 1 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", s.counter)
|
c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", c.counter)
|
||||||
if err := s.conn.Write(s.connClosePacket); err != nil {
|
c.sendPacket(p.remoteAddr, p.info)
|
||||||
s.logger.Debugf("Error retransmitting CONNECTION_CLOSE: %s", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *closedLocalConn) shutdown() {
|
func (c *closedLocalConn) shutdown() {}
|
||||||
s.destroy(nil)
|
func (c *closedLocalConn) destroy(error) {}
|
||||||
}
|
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
|
||||||
|
|
||||||
func (s *closedLocalConn) destroy(error) {
|
|
||||||
s.closeOnce.Do(func() {
|
|
||||||
close(s.closeChan)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *closedLocalConn) getPerspective() protocol.Perspective {
|
|
||||||
return s.perspective
|
|
||||||
}
|
|
||||||
|
|
||||||
// A closedRemoteConn is a connection that was closed remotely.
|
// A closedRemoteConn is a connection that was closed remotely.
|
||||||
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
|
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"net"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
|
||||||
|
@ -13,44 +11,28 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Closed local connection", func() {
|
var _ = Describe("Closed local connection", func() {
|
||||||
var (
|
|
||||||
conn packetHandler
|
|
||||||
mconn *MockSendConn
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
mconn = NewMockSendConn(mockCtrl)
|
|
||||||
conn = newClosedLocalConn(mconn, []byte("close"), protocol.PerspectiveClient, utils.DefaultLogger)
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Eventually(areClosedConnsRunning).Should(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("tells its perspective", func() {
|
It("tells its perspective", func() {
|
||||||
|
conn := newClosedLocalConn(nil, protocol.PerspectiveClient, utils.DefaultLogger)
|
||||||
Expect(conn.getPerspective()).To(Equal(protocol.PerspectiveClient))
|
Expect(conn.getPerspective()).To(Equal(protocol.PerspectiveClient))
|
||||||
// stop the connection
|
// stop the connection
|
||||||
conn.shutdown()
|
conn.shutdown()
|
||||||
})
|
})
|
||||||
|
|
||||||
It("repeats the packet containing the CONNECTION_CLOSE frame", func() {
|
It("repeats the packet containing the CONNECTION_CLOSE frame", func() {
|
||||||
written := make(chan []byte)
|
written := make(chan net.Addr, 1)
|
||||||
mconn.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }).AnyTimes()
|
conn := newClosedLocalConn(
|
||||||
|
func(addr net.Addr, _ *packetInfo) { written <- addr },
|
||||||
|
protocol.PerspectiveClient,
|
||||||
|
utils.DefaultLogger,
|
||||||
|
)
|
||||||
|
addr := &net.UDPAddr{IP: net.IPv4(127, 1, 2, 3), Port: 1337}
|
||||||
for i := 1; i <= 20; i++ {
|
for i := 1; i <= 20; i++ {
|
||||||
conn.handlePacket(&receivedPacket{})
|
conn.handlePacket(&receivedPacket{remoteAddr: addr})
|
||||||
if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 {
|
if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 {
|
||||||
Eventually(written).Should(Receive(Equal([]byte("close")))) // receive the CONNECTION_CLOSE
|
Expect(written).To(Receive(Equal(addr))) // receive the CONNECTION_CLOSE
|
||||||
} else {
|
} else {
|
||||||
Consistently(written, 10*time.Millisecond).Should(HaveLen(0))
|
Expect(written).ToNot(Receive())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// stop the connection
|
|
||||||
conn.shutdown()
|
|
||||||
})
|
|
||||||
|
|
||||||
It("destroys connections", func() {
|
|
||||||
Eventually(areClosedConnsRunning).Should(BeTrue())
|
|
||||||
conn.destroy(errors.New("destroy"))
|
|
||||||
Eventually(areClosedConnsRunning).Should(BeFalse())
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -20,7 +20,7 @@ type connIDGenerator struct {
|
||||||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
|
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
|
||||||
removeConnectionID func(protocol.ConnectionID)
|
removeConnectionID func(protocol.ConnectionID)
|
||||||
retireConnectionID func(protocol.ConnectionID)
|
retireConnectionID func(protocol.ConnectionID)
|
||||||
replaceWithClosed func([]protocol.ConnectionID, packetHandler)
|
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte)
|
||||||
queueControlFrame func(wire.Frame)
|
queueControlFrame func(wire.Frame)
|
||||||
|
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
|
@ -33,7 +33,7 @@ func newConnIDGenerator(
|
||||||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
|
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
|
||||||
removeConnectionID func(protocol.ConnectionID),
|
removeConnectionID func(protocol.ConnectionID),
|
||||||
retireConnectionID func(protocol.ConnectionID),
|
retireConnectionID func(protocol.ConnectionID),
|
||||||
replaceWithClosed func([]protocol.ConnectionID, packetHandler),
|
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
|
||||||
queueControlFrame func(wire.Frame),
|
queueControlFrame func(wire.Frame),
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) *connIDGenerator {
|
) *connIDGenerator {
|
||||||
|
@ -130,7 +130,7 @@ func (m *connIDGenerator) RemoveAll() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
|
func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) {
|
||||||
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
|
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
|
||||||
if m.initialClientDestConnID != nil {
|
if m.initialClientDestConnID != nil {
|
||||||
connIDs = append(connIDs, m.initialClientDestConnID)
|
connIDs = append(connIDs, m.initialClientDestConnID)
|
||||||
|
@ -138,5 +138,5 @@ func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
|
||||||
for _, connID := range m.activeSrcConnIDs {
|
for _, connID := range m.activeSrcConnIDs {
|
||||||
connIDs = append(connIDs, connID)
|
connIDs = append(connIDs, connID)
|
||||||
}
|
}
|
||||||
m.replaceWithClosed(connIDs, handler)
|
m.replaceWithClosed(connIDs, pers, connClose)
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ var _ = Describe("Connection ID Generator", func() {
|
||||||
addedConnIDs []protocol.ConnectionID
|
addedConnIDs []protocol.ConnectionID
|
||||||
retiredConnIDs []protocol.ConnectionID
|
retiredConnIDs []protocol.ConnectionID
|
||||||
removedConnIDs []protocol.ConnectionID
|
removedConnIDs []protocol.ConnectionID
|
||||||
replacedWithClosed map[string]packetHandler
|
replacedWithClosed []protocol.ConnectionID
|
||||||
queuedFrames []wire.Frame
|
queuedFrames []wire.Frame
|
||||||
g *connIDGenerator
|
g *connIDGenerator
|
||||||
)
|
)
|
||||||
|
@ -32,7 +32,7 @@ var _ = Describe("Connection ID Generator", func() {
|
||||||
retiredConnIDs = nil
|
retiredConnIDs = nil
|
||||||
removedConnIDs = nil
|
removedConnIDs = nil
|
||||||
queuedFrames = nil
|
queuedFrames = nil
|
||||||
replacedWithClosed = make(map[string]packetHandler)
|
replacedWithClosed = nil
|
||||||
g = newConnIDGenerator(
|
g = newConnIDGenerator(
|
||||||
initialConnID,
|
initialConnID,
|
||||||
initialClientDestConnID,
|
initialClientDestConnID,
|
||||||
|
@ -40,10 +40,8 @@ var _ = Describe("Connection ID Generator", func() {
|
||||||
connIDToToken,
|
connIDToToken,
|
||||||
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
|
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
|
||||||
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
|
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
|
||||||
func(cs []protocol.ConnectionID, h packetHandler) {
|
func(cs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
|
||||||
for _, c := range cs {
|
replacedWithClosed = append(replacedWithClosed, cs...)
|
||||||
replacedWithClosed[string(c)] = h
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
|
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
|
||||||
protocol.VersionDraft29,
|
protocol.VersionDraft29,
|
||||||
|
@ -178,14 +176,13 @@ var _ = Describe("Connection ID Generator", func() {
|
||||||
It("replaces with a closed connection for all connection IDs", func() {
|
It("replaces with a closed connection for all connection IDs", func() {
|
||||||
Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
|
Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
|
||||||
Expect(queuedFrames).To(HaveLen(4))
|
Expect(queuedFrames).To(HaveLen(4))
|
||||||
sess := NewMockPacketHandler(mockCtrl)
|
g.ReplaceWithClosed(protocol.PerspectiveClient, []byte("foobar"))
|
||||||
g.ReplaceWithClosed(sess)
|
|
||||||
Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones
|
Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones
|
||||||
Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialClientDestConnID), sess))
|
Expect(replacedWithClosed).To(ContainElement(initialClientDestConnID))
|
||||||
Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialConnID), sess))
|
Expect(replacedWithClosed).To(ContainElement(initialConnID))
|
||||||
for _, f := range queuedFrames {
|
for _, f := range queuedFrames {
|
||||||
nf := f.(*wire.NewConnectionIDFrame)
|
nf := f.(*wire.NewConnectionIDFrame)
|
||||||
Expect(replacedWithClosed).To(HaveKeyWithValue(string(nf.ConnectionID), sess))
|
Expect(replacedWithClosed).To(ContainElement(nf.ConnectionID))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -95,7 +95,7 @@ type connRunner interface {
|
||||||
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
|
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
|
||||||
Retire(protocol.ConnectionID)
|
Retire(protocol.ConnectionID)
|
||||||
Remove(protocol.ConnectionID)
|
Remove(protocol.ConnectionID)
|
||||||
ReplaceWithClosed([]protocol.ConnectionID, packetHandler)
|
ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte)
|
||||||
AddResetToken(protocol.StatelessResetToken, packetHandler)
|
AddResetToken(protocol.StatelessResetToken, packetHandler)
|
||||||
RemoveResetToken(protocol.StatelessResetToken)
|
RemoveResetToken(protocol.StatelessResetToken)
|
||||||
}
|
}
|
||||||
|
@ -1521,7 +1521,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
|
||||||
|
|
||||||
// If this is a remote close we're done here
|
// If this is a remote close we're done here
|
||||||
if closeErr.remote {
|
if closeErr.remote {
|
||||||
s.connIDGenerator.ReplaceWithClosed(newClosedRemoteConn(s.perspective))
|
s.connIDGenerator.ReplaceWithClosed(s.perspective, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if closeErr.immediate {
|
if closeErr.immediate {
|
||||||
|
@ -1538,8 +1538,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err)
|
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err)
|
||||||
}
|
}
|
||||||
cs := newClosedLocalConn(s.conn, connClosePacket, s.perspective, s.logger)
|
s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket)
|
||||||
s.connIDGenerator.ReplaceWithClosed(cs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
|
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
|
||||||
|
|
|
@ -37,12 +37,6 @@ func areConnsRunning() bool {
|
||||||
return strings.Contains(b.String(), "quic-go.(*connection).run")
|
return strings.Contains(b.String(), "quic-go.(*connection).run")
|
||||||
}
|
}
|
||||||
|
|
||||||
func areClosedConnsRunning() bool {
|
|
||||||
var b bytes.Buffer
|
|
||||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
|
||||||
return strings.Contains(b.String(), "quic-go.(*closedLocalConn).run")
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Connection", func() {
|
var _ = Describe("Connection", func() {
|
||||||
var (
|
var (
|
||||||
conn *connection
|
conn *connection
|
||||||
|
@ -72,14 +66,11 @@ var _ = Describe("Connection", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
expectReplaceWithClosed := func() {
|
expectReplaceWithClosed := func() {
|
||||||
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
|
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
|
||||||
Expect(connIDs).To(ContainElement(srcConnID))
|
Expect(connIDs).To(ContainElement(srcConnID))
|
||||||
if len(connIDs) > 1 {
|
if len(connIDs) > 1 {
|
||||||
Expect(connIDs).To(ContainElement(clientDestConnID))
|
Expect(connIDs).To(ContainElement(clientDestConnID))
|
||||||
}
|
}
|
||||||
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
|
|
||||||
s.shutdown()
|
|
||||||
Eventually(areClosedConnsRunning).Should(BeFalse())
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,9 +324,8 @@ var _ = Describe("Connection", func() {
|
||||||
ErrorMessage: "foobar",
|
ErrorMessage: "foobar",
|
||||||
}
|
}
|
||||||
streamManager.EXPECT().CloseWithError(expectedErr)
|
streamManager.EXPECT().CloseWithError(expectedErr)
|
||||||
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
|
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
|
||||||
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
|
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
|
||||||
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
|
|
||||||
})
|
})
|
||||||
cryptoSetup.EXPECT().Close()
|
cryptoSetup.EXPECT().Close()
|
||||||
gomock.InOrder(
|
gomock.InOrder(
|
||||||
|
@ -362,9 +352,8 @@ var _ = Describe("Connection", func() {
|
||||||
ErrorMessage: "foobar",
|
ErrorMessage: "foobar",
|
||||||
}
|
}
|
||||||
streamManager.EXPECT().CloseWithError(testErr)
|
streamManager.EXPECT().CloseWithError(testErr)
|
||||||
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
|
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
|
||||||
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
|
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
|
||||||
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
|
|
||||||
})
|
})
|
||||||
cryptoSetup.EXPECT().Close()
|
cryptoSetup.EXPECT().Close()
|
||||||
gomock.InOrder(
|
gomock.InOrder(
|
||||||
|
@ -564,7 +553,7 @@ var _ = Describe("Connection", func() {
|
||||||
runConn()
|
runConn()
|
||||||
cryptoSetup.EXPECT().Close()
|
cryptoSetup.EXPECT().Close()
|
||||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||||
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
|
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
hdr := &wire.ExtendedHeader{
|
hdr := &wire.ExtendedHeader{
|
||||||
Header: wire.Header{DestConnectionID: srcConnID},
|
Header: wire.Header{DestConnectionID: srcConnID},
|
||||||
|
@ -2432,10 +2421,7 @@ var _ = Describe("Client Connection", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
expectReplaceWithClosed := func() {
|
expectReplaceWithClosed := func() {
|
||||||
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) {
|
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any())
|
||||||
s.shutdown()
|
|
||||||
Eventually(areClosedConnsRunning).Should(BeFalse())
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -2766,10 +2752,7 @@ var _ = Describe("Client Connection", func() {
|
||||||
|
|
||||||
expectClose := func(applicationClose bool) {
|
expectClose := func(applicationClose bool) {
|
||||||
if !closed {
|
if !closed {
|
||||||
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) {
|
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
|
|
||||||
s.shutdown()
|
|
||||||
})
|
|
||||||
if applicationClose {
|
if applicationClose {
|
||||||
packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1)
|
packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -99,15 +99,15 @@ func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplaceWithClosed mocks base method.
|
// ReplaceWithClosed mocks base method.
|
||||||
func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 packetHandler) {
|
func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 protocol.Perspective, arg2 []byte) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1)
|
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplaceWithClosed indicates an expected call of ReplaceWithClosed.
|
// ReplaceWithClosed indicates an expected call of ReplaceWithClosed.
|
||||||
func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call {
|
func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retire mocks base method.
|
// Retire mocks base method.
|
||||||
|
|
|
@ -139,15 +139,15 @@ func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplaceWithClosed mocks base method.
|
// ReplaceWithClosed mocks base method.
|
||||||
func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 packetHandler) {
|
func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 protocol.Perspective, arg2 []byte) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1)
|
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplaceWithClosed indicates an expected call of ReplaceWithClosed.
|
// ReplaceWithClosed indicates an expected call of ReplaceWithClosed.
|
||||||
func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call {
|
func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retire mocks base method.
|
// Retire mocks base method.
|
||||||
|
|
|
@ -30,6 +30,12 @@ type rawConn interface {
|
||||||
io.Closer
|
io.Closer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type closePacket struct {
|
||||||
|
payload []byte
|
||||||
|
addr net.Addr
|
||||||
|
info *packetInfo
|
||||||
|
}
|
||||||
|
|
||||||
// The packetHandlerMap stores packetHandlers, identified by connection ID.
|
// The packetHandlerMap stores packetHandlers, identified by connection ID.
|
||||||
// It is used:
|
// It is used:
|
||||||
// * by the server to store connections
|
// * by the server to store connections
|
||||||
|
@ -40,6 +46,8 @@ type packetHandlerMap struct {
|
||||||
conn rawConn
|
conn rawConn
|
||||||
connIDLen int
|
connIDLen int
|
||||||
|
|
||||||
|
closeQueue chan closePacket
|
||||||
|
|
||||||
handlers map[string] /* string(ConnectionID)*/ packetHandler
|
handlers map[string] /* string(ConnectionID)*/ packetHandler
|
||||||
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
|
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
|
||||||
server unknownPacketHandler
|
server unknownPacketHandler
|
||||||
|
@ -123,12 +131,14 @@ func newPacketHandlerMap(
|
||||||
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
|
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
|
||||||
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
|
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
|
||||||
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration,
|
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration,
|
||||||
|
closeQueue: make(chan closePacket, 4),
|
||||||
statelessResetEnabled: len(statelessResetKey) > 0,
|
statelessResetEnabled: len(statelessResetKey) > 0,
|
||||||
statelessResetHasher: hmac.New(sha256.New, statelessResetKey),
|
statelessResetHasher: hmac.New(sha256.New, statelessResetKey),
|
||||||
tracer: tracer,
|
tracer: tracer,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
go m.listen()
|
go m.listen()
|
||||||
|
go m.runCloseQueue()
|
||||||
|
|
||||||
if logger.Debug() {
|
if logger.Debug() {
|
||||||
go m.logUsage()
|
go m.logUsage()
|
||||||
|
@ -219,7 +229,29 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, handler packetHandler) {
|
// ReplaceWithClosed is called when a connection is closed.
|
||||||
|
// Depending on which side closed the connection, we need to:
|
||||||
|
// * remote close: absorb delayed packets
|
||||||
|
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
|
||||||
|
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers protocol.Perspective, connClosePacket []byte) {
|
||||||
|
var handler packetHandler
|
||||||
|
if connClosePacket != nil {
|
||||||
|
handler = newClosedLocalConn(
|
||||||
|
func(addr net.Addr, info *packetInfo) {
|
||||||
|
select {
|
||||||
|
case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}:
|
||||||
|
default:
|
||||||
|
// Oops, we're backlogged.
|
||||||
|
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
|
||||||
|
}
|
||||||
|
},
|
||||||
|
pers,
|
||||||
|
h.logger,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
handler = newClosedRemoteConn(pers)
|
||||||
|
}
|
||||||
|
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
h.handlers[string(id)] = handler
|
h.handlers[string(id)] = handler
|
||||||
|
@ -238,6 +270,17 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, handle
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *packetHandlerMap) runCloseQueue() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-h.listening:
|
||||||
|
return
|
||||||
|
case p := <-h.closeQueue:
|
||||||
|
h.conn.WritePacket(p.payload, p.addr, p.info.OOB())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
|
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
h.resetTokens[token] = handler
|
h.resetTokens[token] = handler
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue