move 0-RTT queue handling from the packet handler map to the server

This commit is contained in:
Marten Seemann 2023-04-26 10:07:28 +02:00
parent c0b94ee4b0
commit 58487803d3
7 changed files with 367 additions and 217 deletions

View file

@ -7,7 +7,6 @@ import (
"io"
mrand "math/rand"
"net"
"sort"
"sync"
"sync/atomic"
"time"
@ -214,7 +213,7 @@ var _ = Describe("0-RTT", func() {
ln,
proxy.LocalPort(),
clientTLSConf,
&quic.Config{ConnectionIDLength: connIDLen},
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
PRData,
)
@ -237,8 +236,7 @@ var _ = Describe("0-RTT", func() {
Expect(num0RTT).ToNot(BeZero())
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
sort.Slice(zeroRTTPackets, func(i, j int) bool { return zeroRTTPackets[i] < zeroRTTPackets[j] })
Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0)))
Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0)))
})
}

View file

@ -100,6 +100,21 @@ func (mr *MockPacketHandlerManagerMockRecorder) Destroy() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockPacketHandlerManager)(nil).Destroy))
}
// Get mocks base method.
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0)
ret0, _ := ret[0].(packetHandler)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
}
// GetStatelessResetToken mocks base method.
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken {
m.ctrl.T.Helper()

View file

@ -48,16 +48,14 @@ type packetHandlerMap struct {
closeQueue chan closePacket
handlers map[protocol.ConnectionID]packetHandler
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
server unknownPacketHandler
numZeroRTTEntries int
handlers map[protocol.ConnectionID]packetHandler
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
server unknownPacketHandler
listening chan struct{} // is closed when listen returns
closed bool
deleteRetiredConnsAfter time.Duration
zeroRTTQueueDuration time.Duration
statelessResetEnabled bool
statelessResetMutex sync.Mutex
@ -130,7 +128,6 @@ func newPacketHandlerMap(
handlers: make(map[protocol.ConnectionID]packetHandler),
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration,
closeQueue: make(chan closePacket, 4),
statelessResetEnabled: statelessResetKey != nil,
tracer: tracer,
@ -174,6 +171,14 @@ func (h *packetHandlerMap) logUsage() {
}
}
func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
h.mutex.Lock()
defer h.mutex.Unlock()
handler, ok := h.handlers[id]
return handler, ok
}
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
h.mutex.Lock()
defer h.mutex.Unlock()
@ -191,23 +196,11 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
h.mutex.Lock()
defer h.mutex.Unlock()
var q *zeroRTTQueue
if handler, ok := h.handlers[clientDestConnID]; ok {
q, ok = handler.(*zeroRTTQueue)
if !ok {
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false
}
q.retireTimer.Stop()
h.numZeroRTTEntries--
if h.numZeroRTTEntries < 0 {
panic("number of 0-RTT queues < 0")
}
if _, ok := h.handlers[clientDestConnID]; ok {
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false
}
conn := fn()
if q != nil {
q.EnqueueAll(conn)
}
h.handlers[clientDestConnID] = conn
h.handlers[newConnID] = conn
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
@ -395,17 +388,9 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset {
return
}
if handler, ok := h.handlers[connID]; ok {
if ha, ok := handler.(*zeroRTTQueue); ok { // only enqueue 0-RTT packets in the 0-RTT queue
if wire.Is0RTTPacket(p.data) {
ha.handlePacket(p)
return
}
} else { // existing connection
handler.handlePacket(p)
return
}
handler.handlePacket(p)
return
}
if !wire.IsLongHeaderPacket(p.data[0]) {
go h.maybeSendStatelessReset(p, connID)
@ -415,38 +400,6 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
return
}
if wire.Is0RTTPacket(p.data) {
if h.numZeroRTTEntries >= protocol.Max0RTTQueues {
if h.tracer != nil {
h.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return
}
h.numZeroRTTEntries++
queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)}
h.handlers[connID] = queue
queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() {
h.mutex.Lock()
defer h.mutex.Unlock()
// The entry might have been replaced by an actual connection.
// Only delete it if it's still a 0-RTT queue.
if handler, ok := h.handlers[connID]; ok {
if q, ok := handler.(*zeroRTTQueue); ok {
delete(h.handlers, connID)
h.numZeroRTTEntries--
if h.numZeroRTTEntries < 0 {
panic("number of 0-RTT queues < 0")
}
q.Clear()
if h.logger.Debug() {
h.logger.Debugf("Removing 0-RTT queue for %s.", connID)
}
}
}
})
queue.handlePacket(p)
return
}
h.server.handlePacket(p)
}

