mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
Merge pull request #3788 from quic-go/zero-rtt-queue-server
move 0-RTT queue handling from the packet handler map to the server
This commit is contained in:
commit
2a03c6cd79
7 changed files with 413 additions and 205 deletions
|
@ -7,7 +7,6 @@ import (
|
|||
"io"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -214,7 +213,7 @@ var _ = Describe("0-RTT", func() {
|
|||
ln,
|
||||
proxy.LocalPort(),
|
||||
clientTLSConf,
|
||||
&quic.Config{ConnectionIDLength: connIDLen},
|
||||
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
|
||||
PRData,
|
||||
)
|
||||
|
||||
|
@ -237,8 +236,7 @@ var _ = Describe("0-RTT", func() {
|
|||
Expect(num0RTT).ToNot(BeZero())
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
|
||||
sort.Slice(zeroRTTPackets, func(i, j int) bool { return zeroRTTPackets[i] < zeroRTTPackets[j] })
|
||||
Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0)))
|
||||
Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0)))
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -100,6 +100,21 @@ func (mr *MockPacketHandlerManagerMockRecorder) Destroy() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockPacketHandlerManager)(nil).Destroy))
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", arg0)
|
||||
ret0, _ := ret[0].(packetHandler)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
|
||||
}
|
||||
|
||||
// GetStatelessResetToken mocks base method.
|
||||
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -48,16 +48,14 @@ type packetHandlerMap struct {
|
|||
|
||||
closeQueue chan closePacket
|
||||
|
||||
handlers map[protocol.ConnectionID]packetHandler
|
||||
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
|
||||
server unknownPacketHandler
|
||||
numZeroRTTEntries int
|
||||
handlers map[protocol.ConnectionID]packetHandler
|
||||
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
|
||||
server unknownPacketHandler
|
||||
|
||||
listening chan struct{} // is closed when listen returns
|
||||
closed bool
|
||||
|
||||
deleteRetiredConnsAfter time.Duration
|
||||
zeroRTTQueueDuration time.Duration
|
||||
|
||||
statelessResetEnabled bool
|
||||
statelessResetMutex sync.Mutex
|
||||
|
@ -130,7 +128,6 @@ func newPacketHandlerMap(
|
|||
handlers: make(map[protocol.ConnectionID]packetHandler),
|
||||
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
|
||||
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
|
||||
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration,
|
||||
closeQueue: make(chan closePacket, 4),
|
||||
statelessResetEnabled: statelessResetKey != nil,
|
||||
tracer: tracer,
|
||||
|
@ -174,6 +171,14 @@ func (h *packetHandlerMap) logUsage() {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
handler, ok := h.handlers[id]
|
||||
return handler, ok
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
@ -191,23 +196,11 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
|
|||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
var q *zeroRTTQueue
|
||||
if handler, ok := h.handlers[clientDestConnID]; ok {
|
||||
q, ok = handler.(*zeroRTTQueue)
|
||||
if !ok {
|
||||
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
|
||||
return false
|
||||
}
|
||||
q.retireTimer.Stop()
|
||||
h.numZeroRTTEntries--
|
||||
if h.numZeroRTTEntries < 0 {
|
||||
panic("number of 0-RTT queues < 0")
|
||||
}
|
||||
if _, ok := h.handlers[clientDestConnID]; ok {
|
||||
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
|
||||
return false
|
||||
}
|
||||
conn := fn()
|
||||
if q != nil {
|
||||
q.EnqueueAll(conn)
|
||||
}
|
||||
h.handlers[clientDestConnID] = conn
|
||||
h.handlers[newConnID] = conn
|
||||
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
|
||||
|
@ -395,17 +388,9 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
|
|||
if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset {
|
||||
return
|
||||
}
|
||||
|
||||
if handler, ok := h.handlers[connID]; ok {
|
||||
if ha, ok := handler.(*zeroRTTQueue); ok { // only enqueue 0-RTT packets in the 0-RTT queue
|
||||
if wire.Is0RTTPacket(p.data) {
|
||||
ha.handlePacket(p)
|
||||
return
|
||||
}
|
||||
} else { // existing connection
|
||||
handler.handlePacket(p)
|
||||
return
|
||||
}
|
||||
handler.handlePacket(p)
|
||||
return
|
||||
}
|
||||
if !wire.IsLongHeaderPacket(p.data[0]) {
|
||||
go h.maybeSendStatelessReset(p, connID)
|
||||
|
@ -415,38 +400,6 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
|
|||
h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
|
||||
return
|
||||
}
|
||||
if wire.Is0RTTPacket(p.data) {
|
||||
if h.numZeroRTTEntries >= protocol.Max0RTTQueues {
|
||||
if h.tracer != nil {
|
||||
h.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
return
|
||||
}
|
||||
h.numZeroRTTEntries++
|
||||
queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)}
|
||||
h.handlers[connID] = queue
|
||||
queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
// The entry might have been replaced by an actual connection.
|
||||
// Only delete it if it's still a 0-RTT queue.
|
||||
if handler, ok := h.handlers[connID]; ok {
|
||||
if q, ok := handler.(*zeroRTTQueue); ok {
|
||||
delete(h.handlers, connID)
|
||||
h.numZeroRTTEntries--
|
||||
if h.numZeroRTTEntries < 0 {
|
||||
panic("number of 0-RTT queues < 0")
|
||||
}
|
||||
q.Clear()
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Removing 0-RTT queue for %s.", connID)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
queue.handlePacket(p)
|
||||
return
|
||||
}
|
||||
h.server.handlePacket(p)
|
||||
}
|
||||
|
||||
|
|
|
@ -271,93 +271,6 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("0-RTT", func() {
|
||||
JustBeforeEach(func() {
|
||||
handler.zeroRTTQueueDuration = time.Hour
|
||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
||||
// we don't expect any calls to server.handlePacket
|
||||
handler.SetServer(server)
|
||||
})
|
||||
|
||||
It("queues 0-RTT packets", func() {
|
||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
||||
// don't EXPECT any calls to server.handlePacket
|
||||
handler.SetServer(server)
|
||||
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
|
||||
p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}
|
||||
p2 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2)}
|
||||
p3 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 3)}
|
||||
handler.handlePacket(p1)
|
||||
handler.handlePacket(p2)
|
||||
handler.handlePacket(p3)
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
done := make(chan struct{})
|
||||
gomock.InOrder(
|
||||
conn.EXPECT().handlePacket(p1),
|
||||
conn.EXPECT().handlePacket(p2),
|
||||
conn.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }),
|
||||
)
|
||||
handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn })
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("directs 0-RTT packets to existing connections", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn })
|
||||
p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}
|
||||
conn.EXPECT().handlePacket(p1)
|
||||
handler.handlePacket(p1)
|
||||
})
|
||||
|
||||
It("limits the number of 0-RTT queues", func() {
|
||||
for i := 0; i < protocol.Max0RTTQueues; i++ {
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
p := &receivedPacket{data: getPacketWithPacketType(
|
||||
protocol.ParseConnectionID(b),
|
||||
protocol.PacketType0RTT,
|
||||
1,
|
||||
)}
|
||||
handler.handlePacket(p)
|
||||
}
|
||||
// We're already storing the maximum number of queues. This packet will be dropped.
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
tracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketType0RTT, gomock.Any(), logging.PacketDropDOSPrevention)
|
||||
handler.handlePacket(&receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)})
|
||||
// Don't EXPECT any handlePacket() calls.
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn })
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
})
|
||||
|
||||
It("deletes queues if no connection is created for this connection ID", func() {
|
||||
queueDuration := scaleDuration(10 * time.Millisecond)
|
||||
handler.zeroRTTQueueDuration = queueDuration
|
||||
|
||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
||||
// don't EXPECT any calls to server.handlePacket
|
||||
handler.SetServer(server)
|
||||
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
|
||||
p1 := &receivedPacket{
|
||||
data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1),
|
||||
buffer: getPacketBuffer(),
|
||||
}
|
||||
p2 := &receivedPacket{
|
||||
data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2),
|
||||
buffer: getPacketBuffer(),
|
||||
}
|
||||
handler.handlePacket(p1)
|
||||
handler.handlePacket(p2)
|
||||
// wait a bit. The queue should now already be deleted.
|
||||
time.Sleep(queueDuration * 3)
|
||||
// Don't EXPECT any handlePacket() calls.
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn })
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
})
|
||||
})
|
||||
|
||||
Context("stateless resets", func() {
|
||||
BeforeEach(func() {
|
||||
connIDLen = 5
|
||||
|
|
116
server.go
116
server.go
|
@ -36,6 +36,7 @@ type unknownPacketHandler interface {
|
|||
}
|
||||
|
||||
type packetHandlerManager interface {
|
||||
Get(protocol.ConnectionID) (packetHandler, bool)
|
||||
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
|
||||
Destroy() error
|
||||
connRunner
|
||||
|
@ -54,6 +55,11 @@ type quicConn interface {
|
|||
shutdown()
|
||||
}
|
||||
|
||||
type zeroRTTQueue struct {
|
||||
packets []*receivedPacket
|
||||
expiration time.Time
|
||||
}
|
||||
|
||||
// A Listener of QUIC
|
||||
type baseServer struct {
|
||||
mutex sync.Mutex
|
||||
|
@ -74,6 +80,9 @@ type baseServer struct {
|
|||
|
||||
receivedPackets chan *receivedPacket
|
||||
|
||||
nextZeroRTTCleanup time.Time
|
||||
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
|
||||
|
||||
// set as a member, so they can be set in the tests
|
||||
newConn func(
|
||||
sendConn,
|
||||
|
@ -214,6 +223,9 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
|
|||
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||
acceptEarlyConns: acceptEarly,
|
||||
}
|
||||
if acceptEarly {
|
||||
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
|
||||
}
|
||||
go s.run()
|
||||
connHandler.SetServer(s)
|
||||
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
|
||||
|
@ -310,6 +322,10 @@ func (s *baseServer) handlePacket(p *receivedPacket) {
|
|||
}
|
||||
|
||||
func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ {
|
||||
if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) {
|
||||
defer s.cleanupZeroRTTQueues(p.rcvTime)
|
||||
}
|
||||
|
||||
if wire.IsVersionNegotiationPacket(p.data) {
|
||||
s.logger.Debugf("Dropping Version Negotiation packet.")
|
||||
if s.config.Tracer != nil {
|
||||
|
@ -344,6 +360,17 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if wire.Is0RTTPacket(p.data) {
|
||||
if !s.acceptEarlyConns {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
return s.handle0RTTPacket(p)
|
||||
}
|
||||
|
||||
// If we're creating a new connection, the packet will be passed to the connection.
|
||||
// The header will then be parsed again.
|
||||
hdr, _, _, err := wire.ParsePacket(p.data)
|
||||
|
@ -383,6 +410,74 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
|||
return true
|
||||
}
|
||||
|
||||
func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
if err != nil {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// check again if we might have a connection now
|
||||
if handler, ok := s.connHandler.Get(connID); ok {
|
||||
handler.handlePacket(p)
|
||||
return true
|
||||
}
|
||||
|
||||
if q, ok := s.zeroRTTQueues[connID]; ok {
|
||||
if len(q.packets) >= protocol.Max0RTTQueueLen {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
return false
|
||||
}
|
||||
q.packets = append(q.packets, p)
|
||||
return true
|
||||
}
|
||||
|
||||
if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
return false
|
||||
}
|
||||
queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)}
|
||||
queue.packets[0] = p
|
||||
expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration)
|
||||
queue.expiration = expiration
|
||||
if s.nextZeroRTTCleanup.IsZero() || s.nextZeroRTTCleanup.After(expiration) {
|
||||
s.nextZeroRTTCleanup = expiration
|
||||
}
|
||||
s.zeroRTTQueues[connID] = queue
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
|
||||
// Iterate over all queues to find those that are expired.
|
||||
// This is ok since we're placing a pretty low limit on the number of queues.
|
||||
var nextCleanup time.Time
|
||||
for connID, q := range s.zeroRTTQueues {
|
||||
if q.expiration.After(now) {
|
||||
if nextCleanup.IsZero() || nextCleanup.After(q.expiration) {
|
||||
nextCleanup = q.expiration
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, p := range q.packets {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
p.buffer.Release()
|
||||
}
|
||||
delete(s.zeroRTTQueues, connID)
|
||||
if s.logger.Debug() {
|
||||
s.logger.Debugf("Removing 0-RTT queue for %s.", connID)
|
||||
}
|
||||
}
|
||||
s.nextZeroRTTCleanup = nextCleanup
|
||||
}
|
||||
|
||||
// validateToken returns false if:
|
||||
// - address is invalid
|
||||
// - token is expired
|
||||
|
@ -412,6 +507,14 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
return errors.New("too short connection ID")
|
||||
}
|
||||
|
||||
// The server queues packets for a while, and we might already have established a connection by now.
|
||||
// This results in a second check in the connection map.
|
||||
// That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets).
|
||||
if handler, ok := s.connHandler.Get(hdr.DestConnectionID); ok {
|
||||
handler.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
token *handshake.Token
|
||||
retrySrcConnID *protocol.ConnectionID
|
||||
|
@ -429,7 +532,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
}
|
||||
|
||||
clientAddrIsValid := s.validateToken(token, p.remoteAddr)
|
||||
|
||||
if token != nil && !clientAddrIsValid {
|
||||
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
||||
// We just ignore them, and act as if there was no token on this packet at all.
|
||||
|
@ -450,6 +552,8 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
}
|
||||
}
|
||||
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
|
||||
// Retry invalidates all 0-RTT packets sent.
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
go func() {
|
||||
defer p.buffer.Release()
|
||||
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
|
||||
|
@ -510,8 +614,18 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
hdr.Version,
|
||||
)
|
||||
conn.handlePacket(p)
|
||||
|
||||
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
|
||||
for _, p := range q.packets {
|
||||
conn.handlePacket(p)
|
||||
}
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
}
|
||||
|
||||
return conn
|
||||
}); !added {
|
||||
// TODO: don't just drop the packet
|
||||
// Properly reject the connection attempt.
|
||||
return nil
|
||||
}
|
||||
go conn.run()
|
||||
|
|
281
server_test.go
281
server_test.go
|
@ -215,8 +215,7 @@ var _ = Describe("Server", func() {
|
|||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize-100),
|
||||
)
|
||||
}, make([]byte, protocol.MinInitialPacketSize-100))
|
||||
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
serv.handlePacket(p)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
|
@ -234,6 +233,21 @@ var _ = Describe("Server", func() {
|
|||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
|
||||
It("passes packets to existing connections", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize))
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
phm.EXPECT().Get(connID).Return(conn, true)
|
||||
handled := make(chan struct{})
|
||||
conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) })
|
||||
serv.handlePacket(p)
|
||||
Eventually(handled).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("creates a connection when the token is accepted", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
|
@ -243,10 +257,11 @@ var _ = Describe("Server", func() {
|
|||
protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
hdr := &wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
||||
DestConnectionID: connID,
|
||||
Version: protocol.Version1,
|
||||
Token: retryToken,
|
||||
}
|
||||
|
@ -257,7 +272,9 @@ var _ = Describe("Server", func() {
|
|||
rand.Read(token[:])
|
||||
|
||||
var newConnID protocol.ConnectionID
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
newConnID = c
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
newConnID = c
|
||||
|
@ -408,11 +425,12 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("replies with a Retry packet, if a token is required", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
hdr := &wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
||||
DestConnectionID: connID,
|
||||
Version: protocol.Version1,
|
||||
}
|
||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
|
@ -435,15 +453,17 @@ var _ = Describe("Server", func() {
|
|||
Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:]))
|
||||
return len(b), nil
|
||||
})
|
||||
phm.EXPECT().Get(connID)
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("creates a connection, if no token is required", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
hdr := &wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
||||
DestConnectionID: connID,
|
||||
Version: protocol.Version1,
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
|
@ -452,15 +472,18 @@ var _ = Describe("Server", func() {
|
|||
rand.Read(token[:])
|
||||
|
||||
var newConnID protocol.ConnectionID
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
newConnID = c
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().Get(connID),
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
newConnID = c
|
||||
return token
|
||||
})
|
||||
fn()
|
||||
return true
|
||||
})
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
newConnID = c
|
||||
return token
|
||||
})
|
||||
fn()
|
||||
return true
|
||||
}),
|
||||
)
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
|
||||
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
|
@ -513,6 +536,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("drops packets if the receive queue is full", func() {
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -597,8 +621,10 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}))
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}), gomock.Any(), gomock.Any()).Return(false)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
p := getInitial(connID)
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
|
||||
Expect(serv.handlePacketImpl(p)).To(BeTrue())
|
||||
Expect(createdConn).To(BeFalse())
|
||||
})
|
||||
|
@ -632,6 +658,7 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -702,6 +729,7 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -765,6 +793,7 @@ var _ = Describe("Server", func() {
|
|||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
|
||||
done := make(chan struct{})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -801,6 +830,7 @@ var _ = Describe("Server", func() {
|
|||
checkInvalidToken(b, hdr)
|
||||
return len(b), nil
|
||||
})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -837,6 +867,7 @@ var _ = Describe("Server", func() {
|
|||
checkInvalidToken(b, hdr)
|
||||
return len(b), nil
|
||||
})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -864,6 +895,7 @@ var _ = Describe("Server", func() {
|
|||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||
return len(b), nil
|
||||
})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
serv.handlePacket(packet)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -893,6 +925,7 @@ var _ = Describe("Server", func() {
|
|||
defer close(done)
|
||||
return len(b), nil
|
||||
})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -913,6 +946,7 @@ var _ = Describe("Server", func() {
|
|||
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
done := make(chan struct{})
|
||||
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
serv.handlePacket(packet)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
@ -997,6 +1031,7 @@ var _ = Describe("Server", func() {
|
|||
conn.EXPECT().Context().Return(context.Background())
|
||||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -1070,6 +1105,7 @@ var _ = Describe("Server", func() {
|
|||
conn.EXPECT().Context().Return(context.Background())
|
||||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -1115,6 +1151,7 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -1175,6 +1212,7 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -1202,4 +1240,215 @@ var _ = Describe("Server", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("0-RTT", func() {
|
||||
var (
|
||||
serv *earlyServer
|
||||
phm *MockPacketHandlerManager
|
||||
tracer *mocklogging.MockTracer
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
tracer = mocklogging.NewMockTracer(mockCtrl)
|
||||
ln, err := ListenEarly(conn, tlsConf, &Config{Tracer: tracer})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serv = ln.(*earlyServer)
|
||||
phm = NewMockPacketHandlerManager(mockCtrl)
|
||||
serv.connHandler = phm
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
phm.EXPECT().CloseServer().MaxTimes(1)
|
||||
serv.Close()
|
||||
})
|
||||
|
||||
It("passes packets to existing connections", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 100))
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
phm.EXPECT().Get(connID).Return(conn, true)
|
||||
handled := make(chan struct{})
|
||||
conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) })
|
||||
serv.handlePacket(p)
|
||||
Eventually(handled).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("queues 0-RTT packets, up to Max0RTTQueueSize", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
|
||||
var zeroRTTPackets []*receivedPacket
|
||||
|
||||
for i := 0; i < protocol.Max0RTTQueueLen; i++ {
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 100+i))
|
||||
phm.EXPECT().Get(connID)
|
||||
serv.handlePacket(p)
|
||||
zeroRTTPackets = append(zeroRTTPackets, p)
|
||||
}
|
||||
|
||||
// send one more packet, this one should be dropped
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 200))
|
||||
phm.EXPECT().Get(connID)
|
||||
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
serv.handlePacket(p)
|
||||
|
||||
initial := getPacket(&wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize))
|
||||
called := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ *protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
_ bool,
|
||||
_ logging.ConnectionTracer,
|
||||
_ uint64,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
var calls []*gomock.Call
|
||||
calls = append(calls, conn.EXPECT().handlePacket(initial))
|
||||
for _, p := range zeroRTTPackets {
|
||||
calls = append(calls, conn.EXPECT().handlePacket(p))
|
||||
}
|
||||
gomock.InOrder(calls...)
|
||||
conn.EXPECT().run()
|
||||
conn.EXPECT().earlyConnReady()
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
close(called)
|
||||
return conn
|
||||
}
|
||||
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
})
|
||||
serv.handlePacket(initial)
|
||||
Eventually(called).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("limits the number of queues", func() {
|
||||
for i := 0; i < protocol.Max0RTTQueues; i++ {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
connID := protocol.ParseConnectionID(b)
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 100+i))
|
||||
phm.EXPECT().Get(connID)
|
||||
serv.handlePacket(p)
|
||||
}
|
||||
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 200))
|
||||
phm.EXPECT().Get(connID)
|
||||
dropped := make(chan struct{})
|
||||
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
||||
close(dropped)
|
||||
})
|
||||
serv.handlePacket(p)
|
||||
Eventually(dropped).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("drops queues after a while", func() {
|
||||
now := time.Now()
|
||||
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 200))
|
||||
p.rcvTime = now
|
||||
|
||||
connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9})
|
||||
p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2)
|
||||
p2 := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID2,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 300))
|
||||
p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet
|
||||
|
||||
dropped1 := make(chan struct{})
|
||||
dropped2 := make(chan struct{})
|
||||
// need to register the call before handling the packet to avoid race condition
|
||||
gomock.InOrder(
|
||||
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
||||
close(dropped1)
|
||||
}),
|
||||
tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
||||
close(dropped2)
|
||||
}),
|
||||
)
|
||||
|
||||
phm.EXPECT().Get(connID)
|
||||
serv.handlePacket(p)
|
||||
|
||||
// There's no cleanup Go routine.
|
||||
// Cleanup is triggered when new packets are received.
|
||||
|
||||
phm.EXPECT().Get(connID2)
|
||||
serv.handlePacket(p2)
|
||||
// make sure no cleanup is executed
|
||||
Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed())
|
||||
|
||||
// There's no cleanup Go routine.
|
||||
// Cleanup is triggered when new packets are received.
|
||||
connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0})
|
||||
p3 := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID3,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 200))
|
||||
p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
|
||||
phm.EXPECT().Get(connID3)
|
||||
serv.handlePacket(p3)
|
||||
Eventually(dropped1).Should(BeClosed())
|
||||
Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed())
|
||||
|
||||
// make sure the second packet is also cleaned up
|
||||
connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1})
|
||||
p4 := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID4,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 200))
|
||||
p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
|
||||
phm.EXPECT().Get(connID4)
|
||||
serv.handlePacket(p4)
|
||||
Eventually(dropped2).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type zeroRTTQueue struct {
|
||||
queue []*receivedPacket
|
||||
retireTimer *time.Timer
|
||||
}
|
||||
|
||||
var _ packetHandler = &zeroRTTQueue{}
|
||||
|
||||
func (h *zeroRTTQueue) handlePacket(p *receivedPacket) {
|
||||
if len(h.queue) < protocol.Max0RTTQueueLen {
|
||||
h.queue = append(h.queue, p)
|
||||
}
|
||||
}
|
||||
func (h *zeroRTTQueue) shutdown() {}
|
||||
func (h *zeroRTTQueue) destroy(error) {}
|
||||
func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient }
|
||||
func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) {
|
||||
for _, p := range h.queue {
|
||||
sess.handlePacket(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *zeroRTTQueue) Clear() {
|
||||
for _, p := range h.queue {
|
||||
p.buffer.Release()
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue