refactor the map of sessions into a separate struct

This commit is contained in:
Marten Seemann 2018-05-09 08:21:05 +09:00
parent 15da47cf98
commit 9c5986945e
6 changed files with 312 additions and 218 deletions

View file

@ -0,0 +1,78 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: SessionHandler)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockSessionHandler is a mock of SessionHandler interface
type MockSessionHandler struct {
ctrl *gomock.Controller
recorder *MockSessionHandlerMockRecorder
}
// MockSessionHandlerMockRecorder is the mock recorder for MockSessionHandler
type MockSessionHandlerMockRecorder struct {
mock *MockSessionHandler
}
// NewMockSessionHandler creates a new mock instance
func NewMockSessionHandler(ctrl *gomock.Controller) *MockSessionHandler {
mock := &MockSessionHandler{ctrl: ctrl}
mock.recorder = &MockSessionHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockSessionHandler) EXPECT() *MockSessionHandlerMockRecorder {
return m.recorder
}
// Add mocks base method
func (m *MockSessionHandler) Add(arg0 protocol.ConnectionID, arg1 packetHandler) {
m.ctrl.Call(m, "Add", arg0, arg1)
}
// Add indicates an expected call of Add
func (mr *MockSessionHandlerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSessionHandler)(nil).Add), arg0, arg1)
}
// Close mocks base method
func (m *MockSessionHandler) Close() {
m.ctrl.Call(m, "Close")
}
// Close indicates an expected call of Close
func (mr *MockSessionHandlerMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSessionHandler)(nil).Close))
}
// Get mocks base method
func (m *MockSessionHandler) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
ret := m.ctrl.Call(m, "Get", arg0)
ret0, _ := ret[0].(packetHandler)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// Get indicates an expected call of Get
func (mr *MockSessionHandlerMockRecorder) Get(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSessionHandler)(nil).Get), arg0)
}
// Remove mocks base method
func (m *MockSessionHandler) Remove(arg0 protocol.ConnectionID) {
m.ctrl.Call(m, "Remove", arg0)
}
// Remove indicates an expected call of Remove
func (mr *MockSessionHandlerMockRecorder) Remove(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSessionHandler)(nil).Remove), arg0)
}

View file

@ -13,5 +13,6 @@ package quic
//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD"
//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner"
//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler PacketHandler"
//go:generate sh -c "./mockgen_private.sh quic mock_session_handler_test.go github.com/lucas-clemente/quic-go sessionHandler SessionHandler"
//go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'"
//go:generate sh -c "goimports -w mock*_test.go"

View file

@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
@ -42,6 +41,13 @@ func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectio
var _ sessionRunner = &runner{}
type sessionHandler interface {
Add(protocol.ConnectionID, packetHandler)
Get(protocol.ConnectionID) (packetHandler, bool)
Remove(protocol.ConnectionID)
Close()
}
// A Listener of QUIC
type server struct {
tlsConf *tls.Config
@ -55,9 +61,7 @@ type server struct {
certChain crypto.CertChain
scfg *handshake.ServerConfig
sessionsMutex sync.RWMutex
sessions map[string] /* string(ConnectionID)*/ packetHandler
closed bool
sessionHandler sessionHandler
serverError error
@ -65,9 +69,8 @@ type server struct {
errorChan chan struct{}
sessionRunner sessionRunner
// set as members, so they can be set in the tests
// set as a member, so they can be set in the tests
newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (packetHandler, error)
deleteClosedSessionsAfter time.Duration
logger utils.Logger
}
@ -120,9 +123,8 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
config: config,
certChain: certChain,
scfg: scfg,
sessions: map[string]packetHandler{},
newSession: newSession,
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
sessionHandler: newSessionMap(),
sessionQueue: make(chan Session, 5),
errorChan: make(chan struct{}),
supportsTLS: supportsTLS,
@ -142,7 +144,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
func (s *server) setup() {
s.sessionRunner = &runner{
onHandshakeCompleteImpl: func(sess packetHandler) { s.sessionQueue <- sess },
removeConnectionIDImpl: s.removeConnection,
removeConnectionIDImpl: s.sessionHandler.Remove,
}
}
@ -165,13 +167,13 @@ func (s *server) setupTLS() error {
case tlsSession := <-sessionChan:
connID := tlsSession.connID
sess := tlsSession.sess
s.sessionsMutex.Lock()
if _, ok := s.sessions[string(connID)]; ok { // drop this session if it already exists
s.sessionsMutex.Unlock()
if _, ok := s.sessionHandler.Get(connID); ok { // drop this session if it already exists
continue
}
s.sessions[string(connID)] = sess
s.sessionsMutex.Unlock()
// TODO(#1003): There's a race condition here.
// If another connection with the same conn ID is added between Get() and Add(), it would be overwritten.
// We can avoid this be using server-chosen connection IDs.
s.sessionHandler.Add(connID, sess)
go sess.run()
}
}
@ -288,27 +290,7 @@ func (s *server) Accept() (Session, error) {
// Close the server
func (s *server) Close() error {
s.sessionsMutex.Lock()
if s.closed {
s.sessionsMutex.Unlock()
return nil
}
s.closed = true
var wg sync.WaitGroup
for _, session := range s.sessions {
if session != nil {
wg.Add(1)
go func(sess packetHandler) {
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
_ = sess.Close(nil)
wg.Done()
}(session)
}
}
s.sessionsMutex.Unlock()
wg.Wait()
s.sessionHandler.Close()
err := s.conn.Close()
<-s.errorChan // wait for serve() to return
return err
@ -359,10 +341,7 @@ func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot
}
}
s.sessionsMutex.RLock()
session, sessionKnown := s.sessions[string(hdr.DestConnectionID)]
s.sessionsMutex.RUnlock()
session, sessionKnown := s.sessionHandler.Get(hdr.DestConnectionID)
if sessionKnown && session == nil {
// Late packet for closed session
return nil
@ -382,21 +361,18 @@ func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot
}
func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
s.sessionsMutex.RLock()
session, sessionKnown := s.sessions[string(hdr.DestConnectionID)]
s.sessionsMutex.RUnlock()
if sessionKnown && session == nil {
// Late packet for closed session
return nil
}
// ignore all Public Reset packets
if hdr.ResetFlag {
s.logger.Infof("Received unexpected Public Reset for connection %s.", hdr.DestConnectionID)
return nil
}
session, sessionKnown := s.sessionHandler.Get(hdr.DestConnectionID)
if sessionKnown && session == nil {
// Late packet for closed session
return nil
}
// If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset
// This should only happen after a server restart, when we still receive packets for connections that we lost the state for.
if !sessionKnown && !hdr.VersionFlag {
@ -450,9 +426,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
if err != nil {
return err
}
s.sessionsMutex.Lock()
s.sessions[string(hdr.DestConnectionID)] = session
s.sessionsMutex.Unlock()
s.sessionHandler.Add(hdr.DestConnectionID, session)
go session.run()
}
@ -465,15 +439,3 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
})
return nil
}
func (s *server) removeConnection(id protocol.ConnectionID) {
s.sessionsMutex.Lock()
s.sessions[string(id)] = nil
s.sessionsMutex.Unlock()
time.AfterFunc(s.deleteClosedSessionsAfter, func() {
s.sessionsMutex.Lock()
delete(s.sessions, string(id))
s.sessionsMutex.Unlock()
})
}

View file

@ -84,9 +84,11 @@ var _ = Describe("Server", func() {
firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID)
connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
sessions = make([]*MockPacketHandler, 0)
sessionHandler *MockSessionHandler
)
BeforeEach(func() {
sessionHandler = NewMockSessionHandler(mockCtrl)
newMockSession := func(
_ connection,
runner sessionRunner,
@ -105,7 +107,7 @@ var _ = Describe("Server", func() {
return s, nil
}
serv = &server{
sessions: make(map[string]packetHandler),
sessionHandler: sessionHandler,
newSession: newMockSession,
conn: conn,
config: config,
@ -139,11 +141,13 @@ var _ = Describe("Server", func() {
run := make(chan struct{})
s.EXPECT().run().Do(func() { close(run) })
sessions = append(sessions, s)
sessionHandler.EXPECT().Get(connID)
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
Expect(sess.(*mockSession).connID).To(Equal(connID))
})
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[string(connID)].(*mockSession)
Expect(sess.connID).To(Equal(connID))
Eventually(run).Should(BeClosed())
})
@ -154,51 +158,38 @@ var _ = Describe("Server", func() {
sess.EXPECT().run().Do(func() { close(run) })
err := serv.setupTLS()
Expect(err).ToNot(HaveOccurred())
sessionHandler.EXPECT().Get(connID)
sessionHandler.EXPECT().Add(connID, sess)
serv.serverTLS.sessionChan <- tlsSession{
connID: connID,
sess: sess,
}
Eventually(func() packetHandler {
serv.sessionsMutex.Lock()
defer serv.sessionsMutex.Unlock()
return serv.sessions[string(connID)]
}).Should(Equal(sess))
Eventually(run).Should(BeClosed())
})
It("only accepts one new TLS sessions for one connection ID", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
run := make(chan struct{})
sess := NewMockPacketHandler(mockCtrl)
sess.EXPECT().run().Do(func() { close(run) })
sess2 := NewMockPacketHandler(mockCtrl)
err := serv.setupTLS()
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
sessionHandler.EXPECT().Get(connID).Return(NewMockPacketHandler(mockCtrl), true).Do(func(protocol.ConnectionID) {
close(done)
})
// don't EXPECT any calls to sessionHandler.Add
serv.serverTLS.sessionChan <- tlsSession{
connID: connID,
sess: sess,
}
Eventually(func() packetHandler {
serv.sessionsMutex.Lock()
defer serv.sessionsMutex.Unlock()
return serv.sessions[string(connID)]
}).Should(Equal(sess))
serv.serverTLS.sessionChan <- tlsSession{
connID: connID,
sess: sess2,
}
Consistently(func() packetHandler {
serv.sessionsMutex.Lock()
defer serv.sessionsMutex.Unlock()
return serv.sessions[string(connID)]
}).Should(Equal(sess))
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
})
It("accepts a session once the connection it is forward secure", func() {
s := NewMockPacketHandler(mockCtrl)
s.EXPECT().handlePacket(gomock.Any())
s.EXPECT().run()
run := make(chan struct{})
s.EXPECT().run().Do(func() { close(run) })
sessions = append(sessions, s)
done := make(chan struct{})
go func() {
@ -208,17 +199,19 @@ var _ = Describe("Server", func() {
Expect(sess.(*mockSession).connID).To(Equal(connID))
close(done)
}()
sessionHandler.EXPECT().Get(connID)
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
Consistently(done).ShouldNot(BeClosed())
sess.(*mockSession).runner.onHandshakeComplete(sess)
})
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Consistently(done).ShouldNot(BeClosed())
sess := serv.sessions[string(connID)].(*mockSession)
sess.runner.onHandshakeComplete(sess)
Eventually(done).Should(BeClosed())
Eventually(run).Should(BeClosed())
})
It("doesn't accept sessions that error during the handshake", func() {
run := make(chan error)
run := make(chan error, 1)
sess := NewMockPacketHandler(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
sess.EXPECT().run().DoAndReturn(func() error { return <-run })
@ -229,79 +222,44 @@ var _ = Describe("Server", func() {
serv.Accept()
close(done)
}()
sessionHandler.EXPECT().Get(connID)
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
run <- errors.New("handshake error")
})
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
run <- errors.New("handshake error")
serv.sessions[string(connID)].(*mockSession).runner.removeConnectionID(connID)
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
sessionHandler.EXPECT().Close()
close(serv.errorChan)
serv.Close()
Eventually(done).Should(BeClosed())
})
It("assigns packets to existing sessions", func() {
run := make(chan struct{})
sess := NewMockPacketHandler(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any()).Times(2)
sess.EXPECT().run().Do(func() { close(run) })
sessions = append(sessions, sess)
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
err = serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
Expect(err).ToNot(HaveOccurred())
Eventually(run).Should(BeClosed())
})
It("deletes sessions", func() {
serv.deleteClosedSessionsAfter = time.Second // make sure that the nil value for the closed session doesn't get deleted in this test
serv.sessions[string(connID)] = &mockSession{}
serv.removeConnection(connID)
// The server should now have closed the session, leaving a nil value in the sessions map
Consistently(func() map[string]packetHandler { return serv.sessions }).Should(HaveLen(1))
Expect(serv.sessions[string(connID)]).To(BeNil())
})
It("deletes nil session entries after a wait time", func() {
serv.deleteClosedSessionsAfter = 25 * time.Millisecond
serv.sessions[string(connID)] = &mockSession{}
// make session.run() return
serv.removeConnection(connID)
Eventually(func() bool {
serv.sessionsMutex.Lock()
_, ok := serv.sessions[string(connID)]
serv.sessionsMutex.Unlock()
return ok
}).Should(BeFalse())
})
It("closes sessions and the connection when Close is called", func() {
run := make(chan struct{})
sess := NewMockPacketHandler(mockCtrl)
sess.EXPECT().Close(nil)
sess.EXPECT().handlePacket(gomock.Any())
sess.EXPECT().run().Do(func() { close(run) })
sessions = append(sessions, sess)
sessionHandler.EXPECT().Get(connID).Return(sess, true)
err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
Expect(err).ToNot(HaveOccurred())
})
It("closes the sessionHandler and the connection when Close is called", func() {
go func() {
defer GinkgoRecover()
serv.serve()
}()
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Eventually(run).Should(BeClosed())
// close the server
sessionHandler.EXPECT().Close().AnyTimes()
Expect(serv.Close()).To(Succeed())
Expect(conn.closed).To(BeTrue())
})
It("ignores packets for closed sessions", func() {
serv.sessions[string(connID)] = nil
err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
sessionHandler.EXPECT().Get(connID).Return(nil, true)
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)]).To(BeNil())
})
It("works if no quic.Config is given", func(done Done) {
@ -327,49 +285,32 @@ var _ = Describe("Server", func() {
Eventually(func() bool { return returned }).Should(BeTrue())
})
It("errors when encountering a connection error", func(done Done) {
It("errors when encountering a connection error", func() {
testErr := errors.New("connection error")
conn.readErr = testErr
go serv.serve()
_, err := serv.Accept()
Expect(err).To(MatchError(testErr))
Expect(serv.Close()).To(Succeed())
close(done)
}, 0.5)
It("closes all sessions when encountering a connection error", func() {
sess := NewMockPacketHandler(mockCtrl)
sess.EXPECT().Close(nil)
serv.sessions[string(connID)] = sess
conn.readErr = errors.New("connection error")
sessionHandler.EXPECT().Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.serve()
close(done)
}()
Expect(serv.Close()).To(Succeed())
_, err := serv.Accept()
Expect(err).To(MatchError(testErr))
Eventually(done).Should(BeClosed())
})
It("ignores delayed packets with mismatching versions", func() {
run := make(chan struct{})
sess := NewMockPacketHandler(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any()) // only called once
sess.EXPECT().run().Do(func() { close(run) })
sessions = append(sessions, sess)
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Eventually(run).Should(BeClosed())
// don't EXPECT any handlePacket() calls to this session
sessionHandler.EXPECT().Get(connID).Return(sess, true)
b := &bytes.Buffer{}
// add an unsupported version
data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]+1))
data = append(append(data, b.Bytes()...), 0x01)
err = serv.handlePacket(nil, data)
err := serv.handlePacket(nil, data)
Expect(err).ToNot(HaveOccurred())
// if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn
Expect(conn.dataWritten.Bytes()).To(BeEmpty())
@ -397,21 +338,12 @@ var _ = Describe("Server", func() {
})
It("cuts packets at the payload length", func() {
run := make(chan struct{})
sess := NewMockPacketHandler(mockCtrl)
gomock.InOrder(
sess.EXPECT().handlePacket(gomock.Any()), // first packet
sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) {
Expect(packet.data).To(HaveLen(123))
}),
)
sess.EXPECT().run().Do(func() { close(run) })
sessions = append(sessions, sess)
})
serv.supportsTLS = true
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Eventually(run).Should(BeClosed())
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
@ -422,7 +354,8 @@ var _ = Describe("Server", func() {
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
err = serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
sessionHandler.EXPECT().Get(connID).Return(sess, true)
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
Expect(err).ToNot(HaveOccurred())
})
@ -443,18 +376,8 @@ var _ = Describe("Server", func() {
})
It("ignores Public Resets", func() {
run := make(chan struct{})
sess := NewMockPacketHandler(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any()) // called only once
sess.EXPECT().run().Do(func() { close(run) })
sessions = append(sessions, sess)
err := serv.handlePacket(nil, firstPacket)
err := serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Eventually(run).Should(BeClosed())
err = serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
})
It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {
@ -470,6 +393,7 @@ var _ = Describe("Server", func() {
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
serv.conn = conn
sessionHandler.EXPECT().Get(connID)
err := serv.handlePacket(nil, b.Bytes())
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
Expect(err).ToNot(HaveOccurred())
@ -487,6 +411,7 @@ var _ = Describe("Server", func() {
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small
serv.conn = conn
sessionHandler.EXPECT().Get(connID)
err := serv.handlePacket(udpAddr, b.Bytes())
Expect(err).To(MatchError("dropping small packet with unknown version"))
Expect(conn.dataWritten.Len()).Should(BeZero())
@ -506,8 +431,7 @@ var _ = Describe("Server", func() {
ln, err := Listen(conn, &tls.Config{}, &config)
Expect(err).ToNot(HaveOccurred())
server := ln.(*server)
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
Expect(server.sessions).ToNot(BeNil())
Expect(server.sessionHandler).ToNot(BeNil())
Expect(server.scfg).ToNot(BeNil())
Expect(server.config.Versions).To(Equal(supportedVersions))
Expect(server.config.HandshakeTimeout).To(Equal(1337 * time.Hour))
@ -692,7 +616,6 @@ var _ = Describe("Server", func() {
Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero())
Expect(conn.dataWrittenTo).To(Equal(udpAddr))
Expect(conn.dataWritten.Bytes()[0] & 0x02).ToNot(BeZero()) // check that the ResetFlag is set
Expect(ln.(*server).sessions).To(BeEmpty())
})
})

74
session_map.go Normal file
View file

@ -0,0 +1,74 @@
package quic
import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type sessionMap struct {
mutex sync.RWMutex
sessions map[string] /* string(ConnectionID)*/ packetHandler
closed bool
deleteClosedSessionsAfter time.Duration
}
var _ sessionHandler = &sessionMap{}
func newSessionMap() sessionHandler {
return &sessionMap{
sessions: make(map[string]packetHandler),
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
}
}
func (h *sessionMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
h.mutex.RLock()
sess, ok := h.sessions[string(id)]
h.mutex.RUnlock()
return sess, ok
}
func (h *sessionMap) Add(id protocol.ConnectionID, sess packetHandler) {
h.mutex.Lock()
h.sessions[string(id)] = sess
h.mutex.Unlock()
}
func (h *sessionMap) Remove(id protocol.ConnectionID) {
h.mutex.Lock()
h.sessions[string(id)] = nil
h.mutex.Unlock()
time.AfterFunc(h.deleteClosedSessionsAfter, func() {
h.mutex.Lock()
delete(h.sessions, string(id))
h.mutex.Unlock()
})
}
func (h *sessionMap) Close() {
h.mutex.Lock()
if h.closed {
h.mutex.Unlock()
return
}
h.closed = true
var wg sync.WaitGroup
for _, session := range h.sessions {
if session != nil {
wg.Add(1)
go func(sess packetHandler) {
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
_ = sess.Close(nil)
wg.Done()
}(session)
}
}
h.mutex.Unlock()
wg.Wait()
}

56
session_map_test.go Normal file
View file

@ -0,0 +1,56 @@
package quic
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Session Handler", func() {
var handler *sessionMap
BeforeEach(func() {
handler = newSessionMap().(*sessionMap)
})
It("adds and gets", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
sess := &mockSession{}
handler.Add(connID, sess)
session, ok := handler.Get(connID)
Expect(ok).To(BeTrue())
Expect(session).To(Equal(sess))
})
It("deletes", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
handler.Add(connID, &mockSession{})
handler.Remove(connID)
session, ok := handler.Get(connID)
Expect(ok).To(BeTrue())
Expect(session).To(BeNil())
})
It("deletes nil session entries after a wait time", func() {
handler.deleteClosedSessionsAfter = 25 * time.Millisecond
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
handler.Add(connID, &mockSession{})
handler.Remove(connID)
Eventually(func() bool {
_, ok := handler.Get(connID)
return ok
}).Should(BeFalse())
})
It("closes", func() {
sess1 := NewMockPacketHandler(mockCtrl)
sess1.EXPECT().Close(nil)
sess2 := NewMockPacketHandler(mockCtrl)
sess2.EXPECT().Close(nil)
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1)
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2)
handler.Close()
})
})