View file

@ -271,93 +271,6 @@ var _ = Describe("Packet Handler Map", func() {
})
})
Context("0-RTT", func() {
JustBeforeEach(func() {
handler.zeroRTTQueueDuration = time.Hour
server := NewMockUnknownPacketHandler(mockCtrl)
// we don't expect any calls to server.handlePacket
handler.SetServer(server)
})
It("queues 0-RTT packets", func() {
server := NewMockUnknownPacketHandler(mockCtrl)
// don't EXPECT any calls to server.handlePacket
handler.SetServer(server)
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}
p2 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2)}
p3 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 3)}
handler.handlePacket(p1)
handler.handlePacket(p2)
handler.handlePacket(p3)
conn := NewMockPacketHandler(mockCtrl)
done := make(chan struct{})
gomock.InOrder(
conn.EXPECT().handlePacket(p1),
conn.EXPECT().handlePacket(p2),
conn.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }),
)
handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn })
Eventually(done).Should(BeClosed())
})
It("directs 0-RTT packets to existing connections", func() {
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
conn := NewMockPacketHandler(mockCtrl)
handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn })
p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}
conn.EXPECT().handlePacket(p1)
handler.handlePacket(p1)
})
It("limits the number of 0-RTT queues", func() {
for i := 0; i < protocol.Max0RTTQueues; i++ {
b := make([]byte, 8)
rand.Read(b)
p := &receivedPacket{data: getPacketWithPacketType(
protocol.ParseConnectionID(b),
protocol.PacketType0RTT,
1,
)}
handler.handlePacket(p)
}
// We're already storing the maximum number of queues. This packet will be dropped.
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
tracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketType0RTT, gomock.Any(), logging.PacketDropDOSPrevention)
handler.handlePacket(&receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)})
// Don't EXPECT any handlePacket() calls.
conn := NewMockPacketHandler(mockCtrl)
handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn })
time.Sleep(20 * time.Millisecond)
})
It("deletes queues if no connection is created for this connection ID", func() {
queueDuration := scaleDuration(10 * time.Millisecond)
handler.zeroRTTQueueDuration = queueDuration
server := NewMockUnknownPacketHandler(mockCtrl)
// don't EXPECT any calls to server.handlePacket
handler.SetServer(server)
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
p1 := &receivedPacket{
data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1),
buffer: getPacketBuffer(),
}
p2 := &receivedPacket{
data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2),
buffer: getPacketBuffer(),
}
handler.handlePacket(p1)
handler.handlePacket(p2)
// wait a bit. The queue should now already be deleted.
time.Sleep(queueDuration * 3)
// Don't EXPECT any handlePacket() calls.
conn := NewMockPacketHandler(mockCtrl)
handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn })
time.Sleep(20 * time.Millisecond)
})
})
Context("stateless resets", func() {
BeforeEach(func() {
connIDLen = 5

131
server.go
View file

@ -36,6 +36,7 @@ type unknownPacketHandler interface {
}
type packetHandlerManager interface {
Get(protocol.ConnectionID) (packetHandler, bool)
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
Destroy() error
connRunner
@ -54,6 +55,11 @@ type quicConn interface {
shutdown()
}
type zeroRTTQueue struct {
packets []*receivedPacket
retireTimer *time.Timer
}
// A Listener of QUIC
type baseServer struct {
mutex sync.Mutex
@ -74,6 +80,9 @@ type baseServer struct {
receivedPackets chan *receivedPacket
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
zeroRTTQueueDuration time.Duration
// set as a member, so they can be set in the tests
newConn func(
sendConn,
@ -201,18 +210,22 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
return nil, err
}
s := &baseServer{
conn: c,
tlsConf: tlsConf,
config: config,
tokenGenerator: tokenGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
newConn: newConnection,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
conn: c,
tlsConf: tlsConf,
config: config,
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration,
tokenGenerator: tokenGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
newConn: newConnection,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
}
if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
}
go s.run()
connHandler.SetServer(s)
@ -344,6 +357,17 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
}
return false
}
if wire.Is0RTTPacket(p.data) {
if !s.acceptEarlyConns {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
return s.handle0RTTPacket(p)
}
// If we're creating a new connection, the packet will be passed to the connection.
// The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data)
@ -383,6 +407,63 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
return true
}
func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
connID, err := wire.ParseConnectionID(p.data, 0)
if err != nil {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
}
return false
}
// check again if we might have a connection now
if handler, ok := s.connHandler.Get(connID); ok {
handler.handlePacket(p)
return true
}
s.mutex.Lock()
defer s.mutex.Unlock()
if q, ok := s.zeroRTTQueues[connID]; ok {
if len(q.packets) >= protocol.Max0RTTQueueLen {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
}
q.packets = append(q.packets, p)
return true
}
if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
}
queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)}
queue.packets[0] = p
queue.retireTimer = time.AfterFunc(s.zeroRTTQueueDuration, func() {
s.mutex.Lock()
defer s.mutex.Unlock()
// The entry might have been replaced by an actual connection.
// Only delete it if it's still a 0-RTT queue.
if q, ok := s.zeroRTTQueues[connID]; ok {
for _, p := range q.packets {
// TODO: trace
p.buffer.Release()
}
delete(s.zeroRTTQueues, connID)
if s.logger.Debug() {
s.logger.Debugf("Removing 0-RTT queue for %s.", connID)
}
}
})
s.zeroRTTQueues[connID] = queue
return true
}
// validateToken returns false if:
// - address is invalid
// - token is expired
@ -412,6 +493,14 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
return errors.New("too short connection ID")
}
// The server queues packets for a while, and we might already have established a connection by now.
// This results in a second check in the connection map.
// That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets).
if handler, ok := s.connHandler.Get(hdr.DestConnectionID); ok {
handler.handlePacket(p)
return nil
}
var (
token *handshake.Token
retrySrcConnID *protocol.ConnectionID
@ -429,7 +518,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
}
clientAddrIsValid := s.validateToken(token, p.remoteAddr)
if token != nil && !clientAddrIsValid {
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
// We just ignore them, and act as if there was no token on this packet at all.
@ -450,6 +538,10 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
}
}
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
s.mutex.Lock()
delete(s.zeroRTTQueues, hdr.DestConnectionID)
s.mutex.Unlock()
go func() {
defer p.buffer.Release()
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
@ -510,8 +602,21 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
hdr.Version,
)
conn.handlePacket(p)
s.mutex.Lock()
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
q.retireTimer.Stop()
for _, p := range q.packets {
conn.handlePacket(p)
}
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
s.mutex.Unlock()
return conn
}); !added {
// TODO: don't just drop the packet
// Properly reject the connection attempt.
return nil
}
go conn.run()

