mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
refactor the map of sessions into a separate struct
This commit is contained in:
parent
15da47cf98
commit
9c5986945e
6 changed files with 312 additions and 218 deletions
78
mock_session_handler_test.go
Normal file
78
mock_session_handler_test.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go (interfaces: SessionHandler)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockSessionHandler is a mock of SessionHandler interface
|
||||
type MockSessionHandler struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSessionHandlerMockRecorder
|
||||
}
|
||||
|
||||
// MockSessionHandlerMockRecorder is the mock recorder for MockSessionHandler
|
||||
type MockSessionHandlerMockRecorder struct {
|
||||
mock *MockSessionHandler
|
||||
}
|
||||
|
||||
// NewMockSessionHandler creates a new mock instance
|
||||
func NewMockSessionHandler(ctrl *gomock.Controller) *MockSessionHandler {
|
||||
mock := &MockSessionHandler{ctrl: ctrl}
|
||||
mock.recorder = &MockSessionHandlerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockSessionHandler) EXPECT() *MockSessionHandlerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Add mocks base method
|
||||
func (m *MockSessionHandler) Add(arg0 protocol.ConnectionID, arg1 packetHandler) {
|
||||
m.ctrl.Call(m, "Add", arg0, arg1)
|
||||
}
|
||||
|
||||
// Add indicates an expected call of Add
|
||||
func (mr *MockSessionHandlerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSessionHandler)(nil).Add), arg0, arg1)
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockSessionHandler) Close() {
|
||||
m.ctrl.Call(m, "Close")
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockSessionHandlerMockRecorder) Close() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSessionHandler)(nil).Close))
|
||||
}
|
||||
|
||||
// Get mocks base method
|
||||
func (m *MockSessionHandler) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
|
||||
ret := m.ctrl.Call(m, "Get", arg0)
|
||||
ret0, _ := ret[0].(packetHandler)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get
|
||||
func (mr *MockSessionHandlerMockRecorder) Get(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSessionHandler)(nil).Get), arg0)
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
func (m *MockSessionHandler) Remove(arg0 protocol.ConnectionID) {
|
||||
m.ctrl.Call(m, "Remove", arg0)
|
||||
}
|
||||
|
||||
// Remove indicates an expected call of Remove
|
||||
func (mr *MockSessionHandlerMockRecorder) Remove(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSessionHandler)(nil).Remove), arg0)
|
||||
}
|
|
@ -13,5 +13,6 @@ package quic
|
|||
//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler PacketHandler"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_session_handler_test.go github.com/lucas-clemente/quic-go sessionHandler SessionHandler"
|
||||
//go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'"
|
||||
//go:generate sh -c "goimports -w mock*_test.go"
|
||||
|
|
110
server.go
110
server.go
|
@ -6,7 +6,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
|
@ -42,6 +41,13 @@ func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectio
|
|||
|
||||
var _ sessionRunner = &runner{}
|
||||
|
||||
type sessionHandler interface {
|
||||
Add(protocol.ConnectionID, packetHandler)
|
||||
Get(protocol.ConnectionID) (packetHandler, bool)
|
||||
Remove(protocol.ConnectionID)
|
||||
Close()
|
||||
}
|
||||
|
||||
// A Listener of QUIC
|
||||
type server struct {
|
||||
tlsConf *tls.Config
|
||||
|
@ -55,9 +61,7 @@ type server struct {
|
|||
certChain crypto.CertChain
|
||||
scfg *handshake.ServerConfig
|
||||
|
||||
sessionsMutex sync.RWMutex
|
||||
sessions map[string] /* string(ConnectionID)*/ packetHandler
|
||||
closed bool
|
||||
sessionHandler sessionHandler
|
||||
|
||||
serverError error
|
||||
|
||||
|
@ -65,9 +69,8 @@ type server struct {
|
|||
errorChan chan struct{}
|
||||
|
||||
sessionRunner sessionRunner
|
||||
// set as members, so they can be set in the tests
|
||||
newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (packetHandler, error)
|
||||
deleteClosedSessionsAfter time.Duration
|
||||
// set as a member, so they can be set in the tests
|
||||
newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (packetHandler, error)
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
@ -115,18 +118,17 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
|
|||
}
|
||||
|
||||
s := &server{
|
||||
conn: conn,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
certChain: certChain,
|
||||
scfg: scfg,
|
||||
sessions: map[string]packetHandler{},
|
||||
newSession: newSession,
|
||||
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
|
||||
sessionQueue: make(chan Session, 5),
|
||||
errorChan: make(chan struct{}),
|
||||
supportsTLS: supportsTLS,
|
||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||
conn: conn,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
certChain: certChain,
|
||||
scfg: scfg,
|
||||
newSession: newSession,
|
||||
sessionHandler: newSessionMap(),
|
||||
sessionQueue: make(chan Session, 5),
|
||||
errorChan: make(chan struct{}),
|
||||
supportsTLS: supportsTLS,
|
||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||
}
|
||||
s.setup()
|
||||
if supportsTLS {
|
||||
|
@ -142,7 +144,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
|
|||
func (s *server) setup() {
|
||||
s.sessionRunner = &runner{
|
||||
onHandshakeCompleteImpl: func(sess packetHandler) { s.sessionQueue <- sess },
|
||||
removeConnectionIDImpl: s.removeConnection,
|
||||
removeConnectionIDImpl: s.sessionHandler.Remove,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -165,13 +167,13 @@ func (s *server) setupTLS() error {
|
|||
case tlsSession := <-sessionChan:
|
||||
connID := tlsSession.connID
|
||||
sess := tlsSession.sess
|
||||
s.sessionsMutex.Lock()
|
||||
if _, ok := s.sessions[string(connID)]; ok { // drop this session if it already exists
|
||||
s.sessionsMutex.Unlock()
|
||||
if _, ok := s.sessionHandler.Get(connID); ok { // drop this session if it already exists
|
||||
continue
|
||||
}
|
||||
s.sessions[string(connID)] = sess
|
||||
s.sessionsMutex.Unlock()
|
||||
// TODO(#1003): There's a race condition here.
|
||||
// If another connection with the same conn ID is added between Get() and Add(), it would be overwritten.
|
||||
// We can avoid this be using server-chosen connection IDs.
|
||||
s.sessionHandler.Add(connID, sess)
|
||||
go sess.run()
|
||||
}
|
||||
}
|
||||
|
@ -288,27 +290,7 @@ func (s *server) Accept() (Session, error) {
|
|||
|
||||
// Close the server
|
||||
func (s *server) Close() error {
|
||||
s.sessionsMutex.Lock()
|
||||
if s.closed {
|
||||
s.sessionsMutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, session := range s.sessions {
|
||||
if session != nil {
|
||||
wg.Add(1)
|
||||
go func(sess packetHandler) {
|
||||
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
|
||||
_ = sess.Close(nil)
|
||||
wg.Done()
|
||||
}(session)
|
||||
}
|
||||
}
|
||||
s.sessionsMutex.Unlock()
|
||||
wg.Wait()
|
||||
|
||||
s.sessionHandler.Close()
|
||||
err := s.conn.Close()
|
||||
<-s.errorChan // wait for serve() to return
|
||||
return err
|
||||
|
@ -359,10 +341,7 @@ func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot
|
|||
}
|
||||
}
|
||||
|
||||
s.sessionsMutex.RLock()
|
||||
session, sessionKnown := s.sessions[string(hdr.DestConnectionID)]
|
||||
s.sessionsMutex.RUnlock()
|
||||
|
||||
session, sessionKnown := s.sessionHandler.Get(hdr.DestConnectionID)
|
||||
if sessionKnown && session == nil {
|
||||
// Late packet for closed session
|
||||
return nil
|
||||
|
@ -382,21 +361,18 @@ func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot
|
|||
}
|
||||
|
||||
func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
|
||||
s.sessionsMutex.RLock()
|
||||
session, sessionKnown := s.sessions[string(hdr.DestConnectionID)]
|
||||
s.sessionsMutex.RUnlock()
|
||||
|
||||
if sessionKnown && session == nil {
|
||||
// Late packet for closed session
|
||||
return nil
|
||||
}
|
||||
|
||||
// ignore all Public Reset packets
|
||||
if hdr.ResetFlag {
|
||||
s.logger.Infof("Received unexpected Public Reset for connection %s.", hdr.DestConnectionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
session, sessionKnown := s.sessionHandler.Get(hdr.DestConnectionID)
|
||||
if sessionKnown && session == nil {
|
||||
// Late packet for closed session
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset
|
||||
// This should only happen after a server restart, when we still receive packets for connections that we lost the state for.
|
||||
if !sessionKnown && !hdr.VersionFlag {
|
||||
|
@ -450,9 +426,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.sessionsMutex.Lock()
|
||||
s.sessions[string(hdr.DestConnectionID)] = session
|
||||
s.sessionsMutex.Unlock()
|
||||
s.sessionHandler.Add(hdr.DestConnectionID, session)
|
||||
|
||||
go session.run()
|
||||
}
|
||||
|
@ -465,15 +439,3 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
|
|||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *server) removeConnection(id protocol.ConnectionID) {
|
||||
s.sessionsMutex.Lock()
|
||||
s.sessions[string(id)] = nil
|
||||
s.sessionsMutex.Unlock()
|
||||
|
||||
time.AfterFunc(s.deleteClosedSessionsAfter, func() {
|
||||
s.sessionsMutex.Lock()
|
||||
delete(s.sessions, string(id))
|
||||
s.sessionsMutex.Unlock()
|
||||
})
|
||||
}
|
||||
|
|
211
server_test.go
211
server_test.go
|
@ -80,13 +80,15 @@ var _ = Describe("Server", func() {
|
|||
|
||||
Context("with mock session", func() {
|
||||
var (
|
||||
serv *server
|
||||
firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID)
|
||||
connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
||||
sessions = make([]*MockPacketHandler, 0)
|
||||
serv *server
|
||||
firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID)
|
||||
connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
||||
sessions = make([]*MockPacketHandler, 0)
|
||||
sessionHandler *MockSessionHandler
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
sessionHandler = NewMockSessionHandler(mockCtrl)
|
||||
newMockSession := func(
|
||||
_ connection,
|
||||
runner sessionRunner,
|
||||
|
@ -105,13 +107,13 @@ var _ = Describe("Server", func() {
|
|||
return s, nil
|
||||
}
|
||||
serv = &server{
|
||||
sessions: make(map[string]packetHandler),
|
||||
newSession: newMockSession,
|
||||
conn: conn,
|
||||
config: config,
|
||||
sessionQueue: make(chan Session, 5),
|
||||
errorChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger,
|
||||
sessionHandler: sessionHandler,
|
||||
newSession: newMockSession,
|
||||
conn: conn,
|
||||
config: config,
|
||||
sessionQueue: make(chan Session, 5),
|
||||
errorChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger,
|
||||
}
|
||||
serv.setup()
|
||||
b := &bytes.Buffer{}
|
||||
|
@ -139,11 +141,13 @@ var _ = Describe("Server", func() {
|
|||
run := make(chan struct{})
|
||||
s.EXPECT().run().Do(func() { close(run) })
|
||||
sessions = append(sessions, s)
|
||||
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
|
||||
Expect(sess.(*mockSession).connID).To(Equal(connID))
|
||||
})
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
sess := serv.sessions[string(connID)].(*mockSession)
|
||||
Expect(sess.connID).To(Equal(connID))
|
||||
Eventually(run).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -154,51 +158,38 @@ var _ = Describe("Server", func() {
|
|||
sess.EXPECT().run().Do(func() { close(run) })
|
||||
err := serv.setupTLS()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
sessionHandler.EXPECT().Add(connID, sess)
|
||||
serv.serverTLS.sessionChan <- tlsSession{
|
||||
connID: connID,
|
||||
sess: sess,
|
||||
}
|
||||
Eventually(func() packetHandler {
|
||||
serv.sessionsMutex.Lock()
|
||||
defer serv.sessionsMutex.Unlock()
|
||||
return serv.sessions[string(connID)]
|
||||
}).Should(Equal(sess))
|
||||
Eventually(run).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("only accepts one new TLS sessions for one connection ID", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
run := make(chan struct{})
|
||||
sess := NewMockPacketHandler(mockCtrl)
|
||||
sess.EXPECT().run().Do(func() { close(run) })
|
||||
sess2 := NewMockPacketHandler(mockCtrl)
|
||||
err := serv.setupTLS()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
sessionHandler.EXPECT().Get(connID).Return(NewMockPacketHandler(mockCtrl), true).Do(func(protocol.ConnectionID) {
|
||||
close(done)
|
||||
})
|
||||
// don't EXPECT any calls to sessionHandler.Add
|
||||
serv.serverTLS.sessionChan <- tlsSession{
|
||||
connID: connID,
|
||||
sess: sess,
|
||||
}
|
||||
Eventually(func() packetHandler {
|
||||
serv.sessionsMutex.Lock()
|
||||
defer serv.sessionsMutex.Unlock()
|
||||
return serv.sessions[string(connID)]
|
||||
}).Should(Equal(sess))
|
||||
serv.serverTLS.sessionChan <- tlsSession{
|
||||
connID: connID,
|
||||
sess: sess2,
|
||||
}
|
||||
Consistently(func() packetHandler {
|
||||
serv.sessionsMutex.Lock()
|
||||
defer serv.sessionsMutex.Unlock()
|
||||
return serv.sessions[string(connID)]
|
||||
}).Should(Equal(sess))
|
||||
Eventually(run).Should(BeClosed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("accepts a session once the connection it is forward secure", func() {
|
||||
s := NewMockPacketHandler(mockCtrl)
|
||||
s.EXPECT().handlePacket(gomock.Any())
|
||||
s.EXPECT().run()
|
||||
run := make(chan struct{})
|
||||
s.EXPECT().run().Do(func() { close(run) })
|
||||
sessions = append(sessions, s)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -208,17 +199,19 @@ var _ = Describe("Server", func() {
|
|||
Expect(sess.(*mockSession).connID).To(Equal(connID))
|
||||
close(done)
|
||||
}()
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
sess.(*mockSession).runner.onHandshakeComplete(sess)
|
||||
})
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
sess := serv.sessions[string(connID)].(*mockSession)
|
||||
sess.runner.onHandshakeComplete(sess)
|
||||
Eventually(done).Should(BeClosed())
|
||||
Eventually(run).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't accept sessions that error during the handshake", func() {
|
||||
run := make(chan error)
|
||||
run := make(chan error, 1)
|
||||
sess := NewMockPacketHandler(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
sess.EXPECT().run().DoAndReturn(func() error { return <-run })
|
||||
|
@ -229,79 +222,44 @@ var _ = Describe("Server", func() {
|
|||
serv.Accept()
|
||||
close(done)
|
||||
}()
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
|
||||
run <- errors.New("handshake error")
|
||||
})
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
run <- errors.New("handshake error")
|
||||
serv.sessions[string(connID)].(*mockSession).runner.removeConnectionID(connID)
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
// make the go routine return
|
||||
sessionHandler.EXPECT().Close()
|
||||
close(serv.errorChan)
|
||||
serv.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("assigns packets to existing sessions", func() {
|
||||
run := make(chan struct{})
|
||||
sess := NewMockPacketHandler(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any()).Times(2)
|
||||
sess.EXPECT().run().Do(func() { close(run) })
|
||||
sessions = append(sessions, sess)
|
||||
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(run).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("deletes sessions", func() {
|
||||
serv.deleteClosedSessionsAfter = time.Second // make sure that the nil value for the closed session doesn't get deleted in this test
|
||||
serv.sessions[string(connID)] = &mockSession{}
|
||||
serv.removeConnection(connID)
|
||||
// The server should now have closed the session, leaving a nil value in the sessions map
|
||||
Consistently(func() map[string]packetHandler { return serv.sessions }).Should(HaveLen(1))
|
||||
Expect(serv.sessions[string(connID)]).To(BeNil())
|
||||
})
|
||||
|
||||
It("deletes nil session entries after a wait time", func() {
|
||||
serv.deleteClosedSessionsAfter = 25 * time.Millisecond
|
||||
serv.sessions[string(connID)] = &mockSession{}
|
||||
// make session.run() return
|
||||
serv.removeConnection(connID)
|
||||
Eventually(func() bool {
|
||||
serv.sessionsMutex.Lock()
|
||||
_, ok := serv.sessions[string(connID)]
|
||||
serv.sessionsMutex.Unlock()
|
||||
return ok
|
||||
}).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("closes sessions and the connection when Close is called", func() {
|
||||
run := make(chan struct{})
|
||||
sess := NewMockPacketHandler(mockCtrl)
|
||||
sess.EXPECT().Close(nil)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
sess.EXPECT().run().Do(func() { close(run) })
|
||||
sessions = append(sessions, sess)
|
||||
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("closes the sessionHandler and the connection when Close is called", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serv.serve()
|
||||
}()
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(run).Should(BeClosed())
|
||||
// close the server
|
||||
sessionHandler.EXPECT().Close().AnyTimes()
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Expect(conn.closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("ignores packets for closed sessions", func() {
|
||||
serv.sessions[string(connID)] = nil
|
||||
err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
|
||||
sessionHandler.EXPECT().Get(connID).Return(nil, true)
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
Expect(serv.sessions[string(connID)]).To(BeNil())
|
||||
})
|
||||
|
||||
It("works if no quic.Config is given", func(done Done) {
|
||||
|
@ -327,49 +285,32 @@ var _ = Describe("Server", func() {
|
|||
Eventually(func() bool { return returned }).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("errors when encountering a connection error", func(done Done) {
|
||||
It("errors when encountering a connection error", func() {
|
||||
testErr := errors.New("connection error")
|
||||
conn.readErr = testErr
|
||||
go serv.serve()
|
||||
_, err := serv.Accept()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
close(done)
|
||||
}, 0.5)
|
||||
|
||||
It("closes all sessions when encountering a connection error", func() {
|
||||
sess := NewMockPacketHandler(mockCtrl)
|
||||
sess.EXPECT().Close(nil)
|
||||
serv.sessions[string(connID)] = sess
|
||||
|
||||
conn.readErr = errors.New("connection error")
|
||||
sessionHandler.EXPECT().Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serv.serve()
|
||||
close(done)
|
||||
}()
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
_, err := serv.Accept()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("ignores delayed packets with mismatching versions", func() {
|
||||
run := make(chan struct{})
|
||||
sess := NewMockPacketHandler(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any()) // only called once
|
||||
sess.EXPECT().run().Do(func() { close(run) })
|
||||
sessions = append(sessions, sess)
|
||||
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(run).Should(BeClosed())
|
||||
// don't EXPECT any handlePacket() calls to this session
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
// add an unsupported version
|
||||
data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
|
||||
utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]+1))
|
||||
data = append(append(data, b.Bytes()...), 0x01)
|
||||
err = serv.handlePacket(nil, data)
|
||||
err := serv.handlePacket(nil, data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn
|
||||
Expect(conn.dataWritten.Bytes()).To(BeEmpty())
|
||||
|
@ -397,21 +338,12 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("cuts packets at the payload length", func() {
|
||||
run := make(chan struct{})
|
||||
sess := NewMockPacketHandler(mockCtrl)
|
||||
gomock.InOrder(
|
||||
sess.EXPECT().handlePacket(gomock.Any()), // first packet
|
||||
sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) {
|
||||
Expect(packet.data).To(HaveLen(123))
|
||||
}),
|
||||
)
|
||||
sess.EXPECT().run().Do(func() { close(run) })
|
||||
sessions = append(sessions, sess)
|
||||
sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) {
|
||||
Expect(packet.data).To(HaveLen(123))
|
||||
})
|
||||
|
||||
serv.supportsTLS = true
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(run).Should(BeClosed())
|
||||
b := &bytes.Buffer{}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
|
@ -422,7 +354,8 @@ var _ = Describe("Server", func() {
|
|||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
err = serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
|
||||
sessionHandler.EXPECT().Get(connID).Return(sess, true)
|
||||
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
|
@ -443,18 +376,8 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("ignores Public Resets", func() {
|
||||
run := make(chan struct{})
|
||||
sess := NewMockPacketHandler(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any()) // called only once
|
||||
sess.EXPECT().run().Do(func() { close(run) })
|
||||
sessions = append(sessions, sess)
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
err := serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
Eventually(run).Should(BeClosed())
|
||||
err = serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {
|
||||
|
@ -470,6 +393,7 @@ var _ = Describe("Server", func() {
|
|||
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
|
||||
serv.conn = conn
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
err := serv.handlePacket(nil, b.Bytes())
|
||||
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -487,6 +411,7 @@ var _ = Describe("Server", func() {
|
|||
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small
|
||||
serv.conn = conn
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
err := serv.handlePacket(udpAddr, b.Bytes())
|
||||
Expect(err).To(MatchError("dropping small packet with unknown version"))
|
||||
Expect(conn.dataWritten.Len()).Should(BeZero())
|
||||
|
@ -506,8 +431,7 @@ var _ = Describe("Server", func() {
|
|||
ln, err := Listen(conn, &tls.Config{}, &config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server := ln.(*server)
|
||||
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
|
||||
Expect(server.sessions).ToNot(BeNil())
|
||||
Expect(server.sessionHandler).ToNot(BeNil())
|
||||
Expect(server.scfg).ToNot(BeNil())
|
||||
Expect(server.config.Versions).To(Equal(supportedVersions))
|
||||
Expect(server.config.HandshakeTimeout).To(Equal(1337 * time.Hour))
|
||||
|
@ -692,7 +616,6 @@ var _ = Describe("Server", func() {
|
|||
Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero())
|
||||
Expect(conn.dataWrittenTo).To(Equal(udpAddr))
|
||||
Expect(conn.dataWritten.Bytes()[0] & 0x02).ToNot(BeZero()) // check that the ResetFlag is set
|
||||
Expect(ln.(*server).sessions).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
|
|
74
session_map.go
Normal file
74
session_map.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type sessionMap struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
sessions map[string] /* string(ConnectionID)*/ packetHandler
|
||||
closed bool
|
||||
|
||||
deleteClosedSessionsAfter time.Duration
|
||||
}
|
||||
|
||||
var _ sessionHandler = &sessionMap{}
|
||||
|
||||
func newSessionMap() sessionHandler {
|
||||
return &sessionMap{
|
||||
sessions: make(map[string]packetHandler),
|
||||
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sessionMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
|
||||
h.mutex.RLock()
|
||||
sess, ok := h.sessions[string(id)]
|
||||
h.mutex.RUnlock()
|
||||
return sess, ok
|
||||
}
|
||||
|
||||
func (h *sessionMap) Add(id protocol.ConnectionID, sess packetHandler) {
|
||||
h.mutex.Lock()
|
||||
h.sessions[string(id)] = sess
|
||||
h.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (h *sessionMap) Remove(id protocol.ConnectionID) {
|
||||
h.mutex.Lock()
|
||||
h.sessions[string(id)] = nil
|
||||
h.mutex.Unlock()
|
||||
|
||||
time.AfterFunc(h.deleteClosedSessionsAfter, func() {
|
||||
h.mutex.Lock()
|
||||
delete(h.sessions, string(id))
|
||||
h.mutex.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func (h *sessionMap) Close() {
|
||||
h.mutex.Lock()
|
||||
if h.closed {
|
||||
h.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
h.closed = true
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, session := range h.sessions {
|
||||
if session != nil {
|
||||
wg.Add(1)
|
||||
go func(sess packetHandler) {
|
||||
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
|
||||
_ = sess.Close(nil)
|
||||
wg.Done()
|
||||
}(session)
|
||||
}
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
wg.Wait()
|
||||
}
|
56
session_map_test.go
Normal file
56
session_map_test.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Session Handler", func() {
|
||||
var handler *sessionMap
|
||||
|
||||
BeforeEach(func() {
|
||||
handler = newSessionMap().(*sessionMap)
|
||||
})
|
||||
|
||||
It("adds and gets", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||
sess := &mockSession{}
|
||||
handler.Add(connID, sess)
|
||||
session, ok := handler.Get(connID)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(session).To(Equal(sess))
|
||||
})
|
||||
|
||||
It("deletes", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||
handler.Add(connID, &mockSession{})
|
||||
handler.Remove(connID)
|
||||
session, ok := handler.Get(connID)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(session).To(BeNil())
|
||||
})
|
||||
|
||||
It("deletes nil session entries after a wait time", func() {
|
||||
handler.deleteClosedSessionsAfter = 25 * time.Millisecond
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||
handler.Add(connID, &mockSession{})
|
||||
handler.Remove(connID)
|
||||
Eventually(func() bool {
|
||||
_, ok := handler.Get(connID)
|
||||
return ok
|
||||
}).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("closes", func() {
|
||||
sess1 := NewMockPacketHandler(mockCtrl)
|
||||
sess1.EXPECT().Close(nil)
|
||||
sess2 := NewMockPacketHandler(mockCtrl)
|
||||
sess2.EXPECT().Close(nil)
|
||||
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1)
|
||||
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2)
|
||||
handler.Close()
|
||||
})
|
||||
})
|
Loading…
Add table
Add a link
Reference in a new issue