mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 05:07:36 +03:00
refactor the map of sessions into a separate struct
This commit is contained in:
parent
15da47cf98
commit
9c5986945e
6 changed files with 312 additions and 218 deletions
78
mock_session_handler_test.go
Normal file
78
mock_session_handler_test.go
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/lucas-clemente/quic-go (interfaces: SessionHandler)
|
||||||
|
|
||||||
|
// Package quic is a generated GoMock package.
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockSessionHandler is a mock of SessionHandler interface
|
||||||
|
type MockSessionHandler struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockSessionHandlerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockSessionHandlerMockRecorder is the mock recorder for MockSessionHandler
|
||||||
|
type MockSessionHandlerMockRecorder struct {
|
||||||
|
mock *MockSessionHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockSessionHandler creates a new mock instance
|
||||||
|
func NewMockSessionHandler(ctrl *gomock.Controller) *MockSessionHandler {
|
||||||
|
mock := &MockSessionHandler{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockSessionHandlerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (m *MockSessionHandler) EXPECT() *MockSessionHandlerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add mocks base method
|
||||||
|
func (m *MockSessionHandler) Add(arg0 protocol.ConnectionID, arg1 packetHandler) {
|
||||||
|
m.ctrl.Call(m, "Add", arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add indicates an expected call of Add
|
||||||
|
func (mr *MockSessionHandlerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSessionHandler)(nil).Add), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close mocks base method
|
||||||
|
func (m *MockSessionHandler) Close() {
|
||||||
|
m.ctrl.Call(m, "Close")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close indicates an expected call of Close
|
||||||
|
func (mr *MockSessionHandlerMockRecorder) Close() *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSessionHandler)(nil).Close))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get mocks base method
|
||||||
|
func (m *MockSessionHandler) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
|
||||||
|
ret := m.ctrl.Call(m, "Get", arg0)
|
||||||
|
ret0, _ := ret[0].(packetHandler)
|
||||||
|
ret1, _ := ret[1].(bool)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get indicates an expected call of Get
|
||||||
|
func (mr *MockSessionHandlerMockRecorder) Get(arg0 interface{}) *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSessionHandler)(nil).Get), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove mocks base method
|
||||||
|
func (m *MockSessionHandler) Remove(arg0 protocol.ConnectionID) {
|
||||||
|
m.ctrl.Call(m, "Remove", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove indicates an expected call of Remove
|
||||||
|
func (mr *MockSessionHandlerMockRecorder) Remove(arg0 interface{}) *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSessionHandler)(nil).Remove), arg0)
|
||||||
|
}
|
|
@ -13,5 +13,6 @@ package quic
|
||||||
//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD"
|
//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD"
|
||||||
//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner"
|
//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner"
|
||||||
//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler PacketHandler"
|
//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler PacketHandler"
|
||||||
|
//go:generate sh -c "./mockgen_private.sh quic mock_session_handler_test.go github.com/lucas-clemente/quic-go sessionHandler SessionHandler"
|
||||||
//go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'"
|
//go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'"
|
||||||
//go:generate sh -c "goimports -w mock*_test.go"
|
//go:generate sh -c "goimports -w mock*_test.go"
|
||||||
|
|
88
server.go
88
server.go
|
@ -6,7 +6,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
|
@ -42,6 +41,13 @@ func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectio
|
||||||
|
|
||||||
var _ sessionRunner = &runner{}
|
var _ sessionRunner = &runner{}
|
||||||
|
|
||||||
|
type sessionHandler interface {
|
||||||
|
Add(protocol.ConnectionID, packetHandler)
|
||||||
|
Get(protocol.ConnectionID) (packetHandler, bool)
|
||||||
|
Remove(protocol.ConnectionID)
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
// A Listener of QUIC
|
// A Listener of QUIC
|
||||||
type server struct {
|
type server struct {
|
||||||
tlsConf *tls.Config
|
tlsConf *tls.Config
|
||||||
|
@ -55,9 +61,7 @@ type server struct {
|
||||||
certChain crypto.CertChain
|
certChain crypto.CertChain
|
||||||
scfg *handshake.ServerConfig
|
scfg *handshake.ServerConfig
|
||||||
|
|
||||||
sessionsMutex sync.RWMutex
|
sessionHandler sessionHandler
|
||||||
sessions map[string] /* string(ConnectionID)*/ packetHandler
|
|
||||||
closed bool
|
|
||||||
|
|
||||||
serverError error
|
serverError error
|
||||||
|
|
||||||
|
@ -65,9 +69,8 @@ type server struct {
|
||||||
errorChan chan struct{}
|
errorChan chan struct{}
|
||||||
|
|
||||||
sessionRunner sessionRunner
|
sessionRunner sessionRunner
|
||||||
// set as members, so they can be set in the tests
|
// set as a member, so they can be set in the tests
|
||||||
newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (packetHandler, error)
|
newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (packetHandler, error)
|
||||||
deleteClosedSessionsAfter time.Duration
|
|
||||||
|
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
@ -120,9 +123,8 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
|
||||||
config: config,
|
config: config,
|
||||||
certChain: certChain,
|
certChain: certChain,
|
||||||
scfg: scfg,
|
scfg: scfg,
|
||||||
sessions: map[string]packetHandler{},
|
|
||||||
newSession: newSession,
|
newSession: newSession,
|
||||||
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
|
sessionHandler: newSessionMap(),
|
||||||
sessionQueue: make(chan Session, 5),
|
sessionQueue: make(chan Session, 5),
|
||||||
errorChan: make(chan struct{}),
|
errorChan: make(chan struct{}),
|
||||||
supportsTLS: supportsTLS,
|
supportsTLS: supportsTLS,
|
||||||
|
@ -142,7 +144,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
|
||||||
func (s *server) setup() {
|
func (s *server) setup() {
|
||||||
s.sessionRunner = &runner{
|
s.sessionRunner = &runner{
|
||||||
onHandshakeCompleteImpl: func(sess packetHandler) { s.sessionQueue <- sess },
|
onHandshakeCompleteImpl: func(sess packetHandler) { s.sessionQueue <- sess },
|
||||||
removeConnectionIDImpl: s.removeConnection,
|
removeConnectionIDImpl: s.sessionHandler.Remove,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,13 +167,13 @@ func (s *server) setupTLS() error {
|
||||||
case tlsSession := <-sessionChan:
|
case tlsSession := <-sessionChan:
|
||||||
connID := tlsSession.connID
|
connID := tlsSession.connID
|
||||||
sess := tlsSession.sess
|
sess := tlsSession.sess
|
||||||
s.sessionsMutex.Lock()
|
if _, ok := s.sessionHandler.Get(connID); ok { // drop this session if it already exists
|
||||||
if _, ok := s.sessions[string(connID)]; ok { // drop this session if it already exists
|
|
||||||
s.sessionsMutex.Unlock()
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.sessions[string(connID)] = sess
|
// TODO(#1003): There's a race condition here.
|
||||||
s.sessionsMutex.Unlock()
|
// If another connection with the same conn ID is added between Get() and Add(), it would be overwritten.
|
||||||
|
// We can avoid this be using server-chosen connection IDs.
|
||||||
|
s.sessionHandler.Add(connID, sess)
|
||||||
go sess.run()
|
go sess.run()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -288,27 +290,7 @@ func (s *server) Accept() (Session, error) {
|
||||||
|
|
||||||
// Close the server
|
// Close the server
|
||||||
func (s *server) Close() error {
|
func (s *server) Close() error {
|
||||||
s.sessionsMutex.Lock()
|
s.sessionHandler.Close()
|
||||||
if s.closed {
|
|
||||||
s.sessionsMutex.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
s.closed = true
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for _, session := range s.sessions {
|
|
||||||
if session != nil {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(sess packetHandler) {
|
|
||||||
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
|
|
||||||
_ = sess.Close(nil)
|
|
||||||
wg.Done()
|
|
||||||
}(session)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.sessionsMutex.Unlock()
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
err := s.conn.Close()
|
err := s.conn.Close()
|
||||||
<-s.errorChan // wait for serve() to return
|
<-s.errorChan // wait for serve() to return
|
||||||
return err
|
return err
|
||||||
|
@ -359,10 +341,7 @@ func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sessionsMutex.RLock()
|
session, sessionKnown := s.sessionHandler.Get(hdr.DestConnectionID)
|
||||||
session, sessionKnown := s.sessions[string(hdr.DestConnectionID)]
|
|
||||||
s.sessionsMutex.RUnlock()
|
|
||||||
|
|
||||||
if sessionKnown && session == nil {
|
if sessionKnown && session == nil {
|
||||||
// Late packet for closed session
|
// Late packet for closed session
|
||||||
return nil
|
return nil
|
||||||
|
@ -382,21 +361,18 @@ func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
|
func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
|
||||||
s.sessionsMutex.RLock()
|
|
||||||
session, sessionKnown := s.sessions[string(hdr.DestConnectionID)]
|
|
||||||
s.sessionsMutex.RUnlock()
|
|
||||||
|
|
||||||
if sessionKnown && session == nil {
|
|
||||||
// Late packet for closed session
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ignore all Public Reset packets
|
// ignore all Public Reset packets
|
||||||
if hdr.ResetFlag {
|
if hdr.ResetFlag {
|
||||||
s.logger.Infof("Received unexpected Public Reset for connection %s.", hdr.DestConnectionID)
|
s.logger.Infof("Received unexpected Public Reset for connection %s.", hdr.DestConnectionID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
session, sessionKnown := s.sessionHandler.Get(hdr.DestConnectionID)
|
||||||
|
if sessionKnown && session == nil {
|
||||||
|
// Late packet for closed session
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset
|
// If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset
|
||||||
// This should only happen after a server restart, when we still receive packets for connections that we lost the state for.
|
// This should only happen after a server restart, when we still receive packets for connections that we lost the state for.
|
||||||
if !sessionKnown && !hdr.VersionFlag {
|
if !sessionKnown && !hdr.VersionFlag {
|
||||||
|
@ -450,9 +426,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.sessionsMutex.Lock()
|
s.sessionHandler.Add(hdr.DestConnectionID, session)
|
||||||
s.sessions[string(hdr.DestConnectionID)] = session
|
|
||||||
s.sessionsMutex.Unlock()
|
|
||||||
|
|
||||||
go session.run()
|
go session.run()
|
||||||
}
|
}
|
||||||
|
@ -465,15 +439,3 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) removeConnection(id protocol.ConnectionID) {
|
|
||||||
s.sessionsMutex.Lock()
|
|
||||||
s.sessions[string(id)] = nil
|
|
||||||
s.sessionsMutex.Unlock()
|
|
||||||
|
|
||||||
time.AfterFunc(s.deleteClosedSessionsAfter, func() {
|
|
||||||
s.sessionsMutex.Lock()
|
|
||||||
delete(s.sessions, string(id))
|
|
||||||
s.sessionsMutex.Unlock()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
187
server_test.go
187
server_test.go
|
@ -84,9 +84,11 @@ var _ = Describe("Server", func() {
|
||||||
firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID)
|
firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID)
|
||||||
connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
||||||
sessions = make([]*MockPacketHandler, 0)
|
sessions = make([]*MockPacketHandler, 0)
|
||||||
|
sessionHandler *MockSessionHandler
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
sessionHandler = NewMockSessionHandler(mockCtrl)
|
||||||
newMockSession := func(
|
newMockSession := func(
|
||||||
_ connection,
|
_ connection,
|
||||||
runner sessionRunner,
|
runner sessionRunner,
|
||||||
|
@ -105,7 +107,7 @@ var _ = Describe("Server", func() {
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
serv = &server{
|
serv = &server{
|
||||||
sessions: make(map[string]packetHandler),
|
sessionHandler: sessionHandler,
|
||||||
newSession: newMockSession,
|
newSession: newMockSession,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
config: config,
|
config: config,
|
||||||
|
@ -139,11 +141,13 @@ var _ = Describe("Server", func() {
|
||||||
run := make(chan struct{})
|
run := make(chan struct{})
|
||||||
s.EXPECT().run().Do(func() { close(run) })
|
s.EXPECT().run().Do(func() { close(run) })
|
||||||
sessions = append(sessions, s)
|
sessions = append(sessions, s)
|
||||||
|
|
||||||
|
sessionHandler.EXPECT().Get(connID)
|
||||||
|
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
|
||||||
|
Expect(sess.(*mockSession).connID).To(Equal(connID))
|
||||||
|
})
|
||||||
err := serv.handlePacket(nil, firstPacket)
|
err := serv.handlePacket(nil, firstPacket)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(serv.sessions).To(HaveLen(1))
|
|
||||||
sess := serv.sessions[string(connID)].(*mockSession)
|
|
||||||
Expect(sess.connID).To(Equal(connID))
|
|
||||||
Eventually(run).Should(BeClosed())
|
Eventually(run).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -154,51 +158,38 @@ var _ = Describe("Server", func() {
|
||||||
sess.EXPECT().run().Do(func() { close(run) })
|
sess.EXPECT().run().Do(func() { close(run) })
|
||||||
err := serv.setupTLS()
|
err := serv.setupTLS()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
sessionHandler.EXPECT().Get(connID)
|
||||||
|
sessionHandler.EXPECT().Add(connID, sess)
|
||||||
serv.serverTLS.sessionChan <- tlsSession{
|
serv.serverTLS.sessionChan <- tlsSession{
|
||||||
connID: connID,
|
connID: connID,
|
||||||
sess: sess,
|
sess: sess,
|
||||||
}
|
}
|
||||||
Eventually(func() packetHandler {
|
|
||||||
serv.sessionsMutex.Lock()
|
|
||||||
defer serv.sessionsMutex.Unlock()
|
|
||||||
return serv.sessions[string(connID)]
|
|
||||||
}).Should(Equal(sess))
|
|
||||||
Eventually(run).Should(BeClosed())
|
Eventually(run).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("only accepts one new TLS sessions for one connection ID", func() {
|
It("only accepts one new TLS sessions for one connection ID", func() {
|
||||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||||
run := make(chan struct{})
|
|
||||||
sess := NewMockPacketHandler(mockCtrl)
|
sess := NewMockPacketHandler(mockCtrl)
|
||||||
sess.EXPECT().run().Do(func() { close(run) })
|
|
||||||
sess2 := NewMockPacketHandler(mockCtrl)
|
|
||||||
err := serv.setupTLS()
|
err := serv.setupTLS()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
sessionHandler.EXPECT().Get(connID).Return(NewMockPacketHandler(mockCtrl), true).Do(func(protocol.ConnectionID) {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
|
// don't EXPECT any calls to sessionHandler.Add
|
||||||
serv.serverTLS.sessionChan <- tlsSession{
|
serv.serverTLS.sessionChan <- tlsSession{
|
||||||
connID: connID,
|
connID: connID,
|
||||||
sess: sess,
|
sess: sess,
|
||||||
}
|
}
|
||||||
Eventually(func() packetHandler {
|
Eventually(done).Should(BeClosed())
|
||||||
serv.sessionsMutex.Lock()
|
|
||||||
defer serv.sessionsMutex.Unlock()
|
|
||||||
return serv.sessions[string(connID)]
|
|
||||||
}).Should(Equal(sess))
|
|
||||||
serv.serverTLS.sessionChan <- tlsSession{
|
|
||||||
connID: connID,
|
|
||||||
sess: sess2,
|
|
||||||
}
|
|
||||||
Consistently(func() packetHandler {
|
|
||||||
serv.sessionsMutex.Lock()
|
|
||||||
defer serv.sessionsMutex.Unlock()
|
|
||||||
return serv.sessions[string(connID)]
|
|
||||||
}).Should(Equal(sess))
|
|
||||||
Eventually(run).Should(BeClosed())
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("accepts a session once the connection it is forward secure", func() {
|
It("accepts a session once the connection it is forward secure", func() {
|
||||||
s := NewMockPacketHandler(mockCtrl)
|
s := NewMockPacketHandler(mockCtrl)
|
||||||
s.EXPECT().handlePacket(gomock.Any())
|
s.EXPECT().handlePacket(gomock.Any())
|
||||||
s.EXPECT().run()
|
run := make(chan struct{})
|
||||||
|
s.EXPECT().run().Do(func() { close(run) })
|
||||||
sessions = append(sessions, s)
|
sessions = append(sessions, s)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -208,17 +199,19 @@ var _ = Describe("Server", func() {
|
||||||
Expect(sess.(*mockSession).connID).To(Equal(connID))
|
Expect(sess.(*mockSession).connID).To(Equal(connID))
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
sessionHandler.EXPECT().Get(connID)
|
||||||
|
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
|
||||||
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
|
sess.(*mockSession).runner.onHandshakeComplete(sess)
|
||||||
|
})
|
||||||
err := serv.handlePacket(nil, firstPacket)
|
err := serv.handlePacket(nil, firstPacket)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(serv.sessions).To(HaveLen(1))
|
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
|
||||||
sess := serv.sessions[string(connID)].(*mockSession)
|
|
||||||
sess.runner.onHandshakeComplete(sess)
|
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
|
Eventually(run).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't accept sessions that error during the handshake", func() {
|
It("doesn't accept sessions that error during the handshake", func() {
|
||||||
run := make(chan error)
|
run := make(chan error, 1)
|
||||||
sess := NewMockPacketHandler(mockCtrl)
|
sess := NewMockPacketHandler(mockCtrl)
|
||||||
sess.EXPECT().handlePacket(gomock.Any())
|
sess.EXPECT().handlePacket(gomock.Any())
|
||||||
sess.EXPECT().run().DoAndReturn(func() error { return <-run })
|
sess.EXPECT().run().DoAndReturn(func() error { return <-run })
|
||||||
|
@ -229,79 +222,44 @@ var _ = Describe("Server", func() {
|
||||||
serv.Accept()
|
serv.Accept()
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
sessionHandler.EXPECT().Get(connID)
|
||||||
|
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
|
||||||
|
run <- errors.New("handshake error")
|
||||||
|
})
|
||||||
err := serv.handlePacket(nil, firstPacket)
|
err := serv.handlePacket(nil, firstPacket)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(serv.sessions).To(HaveLen(1))
|
|
||||||
run <- errors.New("handshake error")
|
|
||||||
serv.sessions[string(connID)].(*mockSession).runner.removeConnectionID(connID)
|
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
// make the go routine return
|
// make the go routine return
|
||||||
|
sessionHandler.EXPECT().Close()
|
||||||
close(serv.errorChan)
|
close(serv.errorChan)
|
||||||
serv.Close()
|
serv.Close()
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("assigns packets to existing sessions", func() {
|
It("assigns packets to existing sessions", func() {
|
||||||
run := make(chan struct{})
|
|
||||||
sess := NewMockPacketHandler(mockCtrl)
|
sess := NewMockPacketHandler(mockCtrl)
|
||||||
sess.EXPECT().handlePacket(gomock.Any()).Times(2)
|
|
||||||
sess.EXPECT().run().Do(func() { close(run) })
|
|
||||||
sessions = append(sessions, sess)
|
|
||||||
|
|
||||||
err := serv.handlePacket(nil, firstPacket)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Eventually(run).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("deletes sessions", func() {
|
|
||||||
serv.deleteClosedSessionsAfter = time.Second // make sure that the nil value for the closed session doesn't get deleted in this test
|
|
||||||
serv.sessions[string(connID)] = &mockSession{}
|
|
||||||
serv.removeConnection(connID)
|
|
||||||
// The server should now have closed the session, leaving a nil value in the sessions map
|
|
||||||
Consistently(func() map[string]packetHandler { return serv.sessions }).Should(HaveLen(1))
|
|
||||||
Expect(serv.sessions[string(connID)]).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("deletes nil session entries after a wait time", func() {
|
|
||||||
serv.deleteClosedSessionsAfter = 25 * time.Millisecond
|
|
||||||
serv.sessions[string(connID)] = &mockSession{}
|
|
||||||
// make session.run() return
|
|
||||||
serv.removeConnection(connID)
|
|
||||||
Eventually(func() bool {
|
|
||||||
serv.sessionsMutex.Lock()
|
|
||||||
_, ok := serv.sessions[string(connID)]
|
|
||||||
serv.sessionsMutex.Unlock()
|
|
||||||
return ok
|
|
||||||
}).Should(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes sessions and the connection when Close is called", func() {
|
|
||||||
run := make(chan struct{})
|
|
||||||
sess := NewMockPacketHandler(mockCtrl)
|
|
||||||
sess.EXPECT().Close(nil)
|
|
||||||
sess.EXPECT().handlePacket(gomock.Any())
|
sess.EXPECT().handlePacket(gomock.Any())
|
||||||
sess.EXPECT().run().Do(func() { close(run) })
|
|
||||||
sessions = append(sessions, sess)
|
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||||
|
err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("closes the sessionHandler and the connection when Close is called", func() {
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
serv.serve()
|
serv.serve()
|
||||||
}()
|
}()
|
||||||
err := serv.handlePacket(nil, firstPacket)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Eventually(run).Should(BeClosed())
|
|
||||||
// close the server
|
// close the server
|
||||||
|
sessionHandler.EXPECT().Close().AnyTimes()
|
||||||
Expect(serv.Close()).To(Succeed())
|
Expect(serv.Close()).To(Succeed())
|
||||||
Expect(conn.closed).To(BeTrue())
|
Expect(conn.closed).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("ignores packets for closed sessions", func() {
|
It("ignores packets for closed sessions", func() {
|
||||||
serv.sessions[string(connID)] = nil
|
sessionHandler.EXPECT().Get(connID).Return(nil, true)
|
||||||
err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
|
err := serv.handlePacket(nil, firstPacket)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(serv.sessions).To(HaveLen(1))
|
|
||||||
Expect(serv.sessions[string(connID)]).To(BeNil())
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("works if no quic.Config is given", func(done Done) {
|
It("works if no quic.Config is given", func(done Done) {
|
||||||
|
@ -327,49 +285,32 @@ var _ = Describe("Server", func() {
|
||||||
Eventually(func() bool { return returned }).Should(BeTrue())
|
Eventually(func() bool { return returned }).Should(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors when encountering a connection error", func(done Done) {
|
It("errors when encountering a connection error", func() {
|
||||||
testErr := errors.New("connection error")
|
testErr := errors.New("connection error")
|
||||||
conn.readErr = testErr
|
conn.readErr = testErr
|
||||||
go serv.serve()
|
sessionHandler.EXPECT().Close()
|
||||||
_, err := serv.Accept()
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
Expect(serv.Close()).To(Succeed())
|
|
||||||
close(done)
|
|
||||||
}, 0.5)
|
|
||||||
|
|
||||||
It("closes all sessions when encountering a connection error", func() {
|
|
||||||
sess := NewMockPacketHandler(mockCtrl)
|
|
||||||
sess.EXPECT().Close(nil)
|
|
||||||
serv.sessions[string(connID)] = sess
|
|
||||||
|
|
||||||
conn.readErr = errors.New("connection error")
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
serv.serve()
|
serv.serve()
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
Expect(serv.Close()).To(Succeed())
|
_, err := serv.Accept()
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("ignores delayed packets with mismatching versions", func() {
|
It("ignores delayed packets with mismatching versions", func() {
|
||||||
run := make(chan struct{})
|
|
||||||
sess := NewMockPacketHandler(mockCtrl)
|
sess := NewMockPacketHandler(mockCtrl)
|
||||||
sess.EXPECT().handlePacket(gomock.Any()) // only called once
|
// don't EXPECT any handlePacket() calls to this session
|
||||||
sess.EXPECT().run().Do(func() { close(run) })
|
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||||
sessions = append(sessions, sess)
|
|
||||||
|
|
||||||
err := serv.handlePacket(nil, firstPacket)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Eventually(run).Should(BeClosed())
|
|
||||||
|
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
// add an unsupported version
|
// add an unsupported version
|
||||||
data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
||||||
utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]+1))
|
utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]+1))
|
||||||
data = append(append(data, b.Bytes()...), 0x01)
|
data = append(append(data, b.Bytes()...), 0x01)
|
||||||
err = serv.handlePacket(nil, data)
|
err := serv.handlePacket(nil, data)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
// if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn
|
// if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn
|
||||||
Expect(conn.dataWritten.Bytes()).To(BeEmpty())
|
Expect(conn.dataWritten.Bytes()).To(BeEmpty())
|
||||||
|
@ -397,21 +338,12 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("cuts packets at the payload length", func() {
|
It("cuts packets at the payload length", func() {
|
||||||
run := make(chan struct{})
|
|
||||||
sess := NewMockPacketHandler(mockCtrl)
|
sess := NewMockPacketHandler(mockCtrl)
|
||||||
gomock.InOrder(
|
|
||||||
sess.EXPECT().handlePacket(gomock.Any()), // first packet
|
|
||||||
sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) {
|
sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) {
|
||||||
Expect(packet.data).To(HaveLen(123))
|
Expect(packet.data).To(HaveLen(123))
|
||||||
}),
|
})
|
||||||
)
|
|
||||||
sess.EXPECT().run().Do(func() { close(run) })
|
|
||||||
sessions = append(sessions, sess)
|
|
||||||
|
|
||||||
serv.supportsTLS = true
|
serv.supportsTLS = true
|
||||||
err := serv.handlePacket(nil, firstPacket)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Eventually(run).Should(BeClosed())
|
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
hdr := &wire.Header{
|
hdr := &wire.Header{
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
|
@ -422,7 +354,8 @@ var _ = Describe("Server", func() {
|
||||||
Version: versionIETFFrames,
|
Version: versionIETFFrames,
|
||||||
}
|
}
|
||||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||||
err = serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
|
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||||
|
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -443,18 +376,8 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("ignores Public Resets", func() {
|
It("ignores Public Resets", func() {
|
||||||
run := make(chan struct{})
|
err := serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337))
|
||||||
sess := NewMockPacketHandler(mockCtrl)
|
|
||||||
sess.EXPECT().handlePacket(gomock.Any()) // called only once
|
|
||||||
sess.EXPECT().run().Do(func() { close(run) })
|
|
||||||
sessions = append(sessions, sess)
|
|
||||||
err := serv.handlePacket(nil, firstPacket)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(serv.sessions).To(HaveLen(1))
|
|
||||||
Eventually(run).Should(BeClosed())
|
|
||||||
err = serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(serv.sessions).To(HaveLen(1))
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {
|
It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {
|
||||||
|
@ -470,6 +393,7 @@ var _ = Describe("Server", func() {
|
||||||
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
|
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
|
||||||
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
|
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
|
||||||
serv.conn = conn
|
serv.conn = conn
|
||||||
|
sessionHandler.EXPECT().Get(connID)
|
||||||
err := serv.handlePacket(nil, b.Bytes())
|
err := serv.handlePacket(nil, b.Bytes())
|
||||||
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
|
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -487,6 +411,7 @@ var _ = Describe("Server", func() {
|
||||||
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
|
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
|
||||||
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small
|
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small
|
||||||
serv.conn = conn
|
serv.conn = conn
|
||||||
|
sessionHandler.EXPECT().Get(connID)
|
||||||
err := serv.handlePacket(udpAddr, b.Bytes())
|
err := serv.handlePacket(udpAddr, b.Bytes())
|
||||||
Expect(err).To(MatchError("dropping small packet with unknown version"))
|
Expect(err).To(MatchError("dropping small packet with unknown version"))
|
||||||
Expect(conn.dataWritten.Len()).Should(BeZero())
|
Expect(conn.dataWritten.Len()).Should(BeZero())
|
||||||
|
@ -506,8 +431,7 @@ var _ = Describe("Server", func() {
|
||||||
ln, err := Listen(conn, &tls.Config{}, &config)
|
ln, err := Listen(conn, &tls.Config{}, &config)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
server := ln.(*server)
|
server := ln.(*server)
|
||||||
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
|
Expect(server.sessionHandler).ToNot(BeNil())
|
||||||
Expect(server.sessions).ToNot(BeNil())
|
|
||||||
Expect(server.scfg).ToNot(BeNil())
|
Expect(server.scfg).ToNot(BeNil())
|
||||||
Expect(server.config.Versions).To(Equal(supportedVersions))
|
Expect(server.config.Versions).To(Equal(supportedVersions))
|
||||||
Expect(server.config.HandshakeTimeout).To(Equal(1337 * time.Hour))
|
Expect(server.config.HandshakeTimeout).To(Equal(1337 * time.Hour))
|
||||||
|
@ -692,7 +616,6 @@ var _ = Describe("Server", func() {
|
||||||
Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero())
|
Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero())
|
||||||
Expect(conn.dataWrittenTo).To(Equal(udpAddr))
|
Expect(conn.dataWrittenTo).To(Equal(udpAddr))
|
||||||
Expect(conn.dataWritten.Bytes()[0] & 0x02).ToNot(BeZero()) // check that the ResetFlag is set
|
Expect(conn.dataWritten.Bytes()[0] & 0x02).ToNot(BeZero()) // check that the ResetFlag is set
|
||||||
Expect(ln.(*server).sessions).To(BeEmpty())
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
74
session_map.go
Normal file
74
session_map.go
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sessionMap struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
|
||||||
|
sessions map[string] /* string(ConnectionID)*/ packetHandler
|
||||||
|
closed bool
|
||||||
|
|
||||||
|
deleteClosedSessionsAfter time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ sessionHandler = &sessionMap{}
|
||||||
|
|
||||||
|
func newSessionMap() sessionHandler {
|
||||||
|
return &sessionMap{
|
||||||
|
sessions: make(map[string]packetHandler),
|
||||||
|
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sessionMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
|
||||||
|
h.mutex.RLock()
|
||||||
|
sess, ok := h.sessions[string(id)]
|
||||||
|
h.mutex.RUnlock()
|
||||||
|
return sess, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sessionMap) Add(id protocol.ConnectionID, sess packetHandler) {
|
||||||
|
h.mutex.Lock()
|
||||||
|
h.sessions[string(id)] = sess
|
||||||
|
h.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sessionMap) Remove(id protocol.ConnectionID) {
|
||||||
|
h.mutex.Lock()
|
||||||
|
h.sessions[string(id)] = nil
|
||||||
|
h.mutex.Unlock()
|
||||||
|
|
||||||
|
time.AfterFunc(h.deleteClosedSessionsAfter, func() {
|
||||||
|
h.mutex.Lock()
|
||||||
|
delete(h.sessions, string(id))
|
||||||
|
h.mutex.Unlock()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sessionMap) Close() {
|
||||||
|
h.mutex.Lock()
|
||||||
|
if h.closed {
|
||||||
|
h.mutex.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.closed = true
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, session := range h.sessions {
|
||||||
|
if session != nil {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(sess packetHandler) {
|
||||||
|
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
|
||||||
|
_ = sess.Close(nil)
|
||||||
|
wg.Done()
|
||||||
|
}(session)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.mutex.Unlock()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
56
session_map_test.go
Normal file
56
session_map_test.go
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Session Handler", func() {
|
||||||
|
var handler *sessionMap
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
handler = newSessionMap().(*sessionMap)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("adds and gets", func() {
|
||||||
|
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||||
|
sess := &mockSession{}
|
||||||
|
handler.Add(connID, sess)
|
||||||
|
session, ok := handler.Get(connID)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(session).To(Equal(sess))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("deletes", func() {
|
||||||
|
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||||
|
handler.Add(connID, &mockSession{})
|
||||||
|
handler.Remove(connID)
|
||||||
|
session, ok := handler.Get(connID)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(session).To(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("deletes nil session entries after a wait time", func() {
|
||||||
|
handler.deleteClosedSessionsAfter = 25 * time.Millisecond
|
||||||
|
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||||
|
handler.Add(connID, &mockSession{})
|
||||||
|
handler.Remove(connID)
|
||||||
|
Eventually(func() bool {
|
||||||
|
_, ok := handler.Get(connID)
|
||||||
|
return ok
|
||||||
|
}).Should(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("closes", func() {
|
||||||
|
sess1 := NewMockPacketHandler(mockCtrl)
|
||||||
|
sess1.EXPECT().Close(nil)
|
||||||
|
sess2 := NewMockPacketHandler(mockCtrl)
|
||||||
|
sess2.EXPECT().Close(nil)
|
||||||
|
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1)
|
||||||
|
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2)
|
||||||
|
handler.Close()
|
||||||
|
})
|
||||||
|
})
|
Loading…
Add table
Add a link
Reference in a new issue