make server delete sessions when they are closed

fixes #46
This commit is contained in:
Lucas Clemente 2016-05-04 16:34:08 +02:00
parent 150795d702
commit ef6e8cf1b4
4 changed files with 69 additions and 9 deletions

View file

@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"net" "net"
"sync"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
@ -25,11 +26,12 @@ type Server struct {
signer crypto.Signer signer crypto.Signer
scfg *handshake.ServerConfig scfg *handshake.ServerConfig
sessions map[protocol.ConnectionID]PacketHandler sessions map[protocol.ConnectionID]PacketHandler
sessionsMutex sync.RWMutex
streamCallback StreamCallback streamCallback StreamCallback
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback CloseCallback) PacketHandler
} }
// NewServer makes a new server // NewServer makes a new server
@ -101,7 +103,10 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
return nil return nil
} }
s.sessionsMutex.RLock()
session, ok := s.sessions[publicHeader.ConnectionID] session, ok := s.sessions[publicHeader.ConnectionID]
s.sessionsMutex.RUnlock()
if !ok { if !ok {
utils.Infof("Serving new connection: %d from %v", publicHeader.ConnectionID, remoteAddr) utils.Infof("Serving new connection: %d from %v", publicHeader.ConnectionID, remoteAddr)
session = s.newSession( session = s.newSession(
@ -110,14 +115,27 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
publicHeader.ConnectionID, publicHeader.ConnectionID,
s.scfg, s.scfg,
s.streamCallback, s.streamCallback,
s.closeCallback,
) )
go session.Run() go session.Run()
s.sessionsMutex.Lock()
s.sessions[publicHeader.ConnectionID] = session s.sessions[publicHeader.ConnectionID] = session
s.sessionsMutex.Unlock()
}
if session == nil {
// Late packet for closed session
return nil
} }
session.HandlePacket(remoteAddr, publicHeader, r) session.HandlePacket(remoteAddr, publicHeader, r)
return nil return nil
} }
func (s *Server) closeCallback(session *Session) {
s.sessionsMutex.Lock()
s.sessions[session.connectionID] = nil
s.sessionsMutex.Unlock()
}
func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte { func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte {
fullReply := &bytes.Buffer{} fullReply := &bytes.Buffer{}
responsePublicHeader := PublicHeader{ responsePublicHeader := PublicHeader{

View file

@ -5,6 +5,7 @@ import (
"net" "net"
"time" "time"
"github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/testdata" "github.com/lucas-clemente/quic-go/testdata"
@ -25,7 +26,7 @@ func (s *mockSession) HandlePacket(addr interface{}, publicHeader *PublicHeader,
func (s *mockSession) Run() { func (s *mockSession) Run() {
} }
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler { func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback CloseCallback) PacketHandler {
return &mockSession{ return &mockSession{
connectionID: connectionID, connectionID: connectionID,
} }
@ -120,4 +121,35 @@ var _ = Describe("Server", func() {
err = server.ListenAndServe("localhost:13370") err = server.ListenAndServe("localhost:13370")
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
It("closes and deletes sessions", func() {
server, err := NewServer(testdata.GetTLSConfig(), nil)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
err := server.ListenAndServe("localhost:13370")
Expect(err).To(HaveOccurred())
}()
addr, err := net.ResolveUDPAddr("udp", "localhost:13370")
Expect(err).ToNot(HaveOccurred())
// Send an invalid packet
time.Sleep(10 * time.Millisecond)
conn, err := net.DialUDP("udp", nil, addr)
Expect(err).ToNot(HaveOccurred())
pheader := []byte{0x0d, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x51, 0x30, 0x33, 0x32, 0x01}
_, err = conn.Write(append(pheader, (&crypto.NullAEAD{}).Seal(0, pheader, nil)...))
Expect(err).ToNot(HaveOccurred())
time.Sleep(10 * time.Millisecond)
// The server should now have closed the session, leaving a nil value in the sessions map
Expect(server.sessions).To(HaveLen(1))
// Expect(server.sessions[0x4cfa9f9b668619f6]).To(BeNil())
Expect(server.sessions[0x4cfa9f9b668619f6]).To(BeNil())
err = server.Close()
Expect(err).ToNot(HaveOccurred())
})
}) })

View file

@ -30,11 +30,15 @@ var (
// StreamCallback gets a stream frame and returns a reply frame // StreamCallback gets a stream frame and returns a reply frame
type StreamCallback func(*Session, utils.Stream) type StreamCallback func(*Session, utils.Stream)
// CloseCallback is called when a session is closed
type CloseCallback func(*Session)
// A Session is a QUIC session // A Session is a QUIC session
type Session struct { type Session struct {
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
streamCallback StreamCallback streamCallback StreamCallback
closeCallback CloseCallback
conn connection conn connection
@ -63,12 +67,13 @@ type Session struct {
} }
// NewSession makes a new session // NewSession makes a new session
func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler { func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback CloseCallback) PacketHandler {
stopWaitingManager := ackhandler.NewStopWaitingManager() stopWaitingManager := ackhandler.NewStopWaitingManager()
session := &Session{ session := &Session{
connectionID: connectionID, connectionID: connectionID,
conn: conn, conn: conn,
streamCallback: streamCallback, streamCallback: streamCallback,
closeCallback: closeCallback,
streams: make(map[protocol.StreamID]*stream), streams: make(map[protocol.StreamID]*stream),
sentPacketHandler: ackhandler.NewSentPacketHandler(stopWaitingManager), sentPacketHandler: ackhandler.NewSentPacketHandler(stopWaitingManager),
receivedPacketHandler: ackhandler.NewReceivedPacketHandler(), receivedPacketHandler: ackhandler.NewReceivedPacketHandler(),
@ -270,6 +275,8 @@ func (s *Session) Close(e error, sendConnectionClose bool) error {
s.closed = true s.closed = true
s.closeChan <- struct{}{} s.closeChan <- struct{}{}
s.closeCallback(s)
if !sendConnectionClose { if !sendConnectionClose {
return nil return nil
} }

View file

@ -200,6 +200,7 @@ var _ = Describe("Session", func() {
Context("closing", func() { Context("closing", func() {
var ( var (
nGoRoutinesBefore int nGoRoutinesBefore int
closed bool
) )
BeforeEach(func() { BeforeEach(func() {
@ -208,13 +209,14 @@ var _ = Describe("Session", func() {
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = NewSession(conn, 0, 0, scfg, nil).(*Session) session = NewSession(conn, 0, 0, scfg, nil, func(*Session) { closed = true }).(*Session)
go session.Run() go session.Run()
Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore + 2)) Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore + 2))
}) })
It("shuts down without error", func() { It("shuts down without error", func() {
session.Close(nil, true) session.Close(nil, true)
Expect(closed).To(BeTrue())
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore)) Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore))
}) })
@ -224,6 +226,7 @@ var _ = Describe("Session", func() {
s, err := session.NewStream(5) s, err := session.NewStream(5)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
session.Close(testErr, true) session.Close(testErr, true)
Expect(closed).To(BeTrue())
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore)) Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore))
n, err := s.Read([]byte{0}) n, err := s.Read([]byte{0})
@ -240,7 +243,7 @@ var _ = Describe("Session", func() {
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = NewSession(conn, 0, 0, scfg, nil).(*Session) session = NewSession(conn, 0, 0, scfg, nil, nil).(*Session)
}) })
It("sends ack frames", func() { It("sends ack frames", func() {
@ -277,7 +280,7 @@ var _ = Describe("Session", func() {
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = NewSession(conn, 0, 0, scfg, nil).(*Session) session = NewSession(conn, 0, 0, scfg, nil, func(*Session) {}).(*Session)
}) })
It("sends after queuing a stream frame", func() { It("sends after queuing a stream frame", func() {
@ -308,7 +311,7 @@ var _ = Describe("Session", func() {
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = NewSession(conn, 0, 0, scfg, nil).(*Session) session = NewSession(conn, 0, 0, scfg, nil, func(*Session) {}).(*Session)
s, err := session.NewStream(3) s, err := session.NewStream(3)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = session.handleStreamFrame(&frames.StreamFrame{ err = session.handleStreamFrame(&frames.StreamFrame{
@ -327,7 +330,7 @@ var _ = Describe("Session", func() {
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = NewSession(conn, 0, 0, scfg, nil).(*Session) session = NewSession(conn, 0, 0, scfg, nil, nil).(*Session)
hdr := &PublicHeader{ hdr := &PublicHeader{
PacketNumber: 42, PacketNumber: 42,
} }