View file

@ -215,8 +215,7 @@ var _ = Describe("Server", func() {
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize-100),
)
}, make([]byte, protocol.MinInitialPacketSize-100))
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
serv.handlePacket(p)
// make sure there are no Write calls on the packet conn
@ -234,6 +233,21 @@ var _ = Describe("Server", func() {
time.Sleep(50 * time.Millisecond)
})
It("passes packets to existing connections", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
conn := NewMockPacketHandler(mockCtrl)
phm.EXPECT().Get(connID).Return(conn, true)
handled := make(chan struct{})
conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) })
serv.handlePacket(p)
Eventually(handled).Should(BeClosed())
})
It("creates a connection when the token is accepted", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
@ -243,10 +257,11 @@ var _ = Describe("Server", func() {
protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
)
Expect(err).ToNot(HaveOccurred())
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
DestConnectionID: connID,
Version: protocol.Version1,
Token: retryToken,
}
@ -257,7 +272,9 @@ var _ = Describe("Server", func() {
rand.Read(token[:])
var newConnID protocol.ConnectionID
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
@ -408,11 +425,12 @@ var _ = Describe("Server", func() {
})
It("replies with a Retry packet, if a token is required", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
DestConnectionID: connID,
Version: protocol.Version1,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
@ -435,15 +453,17 @@ var _ = Describe("Server", func() {
Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:]))
return len(b), nil
})
phm.EXPECT().Get(connID)
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("creates a connection, if no token is required", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
DestConnectionID: connID,
Version: protocol.Version1,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
@ -452,15 +472,18 @@ var _ = Describe("Server", func() {
rand.Read(token[:])
var newConnID protocol.ConnectionID
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
newConnID = c
return token
})
fn()
return true
})
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
return token
})
fn()
return true
}),
)
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
conn := NewMockQUICConn(mockCtrl)
@ -513,6 +536,7 @@ var _ = Describe("Server", func() {
})
It("drops packets if the receive queue is full", func() {
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -597,8 +621,10 @@ var _ = Describe("Server", func() {
return conn
}
p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}))
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}), gomock.Any(), gomock.Any()).Return(false)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
p := getInitial(connID)
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
Expect(serv.handlePacketImpl(p)).To(BeTrue())
Expect(createdConn).To(BeFalse())
})
@ -632,6 +658,7 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -702,6 +729,7 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -765,6 +793,7 @@ var _ = Describe("Server", func() {
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
done := make(chan struct{})
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
@ -801,6 +830,7 @@ var _ = Describe("Server", func() {
checkInvalidToken(b, hdr)
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
@ -837,6 +867,7 @@ var _ = Describe("Server", func() {
checkInvalidToken(b, hdr)
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
@ -864,6 +895,7 @@ var _ = Describe("Server", func() {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
// make sure there are no Write calls on the packet conn
Eventually(done).Should(BeClosed())
@ -893,6 +925,7 @@ var _ = Describe("Server", func() {
defer close(done)
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
@ -913,6 +946,7 @@ var _ = Describe("Server", func() {
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
done := make(chan struct{})
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
@ -997,6 +1031,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -1070,6 +1105,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -1115,6 +1151,7 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -1175,6 +1212,7 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -1202,4 +1240,166 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})
})
Context("0-RTT", func() {
var (
serv *earlyServer
phm *MockPacketHandlerManager
tracer *mocklogging.MockTracer
)
BeforeEach(func() {
tracer = mocklogging.NewMockTracer(mockCtrl)
ln, err := ListenEarly(conn, tlsConf, &Config{Tracer: tracer})
Expect(err).ToNot(HaveOccurred())
serv = ln.(*earlyServer)
phm = NewMockPacketHandlerManager(mockCtrl)
serv.connHandler = phm
serv.zeroRTTQueueDuration = time.Hour
})
AfterEach(func() {
phm.EXPECT().CloseServer().MaxTimes(1)
serv.Close()
})
It("passes packets to existing connections", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 100))
conn := NewMockPacketHandler(mockCtrl)
phm.EXPECT().Get(connID).Return(conn, true)
handled := make(chan struct{})
conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) })
serv.handlePacket(p)
Eventually(handled).Should(BeClosed())
})
It("queues 0-RTT packets, up to Max0RTTQueueSize", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
var zeroRTTPackets []*receivedPacket
for i := 0; i < protocol.Max0RTTQueueLen; i++ {
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 100+i))
phm.EXPECT().Get(connID)
serv.handlePacket(p)
zeroRTTPackets = append(zeroRTTPackets, p)
}
// send one more packet, this one should be dropped
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 200))
phm.EXPECT().Get(connID)
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
serv.handlePacket(p)
initial := getPacket(&wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
called := make(chan struct{})
serv.newConn = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
conn := NewMockQUICConn(mockCtrl)
var calls []*gomock.Call
calls = append(calls, conn.EXPECT().handlePacket(initial))
for _, p := range zeroRTTPackets {
calls = append(calls, conn.EXPECT().handlePacket(p))
}
gomock.InOrder(calls...)
conn.EXPECT().run()
conn.EXPECT().earlyConnReady()
conn.EXPECT().Context().Return(context.Background())
close(called)
return conn
}
tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any())
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.handlePacket(initial)
Eventually(called).Should(BeClosed())
})
It("limits the number of queues", func() {
for i := 0; i < protocol.Max0RTTQueues; i++ {
b := make([]byte, 16)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 100+i))
phm.EXPECT().Get(connID)
serv.handlePacket(p)
}
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 200))
phm.EXPECT().Get(connID)
dropped := make(chan struct{})
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(dropped)
})
serv.handlePacket(p)
Eventually(dropped).Should(BeClosed())
})
It("drops queues after a while", func() {
serv.zeroRTTQueueDuration = time.Millisecond
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 200))
got := make(chan struct{})
phm.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(got) })
serv.handlePacket(p)
<-got
// TODO: use the tracer call
Eventually(func() int {
serv.mutex.Lock()
defer serv.mutex.Unlock()
return len(serv.zeroRTTQueues)
}).Should(BeZero())
})
})
})

View file

@ -1,34 +0,0 @@
package quic
import (
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
type zeroRTTQueue struct {
queue []*receivedPacket
retireTimer *time.Timer
}
var _ packetHandler = &zeroRTTQueue{}
func (h *zeroRTTQueue) handlePacket(p *receivedPacket) {
if len(h.queue) < protocol.Max0RTTQueueLen {
h.queue = append(h.queue, p)
}
}
func (h *zeroRTTQueue) shutdown() {}
func (h *zeroRTTQueue) destroy(error) {}
func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient }
func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) {
for _, p := range h.queue {
sess.handlePacket(p)
}
}
func (h *zeroRTTQueue) Clear() {
for _, p := range h.queue {
p.buffer.Release()
}
}