rename occurrences of session in client, server and packetHandlerMap

This commit is contained in:
Marten Seemann 2022-03-26 15:39:44 +01:00
parent 86338d3ce0
commit 1ae835d1d8
7 changed files with 185 additions and 185 deletions

View file

@ -14,7 +14,7 @@ import (
) )
type client struct { type client struct {
conn sendConn sconn sendConn
// If the client is created with DialAddr, we create a packet conn. // If the client is created with DialAddr, we create a packet conn.
// If it is started with Dial, we take a packet conn as a parameter. // If it is started with Dial, we take a packet conn as a parameter.
createdPacketConn bool createdPacketConn bool
@ -35,7 +35,7 @@ type client struct {
handshakeChan chan struct{} handshakeChan chan struct{}
session quicConn conn quicConn
tracer logging.ConnectionTracer tracer logging.ConnectionTracer
tracingID uint64 tracingID uint64
@ -49,7 +49,7 @@ var (
) )
// DialAddr establishes a new QUIC connection to a server. // DialAddr establishes a new QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC session is closed. // It uses a new UDP connection and closes this connection when the QUIC connection is closed.
// The hostname for SNI is taken from the given address. // The hostname for SNI is taken from the given address.
// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. // The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
func DialAddr( func DialAddr(
@ -61,7 +61,7 @@ func DialAddr(
} }
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. // DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC session is closed. // It uses a new UDP connection and closes this connection when the QUIC connection is closed.
// The hostname for SNI is taken from the given address. // The hostname for SNI is taken from the given address.
// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. // The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
func DialAddrEarly( func DialAddrEarly(
@ -80,12 +80,12 @@ func DialAddrEarlyContext(
tlsConf *tls.Config, tlsConf *tls.Config,
config *Config, config *Config,
) (EarlyConnection, error) { ) (EarlyConnection, error) {
sess, err := dialAddrContext(ctx, addr, tlsConf, config, true) conn, err := dialAddrContext(ctx, addr, tlsConf, config, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early session") utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection")
return sess, nil return conn, nil
} }
// DialAddrContext establishes a new QUIC connection to a server using the provided context. // DialAddrContext establishes a new QUIC connection to a server using the provided context.
@ -212,12 +212,12 @@ func dialContext(
) )
} }
if c.tracer != nil { if c.tracer != nil {
c.tracer.StartedConnection(c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID) c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID)
} }
if err := c.dial(ctx); err != nil { if err := c.dial(ctx); err != nil {
return nil, err return nil, err
} }
return c.session, nil return c.conn, nil
} }
func newClient( func newClient(
@ -265,7 +265,7 @@ func newClient(
c := &client{ c := &client{
srcConnID: srcConnID, srcConnID: srcConnID,
destConnID: destConnID, destConnID: destConnID,
conn: newSendPconn(pconn, remoteAddr), sconn: newSendPconn(pconn, remoteAddr),
createdPacketConn: createdPacketConn, createdPacketConn: createdPacketConn,
use0RTT: use0RTT, use0RTT: use0RTT,
tlsConf: tlsConf, tlsConf: tlsConf,
@ -278,10 +278,10 @@ func newClient(
} }
func (c *client) dial(ctx context.Context) error { func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.session = newClientSession( c.conn = newClientSession(
c.conn, c.sconn,
c.packetHandlers, c.packetHandlers,
c.destConnID, c.destConnID,
c.srcConnID, c.srcConnID,
@ -295,11 +295,11 @@ func (c *client) dial(ctx context.Context) error {
c.logger, c.logger,
c.version, c.version,
) )
c.packetHandlers.Add(c.srcConnID, c.session) c.packetHandlers.Add(c.srcConnID, c.conn)
errorChan := make(chan error, 1) errorChan := make(chan error, 1)
go func() { go func() {
err := c.session.run() // returns as soon as the session is closed err := c.conn.run() // returns as soon as the connection is closed
if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn { if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn {
c.packetHandlers.Destroy() c.packetHandlers.Destroy()
@ -308,15 +308,15 @@ func (c *client) dial(ctx context.Context) error {
}() }()
// only set when we're using 0-RTT // only set when we're using 0-RTT
// Otherwise, earlySessionChan will be nil. Receiving from a nil chan blocks forever. // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
var earlySessionChan <-chan struct{} var earlyConnChan <-chan struct{}
if c.use0RTT { if c.use0RTT {
earlySessionChan = c.session.earlySessionReady() earlyConnChan = c.conn.earlySessionReady()
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
c.session.shutdown() c.conn.shutdown()
return ctx.Err() return ctx.Err()
case err := <-errorChan: case err := <-errorChan:
var recreateErr *errCloseForRecreating var recreateErr *errCloseForRecreating
@ -327,10 +327,10 @@ func (c *client) dial(ctx context.Context) error {
return c.dial(ctx) return c.dial(ctx)
} }
return err return err
case <-earlySessionChan: case <-earlyConnChan:
// ready to send 0-RTT data // ready to send 0-RTT data
return nil return nil
case <-c.session.HandshakeComplete().Done(): case <-c.conn.HandshakeComplete().Done():
// handshake successfully completed // handshake successfully completed
return nil return nil
} }

View file

@ -56,7 +56,7 @@ var _ = Describe("Client", func() {
tr := mocklogging.NewMockTracer(mockCtrl) tr := mocklogging.NewMockTracer(mockCtrl)
tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.VersionTLS}} config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.VersionTLS}}
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areConnsRunning).Should(BeFalse())
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = NewMockPacketConn(mockCtrl) packetConn = NewMockPacketConn(mockCtrl)
packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
@ -64,7 +64,7 @@ var _ = Describe("Client", func() {
srcConnID: connID, srcConnID: connID,
destConnID: connID, destConnID: connID,
version: protocol.VersionTLS, version: protocol.VersionTLS,
conn: newSendPconn(packetConn, addr), sconn: newSendPconn(packetConn, addr),
tracer: tracer, tracer: tracer,
logger: utils.DefaultLogger, logger: utils.DefaultLogger,
} }
@ -81,10 +81,10 @@ var _ = Describe("Client", func() {
}) })
AfterEach(func() { AfterEach(func() {
if s, ok := cl.session.(*session); ok { if s, ok := cl.conn.(*session); ok {
s.shutdown() s.shutdown()
} }
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areConnsRunning).Should(BeFalse())
}) })
Context("Dialing", func() { Context("Dialing", func() {
@ -259,7 +259,7 @@ var _ = Describe("Client", func() {
Eventually(run).Should(BeClosed()) Eventually(run).Should(BeClosed())
}) })
It("returns early sessions", func() { It("returns early connections", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
@ -345,16 +345,16 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
It("closes the session when the context is canceled", func() { It("closes the connection when the context is canceled", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
sessionRunning := make(chan struct{}) connRunning := make(chan struct{})
defer close(sessionRunning) defer close(connRunning)
conn := NewMockQuicConn(mockCtrl) conn := NewMockQuicConn(mockCtrl)
conn.EXPECT().run().Do(func() { conn.EXPECT().run().Do(func() {
<-sessionRunning <-connRunning
}) })
conn.EXPECT().HandshakeComplete().Return(context.Background()) conn.EXPECT().HandshakeComplete().Return(context.Background())
newClientSession = func( newClientSession = func(
@ -407,7 +407,7 @@ var _ = Describe("Client", func() {
var sconn sendConn var sconn sendConn
run := make(chan struct{}) run := make(chan struct{})
sessionCreated := make(chan struct{}) connCreated := make(chan struct{})
conn := NewMockQuicConn(mockCtrl) conn := NewMockQuicConn(mockCtrl)
newClientSession = func( newClientSession = func(
connP sendConn, connP sendConn,
@ -425,7 +425,7 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicConn { ) quicConn {
sconn = connP sconn = connP
close(sessionCreated) close(connCreated)
return conn return conn
} }
conn.EXPECT().run().Do(func() { conn.EXPECT().run().Do(func() {
@ -441,7 +441,7 @@ var _ = Describe("Client", func() {
close(done) close(done)
}() }()
Eventually(sessionCreated).Should(BeClosed()) Eventually(connCreated).Should(BeClosed())
// check that the connection is not closed // check that the connection is not closed
Expect(sconn.Write([]byte("foobar"))).To(Succeed()) Expect(sconn.Write([]byte("foobar"))).To(Succeed())
@ -519,7 +519,7 @@ var _ = Describe("Client", func() {
}) })
}) })
It("creates new sessions with the right parameters", func() { It("creates new connections with the right parameters", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any()) manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
@ -562,7 +562,7 @@ var _ = Describe("Client", func() {
Expect(conf.Versions).To(Equal(config.Versions)) Expect(conf.Versions).To(Equal(config.Versions))
}) })
It("creates a new session after version negotiation", func() { It("creates a new connections after version negotiation", func() {
manager := NewMockPacketHandlerManager(mockCtrl) manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any()).Times(2) manager.EXPECT().Add(connID, gomock.Any()).Times(2)
manager.EXPECT().Destroy() manager.EXPECT().Destroy()

View file

@ -71,8 +71,8 @@ type packetHandlerMap struct {
listening chan struct{} // is closed when listen returns listening chan struct{} // is closed when listen returns
closed bool closed bool
deleteRetiredSessionsAfter time.Duration deleteRetiredConnsAfter time.Duration
zeroRTTQueueDuration time.Duration zeroRTTQueueDuration time.Duration
statelessResetEnabled bool statelessResetEnabled bool
statelessResetMutex sync.Mutex statelessResetMutex sync.Mutex
@ -138,17 +138,17 @@ func newPacketHandlerMap(
return nil, err return nil, err
} }
m := &packetHandlerMap{ m := &packetHandlerMap{
conn: conn, conn: conn,
connIDLen: connIDLen, connIDLen: connIDLen,
listening: make(chan struct{}), listening: make(chan struct{}),
handlers: make(map[string]packetHandlerMapEntry), handlers: make(map[string]packetHandlerMapEntry),
resetTokens: make(map[protocol.StatelessResetToken]packetHandler), resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration,
statelessResetEnabled: len(statelessResetKey) > 0, statelessResetEnabled: len(statelessResetKey) > 0,
statelessResetHasher: hmac.New(sha256.New, statelessResetKey), statelessResetHasher: hmac.New(sha256.New, statelessResetKey),
tracer: tracer, tracer: tracer,
logger: logger, logger: logger,
} }
go m.listen() go m.listen()
@ -204,7 +204,7 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
var q *zeroRTTQueue var q *zeroRTTQueue
if entry, ok := h.handlers[string(clientDestConnID)]; ok { if entry, ok := h.handlers[string(clientDestConnID)]; ok {
if !entry.is0RTTQueue { if !entry.is0RTTQueue {
h.logger.Debugf("Not adding connection ID %s for a new session, as it already exists.", clientDestConnID) h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false return false
} }
q = entry.packetHandler.(*zeroRTTQueue) q = entry.packetHandler.(*zeroRTTQueue)
@ -220,7 +220,7 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
} }
h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess} h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess}
h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess} h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess}
h.logger.Debugf("Adding connection IDs %s and %s for a new session.", clientDestConnID, newConnID) h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
return true return true
} }
@ -232,8 +232,8 @@ func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
} }
func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredSessionsAfter) h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter)
time.AfterFunc(h.deleteRetiredSessionsAfter, func() { time.AfterFunc(h.deleteRetiredConnsAfter, func() {
h.mutex.Lock() h.mutex.Lock()
delete(h.handlers, string(id)) delete(h.handlers, string(id))
h.mutex.Unlock() h.mutex.Unlock()
@ -245,14 +245,14 @@ func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler p
h.mutex.Lock() h.mutex.Lock()
h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler}
h.mutex.Unlock() h.mutex.Unlock()
h.logger.Debugf("Replacing session for connection ID %s with a closed session.", id) h.logger.Debugf("Replacing connection for connection ID %s with a closed connection.", id)
time.AfterFunc(h.deleteRetiredSessionsAfter, func() { time.AfterFunc(h.deleteRetiredConnsAfter, func() {
h.mutex.Lock() h.mutex.Lock()
handler.shutdown() handler.shutdown()
delete(h.handlers, string(id)) delete(h.handlers, string(id))
h.mutex.Unlock() h.mutex.Unlock()
h.logger.Debugf("Removing connection ID %s for a closed session after it has been retired.", id) h.logger.Debugf("Removing connection ID %s for a closed connection after it has been retired.", id)
}) })
} }
@ -297,7 +297,7 @@ func (h *packetHandlerMap) CloseServer() {
} }
// Destroy closes the underlying connection and waits until listen() has returned. // Destroy closes the underlying connection and waits until listen() has returned.
// It does not close active sessions. // It does not close active connections.
func (h *packetHandlerMap) Destroy() error { func (h *packetHandlerMap) Destroy() error {
if err := h.conn.Close(); err != nil { if err := h.conn.Close(); err != nil {
return err return err
@ -371,7 +371,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
entry.packetHandler.handlePacket(p) entry.packetHandler.handlePacket(p)
return return
} }
} else { // existing session } else { // existing connection
entry.packetHandler.handlePacket(p) entry.packetHandler.handlePacket(p)
return return
} }
@ -397,7 +397,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
// The entry might have been replaced by an actual session. // The entry might have been replaced by an actual connection.
// Only delete it if it's still a 0-RTT queue. // Only delete it if it's still a 0-RTT queue.
if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue { if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue {
delete(h.handlers, string(connID)) delete(h.handlers, string(connID))
@ -429,7 +429,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
var token protocol.StatelessResetToken var token protocol.StatelessResetToken
copy(token[:], data[len(data)-16:]) copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok { if sess, ok := h.resetTokens[token]; ok {
h.logger.Debugf("Received a stateless reset with token %#x. Closing session.", token) h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
go sess.destroy(&StatelessResetError{Token: token}) go sess.destroy(&StatelessResetError{Token: token})
return true return true
} }

View file

@ -89,12 +89,12 @@ var _ = Describe("Packet Handler Map", func() {
}() }()
testErr := errors.New("test error ") testErr := errors.New("test error ")
sess1 := NewMockPacketHandler(mockCtrl) conn1 := NewMockPacketHandler(mockCtrl)
sess1.EXPECT().destroy(testErr) conn1.EXPECT().destroy(testErr)
sess2 := NewMockPacketHandler(mockCtrl) conn2 := NewMockPacketHandler(mockCtrl)
sess2.EXPECT().destroy(testErr) conn2.EXPECT().destroy(testErr)
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1) handler.Add(protocol.ConnectionID{1, 1, 1, 1}, conn1)
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2) handler.Add(protocol.ConnectionID{2, 2, 2, 2}, conn2)
mockMultiplexer.EXPECT().RemoveConn(gomock.Any()) mockMultiplexer.EXPECT().RemoveConn(gomock.Any())
handler.close(testErr) handler.close(testErr)
close(packetChan) close(packetChan)
@ -103,7 +103,7 @@ var _ = Describe("Packet Handler Map", func() {
Context("other operations", func() { Context("other operations", func() {
AfterEach(func() { AfterEach(func() {
// delete sessions and the server before closing // delete connections and the server before closing
// They might be mock implementations, and we'd have to register the expected calls before otherwise. // They might be mock implementations, and we'd have to register the expected calls before otherwise.
handler.mutex.Lock() handler.mutex.Lock()
for connID := range handler.handlers { for connID := range handler.handlers {
@ -160,8 +160,8 @@ var _ = Describe("Packet Handler Map", func() {
}) })
}) })
It("deletes removed sessions immediately", func() { It("deletes removed connections immediately", func() {
handler.deleteRetiredSessionsAfter = time.Hour handler.deleteRetiredConnsAfter = time.Hour
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Add(connID, NewMockPacketHandler(mockCtrl))
handler.Remove(connID) handler.Remove(connID)
@ -169,19 +169,19 @@ var _ = Describe("Packet Handler Map", func() {
// don't EXPECT any calls to handlePacket of the MockPacketHandler // don't EXPECT any calls to handlePacket of the MockPacketHandler
}) })
It("deletes retired session entries after a wait time", func() { It("deletes retired connection entries after a wait time", func() {
handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond)
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
sess := NewMockPacketHandler(mockCtrl) conn := NewMockPacketHandler(mockCtrl)
handler.Add(connID, sess) handler.Add(connID, conn)
handler.Retire(connID) handler.Retire(connID)
time.Sleep(scaleDuration(30 * time.Millisecond)) time.Sleep(scaleDuration(30 * time.Millisecond))
handler.handlePacket(&receivedPacket{data: getPacket(connID)}) handler.handlePacket(&receivedPacket{data: getPacket(connID)})
// don't EXPECT any calls to handlePacket of the MockPacketHandler // don't EXPECT any calls to handlePacket of the MockPacketHandler
}) })
It("passes packets arriving late for closed sessions to that session", func() { It("passes packets arriving late for closed connections to that connection", func() {
handler.deleteRetiredSessionsAfter = time.Hour handler.deleteRetiredConnsAfter = time.Hour
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
packetHandler := NewMockPacketHandler(mockCtrl) packetHandler := NewMockPacketHandler(mockCtrl)
handled := make(chan struct{}) handled := make(chan struct{})
@ -250,16 +250,16 @@ var _ = Describe("Packet Handler Map", func() {
handler.handlePacket(&receivedPacket{data: p}) handler.handlePacket(&receivedPacket{data: p})
}) })
It("closes all server sessions", func() { It("closes all server connections", func() {
handler.SetServer(NewMockUnknownPacketHandler(mockCtrl)) handler.SetServer(NewMockUnknownPacketHandler(mockCtrl))
clientSess := NewMockPacketHandler(mockCtrl) clientConn := NewMockPacketHandler(mockCtrl)
clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient) clientConn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
serverSess := NewMockPacketHandler(mockCtrl) serverConn := NewMockPacketHandler(mockCtrl)
serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer) serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
serverSess.EXPECT().shutdown() serverConn.EXPECT().shutdown()
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess) handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientConn)
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess) handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverConn)
handler.CloseServer() handler.CloseServer()
}) })
@ -293,23 +293,23 @@ var _ = Describe("Packet Handler Map", func() {
handler.handlePacket(p1) handler.handlePacket(p1)
handler.handlePacket(p2) handler.handlePacket(p2)
handler.handlePacket(p3) handler.handlePacket(p3)
sess := NewMockPacketHandler(mockCtrl) conn := NewMockPacketHandler(mockCtrl)
done := make(chan struct{}) done := make(chan struct{})
gomock.InOrder( gomock.InOrder(
sess.EXPECT().handlePacket(p1), conn.EXPECT().handlePacket(p1),
sess.EXPECT().handlePacket(p2), conn.EXPECT().handlePacket(p2),
sess.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }), conn.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }),
) )
handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn })
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("directs 0-RTT packets to existing sessions", func() { It("directs 0-RTT packets to existing connections", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
sess := NewMockPacketHandler(mockCtrl) conn := NewMockPacketHandler(mockCtrl)
handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn })
p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}
sess.EXPECT().handlePacket(p1) conn.EXPECT().handlePacket(p1)
handler.handlePacket(p1) handler.handlePacket(p1)
}) })
@ -324,12 +324,12 @@ var _ = Describe("Packet Handler Map", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9} connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}
handler.handlePacket(&receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}) handler.handlePacket(&receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)})
// Don't EXPECT any handlePacket() calls. // Don't EXPECT any handlePacket() calls.
sess := NewMockPacketHandler(mockCtrl) conn := NewMockPacketHandler(mockCtrl)
handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn })
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)
}) })
It("deletes queues if no session is created for this connection ID", func() { It("deletes queues if no connection is created for this connection ID", func() {
queueDuration := scaleDuration(10 * time.Millisecond) queueDuration := scaleDuration(10 * time.Millisecond)
handler.zeroRTTQueueDuration = queueDuration handler.zeroRTTQueueDuration = queueDuration
@ -350,8 +350,8 @@ var _ = Describe("Packet Handler Map", func() {
// wait a bit. The queue should now already be deleted. // wait a bit. The queue should now already be deleted.
time.Sleep(queueDuration * 3) time.Sleep(queueDuration * 3)
// Don't EXPECT any handlePacket() calls. // Don't EXPECT any handlePacket() calls.
sess := NewMockPacketHandler(mockCtrl) conn := NewMockPacketHandler(mockCtrl)
handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn })
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)
}) })
}) })

