avoid lock contention when accepting new connections (#4313)

* avoid lock contention when accepting new connections

The server used to hold the packet handler map's lock while creating the
connection struct for a newly accepted connection. This was intended to
make sure that no two connections with the same Destination Connection
ID could be created.

This is a corner case: it can only happen if two Initial packets with
the same Destination Connection ID are received at the same time. If
the second one is received after the first one has already been
processed, it would be routed to the first connection. We don't need to
optimized for this corner case. It's ok to create a new connection in
that case, and immediately close it if this collision is detected.

* only pass 0-RTT to the connection if it was actually accepted
This commit is contained in:
Marten Seemann 2024-02-09 10:34:42 +07:00 committed by GitHub
parent 013949cda3
commit 8e93770dd3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 113 additions and 163 deletions

View file

@ -113,7 +113,7 @@ func (c *PacketHandlerManagerAddResetTokenCall) DoAndReturn(f func(protocol.Stat
} }
// AddWithConnID mocks base method. // AddWithConnID mocks base method.
func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() (packetHandler, bool)) bool { func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 packetHandler) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2) ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
@ -139,13 +139,13 @@ func (c *PacketHandlerManagerAddWithConnIDCall) Return(arg0 bool) *PacketHandler
} }
// Do rewrite *gomock.Call.Do // Do rewrite *gomock.Call.Do
func (c *PacketHandlerManagerAddWithConnIDCall) Do(f func(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool) *PacketHandlerManagerAddWithConnIDCall { func (c *PacketHandlerManagerAddWithConnIDCall) Do(f func(protocol.ConnectionID, protocol.ConnectionID, packetHandler) bool) *PacketHandlerManagerAddWithConnIDCall {
c.Call = c.Call.Do(f) c.Call = c.Call.Do(f)
return c return c
} }
// DoAndReturn rewrite *gomock.Call.DoAndReturn // DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *PacketHandlerManagerAddWithConnIDCall) DoAndReturn(f func(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool) *PacketHandlerManagerAddWithConnIDCall { func (c *PacketHandlerManagerAddWithConnIDCall) DoAndReturn(f func(protocol.ConnectionID, protocol.ConnectionID, packetHandler) bool) *PacketHandlerManagerAddWithConnIDCall {
c.Call = c.Call.DoAndReturn(f) c.Call = c.Call.DoAndReturn(f)
return c return c
} }

View file

@ -129,7 +129,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
return true return true
} }
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool { func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
@ -137,12 +137,8 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false return false
} }
conn, ok := fn() h.handlers[clientDestConnID] = handler
if !ok { h.handlers[newConnID] = handler
return false
}
h.handlers[clientDestConnID] = conn
h.handlers[newConnID] = conn
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
return true return true
} }

View file

@ -59,18 +59,12 @@ var _ = Describe("Packet Handler Map", func() {
It("adds newly to-be-constructed handlers", func() { It("adds newly to-be-constructed handlers", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
var called bool
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
Expect(m.AddWithConnID(connID1, connID2, func() (packetHandler, bool) { h := NewMockPacketHandler(mockCtrl)
called = true Expect(m.AddWithConnID(connID1, connID2, h)).To(BeTrue())
return NewMockPacketHandler(mockCtrl), true // collision of the destination connection ID, this handler should not be added
})).To(BeTrue()) Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), nil)).To(BeFalse())
Expect(called).To(BeTrue())
Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() (packetHandler, bool) {
Fail("didn't expect the constructor to be executed")
return nil, false
})).To(BeFalse())
}) })
It("adds, gets and removes reset tokens", func() { It("adds, gets and removes reset tokens", func() {

111
server.go
View file

@ -32,7 +32,7 @@ type packetHandler interface {
type packetHandlerManager interface { type packetHandlerManager interface {
Get(protocol.ConnectionID) (packetHandler, bool) Get(protocol.ConnectionID) (packetHandler, bool)
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool) GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool
Close(error) Close(error)
connRunner connRunner
} }
@ -636,63 +636,68 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
s.logger.Debugf("Changing connection ID to %s.", connID) s.logger.Debugf("Changing connection ID to %s.", connID)
var conn quicConn var conn quicConn
tracingID := nextConnTracingID() tracingID := nextConnTracingID()
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) { config := s.config
config := s.config if s.config.GetConfigForClient != nil {
if s.config.GetConfigForClient != nil { conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr})
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr}) if err != nil {
if err != nil { s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
return nil, false
}
config = populateConfig(conf)
}
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
connID := hdr.DestConnectionID
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
}
conn = s.newConn(
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
origDestConnID,
retrySrcConnID,
hdr.DestConnectionID,
hdr.SrcConnectionID,
connID,
s.connIDGenerator,
s.connHandler.GetStatelessResetToken(connID),
config,
s.tlsConf,
s.tokenGenerator,
clientAddrValidated,
tracer,
tracingID,
s.logger,
hdr.Version,
)
conn.handlePacket(p)
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
for _, p := range q.packets {
conn.handlePacket(p)
}
delete(s.zeroRTTQueues, hdr.DestConnectionID) delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
return nil
} }
config = populateConfig(conf)
return conn, true }
}); !added { var tracer *logging.ConnectionTracer
select { if config.Tracer != nil {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: // Use the same connection ID that is passed to the client's GetLogWriter callback.
default: connID := hdr.DestConnectionID
// drop packet if we can't send out the CONNECTION_REFUSED fast enough if origDestConnID.Len() > 0 {
p.buffer.Release() connID = origDestConnID
} }
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
}
conn = s.newConn(
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
origDestConnID,
retrySrcConnID,
hdr.DestConnectionID,
hdr.SrcConnectionID,
connID,
s.connIDGenerator,
s.connHandler.GetStatelessResetToken(connID),
config,
s.tlsConf,
s.tokenGenerator,
clientAddrValidated,
tracer,
tracingID,
s.logger,
hdr.Version,
)
conn.handlePacket(p)
// Adding the connection will fail if the client's chosen Destination Connection ID is already in use.
// This is very unlikely: Even if an attacker chooses a connection ID that's already in use,
// under normal circumstances the packet would just be routed to that connection.
// The only time this collision will occur if we receive the two Initial packets at the same time.
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
delete(s.zeroRTTQueues, hdr.DestConnectionID)
conn.closeWithTransportError(qerr.ConnectionRefused)
return nil return nil
} }
// Pass queued 0-RTT to the newly established connection.
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
for _, p := range q.packets {
conn.handlePacket(p)
}
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
if clientAddrValidated { if clientAddrValidated {
s.numHandshakesValidated.Add(1) s.numHandshakesValidated.Add(1)
} else { } else {

View file

@ -282,17 +282,6 @@ var _ = Describe("Server", func() {
rand.Read(token[:]) rand.Read(token[:])
var newConnID protocol.ConnectionID var newConnID protocol.ConnectionID
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
return token
})
_, ok := fn()
return ok
})
conn := NewMockQUICConn(mockCtrl) conn := NewMockQUICConn(mockCtrl)
serv.newConn = func( serv.newConn = func(
_ sendConn, _ sendConn,
@ -320,7 +309,7 @@ var _ = Describe("Server", func() {
// make sure we're using a server-generated connection ID // make sure we're using a server-generated connection ID
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
Expect(srcConnID).To(Equal(newConnID)) newConnID = srcConnID
Expect(tokenP).To(Equal(token)) Expect(tokenP).To(Equal(token))
conn.EXPECT().handlePacket(p) conn.EXPECT().handlePacket(p)
conn.EXPECT().run().Do(func() error { close(run); return nil }) conn.EXPECT().run().Do(func() error { close(run); return nil })
@ -328,6 +317,12 @@ var _ = Describe("Server", func() {
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn return conn
} }
phm.EXPECT().Get(connID)
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool {
Expect(cid).To(Equal(newConnID))
return true
})
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -483,19 +478,6 @@ var _ = Describe("Server", func() {
rand.Read(token[:]) rand.Read(token[:])
var newConnID protocol.ConnectionID var newConnID protocol.ConnectionID
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
return token
})
_, ok := fn()
return ok
}),
)
conn := NewMockQUICConn(mockCtrl) conn := NewMockQUICConn(mockCtrl)
serv.newConn = func( serv.newConn = func(
_ sendConn, _ sendConn,
@ -523,7 +505,7 @@ var _ = Describe("Server", func() {
// make sure we're using a server-generated connection ID // make sure we're using a server-generated connection ID
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
Expect(srcConnID).To(Equal(newConnID)) newConnID = srcConnID
Expect(tokenP).To(Equal(token)) Expect(tokenP).To(Equal(token))
conn.EXPECT().handlePacket(p) conn.EXPECT().handlePacket(p)
conn.EXPECT().run().Do(func() error { close(run); return nil }) conn.EXPECT().run().Do(func() error { close(run); return nil })
@ -531,6 +513,14 @@ var _ = Describe("Server", func() {
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn return conn
} }
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token),
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool {
Expect(c).To(Equal(newConnID))
return true
}),
)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -553,11 +543,8 @@ var _ = Describe("Server", func() {
serv.maxNumHandshakesUnvalidated = 10000 serv.maxNumHandshakesUnvalidated = 10000
phm.EXPECT().Get(gomock.Any()).AnyTimes() phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
_, ok := fn()
return ok
}).AnyTimes()
acceptConn := make(chan struct{}) acceptConn := make(chan struct{})
var counter atomic.Uint32 var counter atomic.Uint32
@ -614,7 +601,7 @@ var _ = Describe("Server", func() {
Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
}) })
It("only creates a single connection for a duplicate Initial", func() { PIt("only creates a single connection for a duplicate Initial", func() {
var createdConn bool var createdConn bool
serv.newConn = func( serv.newConn = func(
_ sendConn, _ sendConn,
@ -642,7 +629,7 @@ var _ = Describe("Server", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
p := getInitial(connID) p := getInitial(connID)
phm.EXPECT().Get(connID) phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
done := make(chan struct{}) done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) (int, error) { close(done); return 0, nil }) conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) (int, error) { close(done); return 0, nil })
@ -657,11 +644,8 @@ var _ = Describe("Server", func() {
serv.maxNumHandshakesUnvalidated = limit serv.maxNumHandshakesUnvalidated = limit
phm.EXPECT().Get(gomock.Any()).AnyTimes() phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
_, ok := fn()
return ok
}).AnyTimes()
handshakeChan := make(chan struct{}) handshakeChan := make(chan struct{})
connChan := make(chan *MockQUICConn, 1) connChan := make(chan *MockQUICConn, 1)
@ -739,11 +723,8 @@ var _ = Describe("Server", func() {
serv.maxNumHandshakesUnvalidated = limit // same limit, but we check that we send CONNECTION_REFUSED and not Retry serv.maxNumHandshakesUnvalidated = limit // same limit, but we check that we send CONNECTION_REFUSED and not Retry
phm.EXPECT().Get(gomock.Any()).AnyTimes() phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
_, ok := fn()
return ok
}).AnyTimes()
handshakeChan := make(chan struct{}) handshakeChan := make(chan struct{})
connChan := make(chan *MockQUICConn, 1) connChan := make(chan *MockQUICConn, 1)
@ -841,10 +822,12 @@ var _ = Describe("Server", func() {
done := make(chan struct{}) done := make(chan struct{})
phm.EXPECT().Get(gomock.Any()) phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool {
close(done) close(done)
return false return true
}) })
phm.EXPECT().Remove(gomock.Any()).AnyTimes()
serv.handlePacket(packet) serv.handlePacket(packet)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -1059,11 +1042,8 @@ var _ = Describe("Server", func() {
return conn return conn
} }
phm.EXPECT().Get(gomock.Any()) phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
_, ok := fn()
return ok
})
serv.handleInitialImpl( serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()}, receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
@ -1128,11 +1108,8 @@ var _ = Describe("Server", func() {
return conn return conn
} }
phm.EXPECT().Get(gomock.Any()) phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
_, ok := fn()
return ok
})
serv.handleInitialImpl( serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()}, receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
@ -1146,10 +1123,6 @@ var _ = Describe("Server", func() {
serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }}) serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
phm.EXPECT().Get(gomock.Any()) phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
_, ok := fn()
return ok
})
done := make(chan struct{}) done := make(chan struct{})
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
@ -1204,11 +1177,8 @@ var _ = Describe("Server", func() {
return conn return conn
} }
phm.EXPECT().Get(gomock.Any()) phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
_, ok := fn()
return ok
})
serv.handleInitialImpl( serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()}, receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
@ -1277,11 +1247,8 @@ var _ = Describe("Server", func() {
return conn return conn
} }
phm.EXPECT().Get(gomock.Any()) phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
_, ok := fn()
return ok
})
serv.baseServer.handleInitialImpl( serv.baseServer.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()}, receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
@ -1326,11 +1293,8 @@ var _ = Describe("Server", func() {
} }
phm.EXPECT().Get(gomock.Any()).AnyTimes() phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
_, ok := fn()
return ok
}).Times(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ { for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
conn := NewMockQUICConn(mockCtrl) conn := NewMockQUICConn(mockCtrl)
connChan <- conn connChan <- conn
@ -1339,11 +1303,8 @@ var _ = Describe("Server", func() {
Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize)) Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize))
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
_, ok := fn()
return ok
})
conn := NewMockQUICConn(mockCtrl) conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().closeWithTransportError(ConnectionRefused) conn.EXPECT().closeWithTransportError(ConnectionRefused)
connChan <- conn connChan <- conn
@ -1384,11 +1345,8 @@ var _ = Describe("Server", func() {
} }
phm.EXPECT().Get(gomock.Any()) phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
_, ok := fn()
return ok
})
serv.baseServer.handlePacket(p) serv.baseServer.handlePacket(p)
// make sure there are no Write calls on the packet conn // make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@ -1432,7 +1390,7 @@ var _ = Describe("Server", func() {
AfterEach(func() { AfterEach(func() {
tracer.EXPECT().Close() tracer.EXPECT().Close()
tr.Close() Expect(tr.Close()).To(Succeed())
}) })
It("passes packets to existing connections", func() { It("passes packets to existing connections", func() {
@ -1518,11 +1476,8 @@ var _ = Describe("Server", func() {
} }
phm.EXPECT().Get(connID) phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
_, ok := fn()
return ok
})
serv.handlePacket(initial) serv.handlePacket(initial)
Eventually(called).Should(BeClosed()) Eventually(called).Should(BeClosed())
}) })