102
server.go
View file

@ -56,7 +56,7 @@ type quicConn interface {
type baseServer struct { type baseServer struct {
mutex sync.Mutex mutex sync.Mutex
acceptEarlySessions bool acceptEarlyConns bool
tlsConf *tls.Config tlsConf *tls.Config
config *Config config *Config
@ -68,7 +68,7 @@ type baseServer struct {
tokenGenerator *handshake.TokenGenerator tokenGenerator *handshake.TokenGenerator
sessionHandler packetHandlerManager connHandler packetHandlerManager
receivedPackets chan *receivedPacket receivedPackets chan *receivedPacket
@ -97,8 +97,8 @@ type baseServer struct {
closed bool closed bool
running chan struct{} // closed as soon as run() returns running chan struct{} // closed as soon as run() returns
sessionQueue chan quicConn connQueue chan quicConn
sessionQueueLen int32 // to be used as an atomic connQueueLen int32 // to be used as an atomic
logger utils.Logger logger utils.Logger
} }
@ -123,7 +123,7 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err
return listenAddr(addr, tlsConf, config, false) return listenAddr(addr, tlsConf, config, false)
} }
// ListenAddrEarly works like ListenAddr, but it returns sessions before the handshake completes. // ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes.
func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyListener, error) { func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyListener, error) {
s, err := listenAddr(addr, tlsConf, config, true) s, err := listenAddr(addr, tlsConf, config, true)
if err != nil { if err != nil {
@ -164,7 +164,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
return listen(conn, tlsConf, config, false) return listen(conn, tlsConf, config, false)
} }
// ListenEarly works like Listen, but it returns sessions before the handshake completes. // ListenEarly works like Listen, but it returns connections before the handshake completes.
func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (EarlyListener, error) { func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (EarlyListener, error) {
s, err := listen(conn, tlsConf, config, true) s, err := listen(conn, tlsConf, config, true)
if err != nil { if err != nil {
@ -187,7 +187,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
} }
} }
sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -200,21 +200,21 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
return nil, err return nil, err
} }
s := &baseServer{ s := &baseServer{
conn: c, conn: c,
tlsConf: tlsConf, tlsConf: tlsConf,
config: config, config: config,
tokenGenerator: tokenGenerator, tokenGenerator: tokenGenerator,
sessionHandler: sessionHandler, connHandler: connHandler,
sessionQueue: make(chan quicConn), connQueue: make(chan quicConn),
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
running: make(chan struct{}), running: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
newSession: newSession, newSession: newSession,
logger: utils.DefaultLogger.WithPrefix("server"), logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlySessions: acceptEarly, acceptEarlyConns: acceptEarly,
} }
go s.run() go s.run()
sessionHandler.SetServer(s) connHandler.SetServer(s)
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s, nil return s, nil
} }
@ -258,8 +258,8 @@ var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool {
return sourceAddr == token.RemoteAddr return sourceAddr == token.RemoteAddr
} }
// Accept returns sessions that already completed the handshake. // Accept returns connections that already completed the handshake.
// It is only valid if acceptEarlySessions is false. // It is only valid if acceptEarlyConns is false.
func (s *baseServer) Accept(ctx context.Context) (Connection, error) { func (s *baseServer) Accept(ctx context.Context) (Connection, error) {
return s.accept(ctx) return s.accept(ctx)
} }
@ -268,8 +268,8 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
case sess := <-s.sessionQueue: case sess := <-s.connQueue:
atomic.AddInt32(&s.sessionQueueLen, -1) atomic.AddInt32(&s.connQueueLen, -1)
return sess, nil return sess, nil
case <-s.errorChan: case <-s.errorChan:
return nil, s.serverError return nil, s.serverError
@ -294,9 +294,9 @@ func (s *baseServer) Close() error {
s.mutex.Unlock() s.mutex.Unlock()
<-s.running <-s.running
s.sessionHandler.CloseServer() s.connHandler.CloseServer()
if createdPacketConn { if createdPacketConn {
return s.sessionHandler.Destroy() return s.connHandler.Destroy()
} }
return nil return nil
} }
@ -336,7 +336,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
} }
return false return false
} }
// If we're creating a new session, the packet will be passed to the session. // If we're creating a new connection, the packet will be passed to the connection.
// The header will then be parsed again. // The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength) hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength)
if err != nil && err != wire.ErrUnsupportedVersion { if err != nil && err != wire.ErrUnsupportedVersion {
@ -436,7 +436,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
return nil return nil
} }
if queueLen := atomic.LoadInt32(&s.sessionQueueLen); queueLen >= protocol.MaxAcceptQueueSize { if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize {
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
go func() { go func() {
defer p.buffer.Release() defer p.buffer.Release()
@ -452,9 +452,9 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
return err return err
} }
s.logger.Debugf("Changing connection ID to %s.", connID) s.logger.Debugf("Changing connection ID to %s.", connID)
var sess quicConn var conn quicConn
tracingID := nextSessionTracingID() tracingID := nextSessionTracingID()
if added := s.sessionHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler {
var tracer logging.ConnectionTracer var tracer logging.ConnectionTracer
if s.config.Tracer != nil { if s.config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback. // Use the same connection ID that is passed to the client's GetLogWriter callback.
@ -468,69 +468,69 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
connID, connID,
) )
} }
sess = s.newSession( conn = s.newSession(
newSendConn(s.conn, p.remoteAddr, p.info), newSendConn(s.conn, p.remoteAddr, p.info),
s.sessionHandler, s.connHandler,
origDestConnID, origDestConnID,
retrySrcConnID, retrySrcConnID,
hdr.DestConnectionID, hdr.DestConnectionID,
hdr.SrcConnectionID, hdr.SrcConnectionID,
connID, connID,
s.sessionHandler.GetStatelessResetToken(connID), s.connHandler.GetStatelessResetToken(connID),
s.config, s.config,
s.tlsConf, s.tlsConf,
s.tokenGenerator, s.tokenGenerator,
s.acceptEarlySessions, s.acceptEarlyConns,
tracer, tracer,
tracingID, tracingID,
s.logger, s.logger,
hdr.Version, hdr.Version,
) )
sess.handlePacket(p) conn.handlePacket(p)
return sess return conn
}); !added { }); !added {
return nil return nil
} }
go sess.run() go conn.run()
go s.handleNewSession(sess) go s.handleNewConn(conn)
if sess == nil { if conn == nil {
p.buffer.Release() p.buffer.Release()
return nil return nil
} }
return nil return nil
} }
func (s *baseServer) handleNewSession(sess quicConn) { func (s *baseServer) handleNewConn(conn quicConn) {
sessCtx := sess.Context() sessCtx := conn.Context()
if s.acceptEarlySessions { if s.acceptEarlyConns {
// wait until the early session is ready (or the handshake fails) // wait until the early connection is ready (or the handshake fails)
select { select {
case <-sess.earlySessionReady(): case <-conn.earlySessionReady():
case <-sessCtx.Done(): case <-sessCtx.Done():
return return
} }
} else { } else {
// wait until the handshake is complete (or fails) // wait until the handshake is complete (or fails)
select { select {
case <-sess.HandshakeComplete().Done(): case <-conn.HandshakeComplete().Done():
case <-sessCtx.Done(): case <-sessCtx.Done():
return return
} }
} }
atomic.AddInt32(&s.sessionQueueLen, 1) atomic.AddInt32(&s.connQueueLen, 1)
select { select {
case s.sessionQueue <- sess: case s.connQueue <- conn:
// blocks until the session is accepted // blocks until the connection is accepted
case <-sessCtx.Done(): case <-sessCtx.Done():
atomic.AddInt32(&s.sessionQueueLen, -1) atomic.AddInt32(&s.connQueueLen, -1)
// don't pass sessions that were already closed to Accept() // don't pass connections that were already closed to Accept()
} }
} }
func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error {
// Log the Initial packet now. // Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the session. // If no Retry is sent, the packet will be logged by the connection.
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
if err != nil { if err != nil {

View file

@ -146,7 +146,7 @@ var _ = Describe("Server", func() {
ln, err := Listen(conn, tlsConf, &config) ln, err := Listen(conn, tlsConf, &config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
server := ln.(*baseServer) server := ln.(*baseServer)
Expect(server.sessionHandler).ToNot(BeNil()) Expect(server.connHandler).ToNot(BeNil())
Expect(server.config.Versions).To(Equal(supportedVersions)) Expect(server.config.Versions).To(Equal(supportedVersions))
Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
@ -178,7 +178,7 @@ var _ = Describe("Server", func() {
Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
}) })
Context("server accepting sessions that completed the handshake", func() { Context("server accepting connections that completed the handshake", func() {
var ( var (
serv *baseServer serv *baseServer
phm *MockPacketHandlerManager phm *MockPacketHandlerManager
@ -191,7 +191,7 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serv = ln.(*baseServer) serv = ln.(*baseServer)
phm = NewMockPacketHandlerManager(mockCtrl) phm = NewMockPacketHandlerManager(mockCtrl)
serv.sessionHandler = phm serv.connHandler = phm
}) })
AfterEach(func() { AfterEach(func() {
@ -291,7 +291,7 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("creates a session when the token is accepted", func() { It("creates a connection when the token is accepted", func() {
serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true } serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true }
retryToken, err := serv.tokenGenerator.NewRetryToken( retryToken, err := serv.tokenGenerator.NewRetryToken(
&net.UDPAddr{}, &net.UDPAddr{},
@ -363,7 +363,7 @@ var _ = Describe("Server", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
serv.handlePacket(p) serv.handlePacket(p)
// the Handshake packet is written by the session. // the Handshake packet is written by the connection.
// Make sure there are no Write calls on the packet conn. // Make sure there are no Write calls on the packet conn.
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
close(done) close(done)
@ -576,7 +576,7 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("creates a session, if no Token is required", func() { It("creates a connection, if no Token is required", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
hdr := &wire.Header{ hdr := &wire.Header{
IsLongHeader: true, IsLongHeader: true,
@ -642,7 +642,7 @@ var _ = Describe("Server", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
serv.handlePacket(p) serv.handlePacket(p)
// the Handshake packet is written by the session // the Handshake packet is written by the connection
// make sure there are no Write calls on the packet conn // make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
close(done) close(done)
@ -661,7 +661,7 @@ var _ = Describe("Server", func() {
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
serv.config.AcceptToken = func(net.Addr, *Token) bool { return true } serv.config.AcceptToken = func(net.Addr, *Token) bool { return true }
acceptSession := make(chan struct{}) acceptConn := make(chan struct{})
var counter uint32 // to be used as an atomic, so we query it in Eventually var counter uint32 // to be used as an atomic, so we query it in Eventually
serv.newSession = func( serv.newSession = func(
_ sendConn, _ sendConn,
@ -681,7 +681,7 @@ var _ = Describe("Server", func() {
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicConn { ) quicConn {
<-acceptSession <-acceptConn
atomic.AddUint32(&counter, 1) atomic.AddUint32(&counter, 1)
conn := NewMockQuicConn(mockCtrl) conn := NewMockQuicConn(mockCtrl)
conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
@ -705,7 +705,7 @@ var _ = Describe("Server", func() {
} }
wg.Wait() wg.Wait()
close(acceptSession) close(acceptConn)
Eventually( Eventually(
func() uint32 { return atomic.LoadUint32(&counter) }, func() uint32 { return atomic.LoadUint32(&counter) },
scaleDuration(100*time.Millisecond), scaleDuration(100*time.Millisecond),
@ -713,9 +713,9 @@ var _ = Describe("Server", func() {
Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
}) })
It("only creates a single session for a duplicate Initial", func() { It("only creates a single connection for a duplicate Initial", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
var createdSession bool var createdConn bool
conn := NewMockQuicConn(mockCtrl) conn := NewMockQuicConn(mockCtrl)
serv.newSession = func( serv.newSession = func(
_ sendConn, _ sendConn,
@ -735,14 +735,14 @@ var _ = Describe("Server", func() {
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicConn { ) quicConn {
createdSession = true createdConn = true
return conn return conn
} }
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false) phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false)
Expect(serv.handlePacketImpl(p)).To(BeTrue()) Expect(serv.handlePacketImpl(p)).To(BeTrue())
Expect(createdSession).To(BeFalse()) Expect(createdConn).To(BeFalse())
}) })
It("rejects new connection attempts if the accept queue is full", func() { It("rejects new connection attempts if the accept queue is full", func() {
@ -813,12 +813,12 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("doesn't accept new sessions if they were closed in the mean time", func() { It("doesn't accept new connections if they were closed in the mean time", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
sessionCreated := make(chan struct{}) connCreated := make(chan struct{})
conn := NewMockQuicConn(mockCtrl) conn := NewMockQuicConn(mockCtrl)
serv.newSession = func( serv.newSession = func(
_ sendConn, _ sendConn,
@ -844,7 +844,7 @@ var _ = Describe("Server", func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
conn.EXPECT().HandshakeComplete().Return(ctx) conn.EXPECT().HandshakeComplete().Return(ctx)
close(sessionCreated) close(connCreated)
return conn return conn
} }
@ -858,7 +858,7 @@ var _ = Describe("Server", func() {
serv.handlePacket(p) serv.handlePacket(p)
// make sure there are no Write calls on the packet conn // make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
Eventually(sessionCreated).Should(BeClosed()) Eventually(connCreated).Should(BeClosed())
cancel() cancel()
time.Sleep(scaleDuration(200 * time.Millisecond)) time.Sleep(scaleDuration(200 * time.Millisecond))
@ -878,7 +878,7 @@ var _ = Describe("Server", func() {
}) })
}) })
Context("accepting sessions", func() { Context("accepting connections", func() {
It("returns Accept when an error occurs", func() { It("returns Accept when an error occurs", func() {
testErr := errors.New("test err") testErr := errors.New("test err")
@ -918,7 +918,7 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("accepts new sessions when the handshake completes", func() { It("accepts new connections when the handshake completes", func() {
conn := NewMockQuicConn(mockCtrl) conn := NewMockQuicConn(mockCtrl)
done := make(chan struct{}) done := make(chan struct{})
@ -973,7 +973,7 @@ var _ = Describe("Server", func() {
}) })
}) })
Context("server accepting sessions that haven't completed the handshake", func() { Context("server accepting connections that haven't completed the handshake", func() {
var ( var (
serv *earlyServer serv *earlyServer
phm *MockPacketHandlerManager phm *MockPacketHandlerManager
@ -984,7 +984,7 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serv = ln.(*earlyServer) serv = ln.(*earlyServer)
phm = NewMockPacketHandlerManager(mockCtrl) phm = NewMockPacketHandlerManager(mockCtrl)
serv.sessionHandler = phm serv.connHandler = phm
}) })
AfterEach(func() { AfterEach(func() {
@ -992,7 +992,7 @@ var _ = Describe("Server", func() {
serv.Close() serv.Close()
}) })
It("accepts new sessions when they become ready", func() { It("accepts new connections when they become ready", func() {
conn := NewMockQuicConn(mockCtrl) conn := NewMockQuicConn(mockCtrl)
done := make(chan struct{}) done := make(chan struct{})
@ -1086,7 +1086,7 @@ var _ = Describe("Server", func() {
serv.handlePacket(getInitialWithRandomDestConnID()) serv.handlePacket(getInitialWithRandomDestConnID())
} }
Eventually(func() int32 { return atomic.LoadInt32(&serv.sessionQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize)) Eventually(func() int32 { return atomic.LoadInt32(&serv.connQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize))
// make sure there are no Write calls on the packet conn // make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@ -1106,12 +1106,12 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("doesn't accept new sessions if they were closed in the mean time", func() { It("doesn't accept new connections if they were closed in the mean time", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
sessionCreated := make(chan struct{}) connCreated := make(chan struct{})
conn := NewMockQuicConn(mockCtrl) conn := NewMockQuicConn(mockCtrl)
serv.newSession = func( serv.newSession = func(
_ sendConn, _ sendConn,
@ -1135,7 +1135,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().run() conn.EXPECT().run()
conn.EXPECT().earlySessionReady() conn.EXPECT().earlySessionReady()
conn.EXPECT().Context().Return(ctx) conn.EXPECT().Context().Return(ctx)
close(sessionCreated) close(connCreated)
return conn return conn
} }
@ -1147,7 +1147,7 @@ var _ = Describe("Server", func() {
serv.handlePacket(p) serv.handlePacket(p)
// make sure there are no Write calls on the packet conn // make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
Eventually(sessionCreated).Should(BeClosed()) Eventually(connCreated).Should(BeClosed())
cancel() cancel()
time.Sleep(scaleDuration(200 * time.Millisecond)) time.Sleep(scaleDuration(200 * time.Millisecond))

View file

@ -31,7 +31,7 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
func areSessionsRunning() bool { func areConnsRunning() bool {
var b bytes.Buffer var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1) pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*session).run") return strings.Contains(b.String(), "quic-go.(*session).run")
@ -81,7 +81,7 @@ var _ = Describe("Connection", func() {
} }
BeforeEach(func() { BeforeEach(func() {
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areConnsRunning).Should(BeFalse())
connRunner = NewMockConnRunner(mockCtrl) connRunner = NewMockConnRunner(mockCtrl)
mconn = NewMockSendConn(mockCtrl) mconn = NewMockSendConn(mockCtrl)
@ -123,7 +123,7 @@ var _ = Describe("Connection", func() {
}) })
AfterEach(func() { AfterEach(func() {
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areConnsRunning).Should(BeFalse())
}) })
Context("frame handling", func() { Context("frame handling", func() {
@ -424,7 +424,7 @@ var _ = Describe("Connection", func() {
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
runErr <- sess.run() runErr <- sess.run()
}() }()
Eventually(areSessionsRunning).Should(BeTrue()) Eventually(areConnsRunning).Should(BeTrue())
} }
It("shuts down without error", func() { It("shuts down without error", func() {
@ -451,7 +451,7 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().Close(), tracer.EXPECT().Close(),
) )
sess.shutdown() sess.shutdown()
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areConnsRunning).Should(BeFalse())
Expect(sess.Context().Done()).To(BeClosed()) Expect(sess.Context().Done()).To(BeClosed())
}) })
@ -466,7 +466,7 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().Close() tracer.EXPECT().Close()
sess.shutdown() sess.shutdown()
sess.shutdown() sess.shutdown()
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areConnsRunning).Should(BeFalse())
Expect(sess.Context().Done()).To(BeClosed()) Expect(sess.Context().Done()).To(BeClosed())
}) })
@ -486,7 +486,7 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().Close(), tracer.EXPECT().Close(),
) )
sess.CloseWithError(0x1337, "test error") sess.CloseWithError(0x1337, "test error")
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areConnsRunning).Should(BeFalse())
Expect(sess.Context().Done()).To(BeClosed()) Expect(sess.Context().Done()).To(BeClosed())
}) })
@ -507,7 +507,7 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().Close(), tracer.EXPECT().Close(),
) )
sess.closeLocal(expectedErr) sess.closeLocal(expectedErr)
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areConnsRunning).Should(BeFalse())
Expect(sess.Context().Done()).To(BeClosed()) Expect(sess.Context().Done()).To(BeClosed())
}) })
@ -528,7 +528,7 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().Close(), tracer.EXPECT().Close(),
) )
sess.destroy(testErr) sess.destroy(testErr)
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areConnsRunning).Should(BeFalse())
expectedRunErr = &qerr.TransportError{ expectedRunErr = &qerr.TransportError{
ErrorCode: qerr.InternalError, ErrorCode: qerr.InternalError,
ErrorMessage: testErr.Error(), ErrorMessage: testErr.Error(),
@ -2449,7 +2449,7 @@ var _ = Describe("Client Connection", func() {
}) })
JustBeforeEach(func() { JustBeforeEach(func() {
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areConnsRunning).Should(BeFalse())
mconn = NewMockSendConn(mockCtrl) mconn = NewMockSendConn(mockCtrl)
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes() mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes()