mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 21:27:35 +03:00
parent
150795d702
commit
ef6e8cf1b4
4 changed files with 69 additions and 9 deletions
22
server.go
22
server.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/crypto"
|
"github.com/lucas-clemente/quic-go/crypto"
|
||||||
"github.com/lucas-clemente/quic-go/handshake"
|
"github.com/lucas-clemente/quic-go/handshake"
|
||||||
|
@ -25,11 +26,12 @@ type Server struct {
|
||||||
signer crypto.Signer
|
signer crypto.Signer
|
||||||
scfg *handshake.ServerConfig
|
scfg *handshake.ServerConfig
|
||||||
|
|
||||||
sessions map[protocol.ConnectionID]PacketHandler
|
sessions map[protocol.ConnectionID]PacketHandler
|
||||||
|
sessionsMutex sync.RWMutex
|
||||||
|
|
||||||
streamCallback StreamCallback
|
streamCallback StreamCallback
|
||||||
|
|
||||||
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler
|
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback CloseCallback) PacketHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer makes a new server
|
// NewServer makes a new server
|
||||||
|
@ -101,7 +103,10 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.sessionsMutex.RLock()
|
||||||
session, ok := s.sessions[publicHeader.ConnectionID]
|
session, ok := s.sessions[publicHeader.ConnectionID]
|
||||||
|
s.sessionsMutex.RUnlock()
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
utils.Infof("Serving new connection: %d from %v", publicHeader.ConnectionID, remoteAddr)
|
utils.Infof("Serving new connection: %d from %v", publicHeader.ConnectionID, remoteAddr)
|
||||||
session = s.newSession(
|
session = s.newSession(
|
||||||
|
@ -110,14 +115,27 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
|
||||||
publicHeader.ConnectionID,
|
publicHeader.ConnectionID,
|
||||||
s.scfg,
|
s.scfg,
|
||||||
s.streamCallback,
|
s.streamCallback,
|
||||||
|
s.closeCallback,
|
||||||
)
|
)
|
||||||
go session.Run()
|
go session.Run()
|
||||||
|
s.sessionsMutex.Lock()
|
||||||
s.sessions[publicHeader.ConnectionID] = session
|
s.sessions[publicHeader.ConnectionID] = session
|
||||||
|
s.sessionsMutex.Unlock()
|
||||||
|
}
|
||||||
|
if session == nil {
|
||||||
|
// Late packet for closed session
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
session.HandlePacket(remoteAddr, publicHeader, r)
|
session.HandlePacket(remoteAddr, publicHeader, r)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeCallback(session *Session) {
|
||||||
|
s.sessionsMutex.Lock()
|
||||||
|
s.sessions[session.connectionID] = nil
|
||||||
|
s.sessionsMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte {
|
func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte {
|
||||||
fullReply := &bytes.Buffer{}
|
fullReply := &bytes.Buffer{}
|
||||||
responsePublicHeader := PublicHeader{
|
responsePublicHeader := PublicHeader{
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/crypto"
|
||||||
"github.com/lucas-clemente/quic-go/handshake"
|
"github.com/lucas-clemente/quic-go/handshake"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/testdata"
|
"github.com/lucas-clemente/quic-go/testdata"
|
||||||
|
@ -25,7 +26,7 @@ func (s *mockSession) HandlePacket(addr interface{}, publicHeader *PublicHeader,
|
||||||
func (s *mockSession) Run() {
|
func (s *mockSession) Run() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler {
|
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback CloseCallback) PacketHandler {
|
||||||
return &mockSession{
|
return &mockSession{
|
||||||
connectionID: connectionID,
|
connectionID: connectionID,
|
||||||
}
|
}
|
||||||
|
@ -120,4 +121,35 @@ var _ = Describe("Server", func() {
|
||||||
err = server.ListenAndServe("localhost:13370")
|
err = server.ListenAndServe("localhost:13370")
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("closes and deletes sessions", func() {
|
||||||
|
server, err := NewServer(testdata.GetTLSConfig(), nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
err := server.ListenAndServe("localhost:13370")
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
}()
|
||||||
|
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", "localhost:13370")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
// Send an invalid packet
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
conn, err := net.DialUDP("udp", nil, addr)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
pheader := []byte{0x0d, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x51, 0x30, 0x33, 0x32, 0x01}
|
||||||
|
_, err = conn.Write(append(pheader, (&crypto.NullAEAD{}).Seal(0, pheader, nil)...))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// The server should now have closed the session, leaving a nil value in the sessions map
|
||||||
|
Expect(server.sessions).To(HaveLen(1))
|
||||||
|
// Expect(server.sessions[0x4cfa9f9b668619f6]).To(BeNil())
|
||||||
|
Expect(server.sessions[0x4cfa9f9b668619f6]).To(BeNil())
|
||||||
|
|
||||||
|
err = server.Close()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -30,11 +30,15 @@ var (
|
||||||
// StreamCallback gets a stream frame and returns a reply frame
|
// StreamCallback gets a stream frame and returns a reply frame
|
||||||
type StreamCallback func(*Session, utils.Stream)
|
type StreamCallback func(*Session, utils.Stream)
|
||||||
|
|
||||||
|
// CloseCallback is called when a session is closed
|
||||||
|
type CloseCallback func(*Session)
|
||||||
|
|
||||||
// A Session is a QUIC session
|
// A Session is a QUIC session
|
||||||
type Session struct {
|
type Session struct {
|
||||||
connectionID protocol.ConnectionID
|
connectionID protocol.ConnectionID
|
||||||
|
|
||||||
streamCallback StreamCallback
|
streamCallback StreamCallback
|
||||||
|
closeCallback CloseCallback
|
||||||
|
|
||||||
conn connection
|
conn connection
|
||||||
|
|
||||||
|
@ -63,12 +67,13 @@ type Session struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSession makes a new session
|
// NewSession makes a new session
|
||||||
func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler {
|
func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback CloseCallback) PacketHandler {
|
||||||
stopWaitingManager := ackhandler.NewStopWaitingManager()
|
stopWaitingManager := ackhandler.NewStopWaitingManager()
|
||||||
session := &Session{
|
session := &Session{
|
||||||
connectionID: connectionID,
|
connectionID: connectionID,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
streamCallback: streamCallback,
|
streamCallback: streamCallback,
|
||||||
|
closeCallback: closeCallback,
|
||||||
streams: make(map[protocol.StreamID]*stream),
|
streams: make(map[protocol.StreamID]*stream),
|
||||||
sentPacketHandler: ackhandler.NewSentPacketHandler(stopWaitingManager),
|
sentPacketHandler: ackhandler.NewSentPacketHandler(stopWaitingManager),
|
||||||
receivedPacketHandler: ackhandler.NewReceivedPacketHandler(),
|
receivedPacketHandler: ackhandler.NewReceivedPacketHandler(),
|
||||||
|
@ -270,6 +275,8 @@ func (s *Session) Close(e error, sendConnectionClose bool) error {
|
||||||
s.closed = true
|
s.closed = true
|
||||||
s.closeChan <- struct{}{}
|
s.closeChan <- struct{}{}
|
||||||
|
|
||||||
|
s.closeCallback(s)
|
||||||
|
|
||||||
if !sendConnectionClose {
|
if !sendConnectionClose {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -200,6 +200,7 @@ var _ = Describe("Session", func() {
|
||||||
Context("closing", func() {
|
Context("closing", func() {
|
||||||
var (
|
var (
|
||||||
nGoRoutinesBefore int
|
nGoRoutinesBefore int
|
||||||
|
closed bool
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -208,13 +209,14 @@ var _ = Describe("Session", func() {
|
||||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||||
session = NewSession(conn, 0, 0, scfg, nil).(*Session)
|
session = NewSession(conn, 0, 0, scfg, nil, func(*Session) { closed = true }).(*Session)
|
||||||
go session.Run()
|
go session.Run()
|
||||||
Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore + 2))
|
Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore + 2))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("shuts down without error", func() {
|
It("shuts down without error", func() {
|
||||||
session.Close(nil, true)
|
session.Close(nil, true)
|
||||||
|
Expect(closed).To(BeTrue())
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore))
|
Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore))
|
||||||
})
|
})
|
||||||
|
@ -224,6 +226,7 @@ var _ = Describe("Session", func() {
|
||||||
s, err := session.NewStream(5)
|
s, err := session.NewStream(5)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
session.Close(testErr, true)
|
session.Close(testErr, true)
|
||||||
|
Expect(closed).To(BeTrue())
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore))
|
Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore))
|
||||||
n, err := s.Read([]byte{0})
|
n, err := s.Read([]byte{0})
|
||||||
|
@ -240,7 +243,7 @@ var _ = Describe("Session", func() {
|
||||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||||
session = NewSession(conn, 0, 0, scfg, nil).(*Session)
|
session = NewSession(conn, 0, 0, scfg, nil, nil).(*Session)
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends ack frames", func() {
|
It("sends ack frames", func() {
|
||||||
|
@ -277,7 +280,7 @@ var _ = Describe("Session", func() {
|
||||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||||
session = NewSession(conn, 0, 0, scfg, nil).(*Session)
|
session = NewSession(conn, 0, 0, scfg, nil, func(*Session) {}).(*Session)
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends after queuing a stream frame", func() {
|
It("sends after queuing a stream frame", func() {
|
||||||
|
@ -308,7 +311,7 @@ var _ = Describe("Session", func() {
|
||||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||||
session = NewSession(conn, 0, 0, scfg, nil).(*Session)
|
session = NewSession(conn, 0, 0, scfg, nil, func(*Session) {}).(*Session)
|
||||||
s, err := session.NewStream(3)
|
s, err := session.NewStream(3)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
err = session.handleStreamFrame(&frames.StreamFrame{
|
err = session.handleStreamFrame(&frames.StreamFrame{
|
||||||
|
@ -327,7 +330,7 @@ var _ = Describe("Session", func() {
|
||||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||||
session = NewSession(conn, 0, 0, scfg, nil).(*Session)
|
session = NewSession(conn, 0, 0, scfg, nil, nil).(*Session)
|
||||||
hdr := &PublicHeader{
|
hdr := &PublicHeader{
|
||||||
PacketNumber: 42,
|
PacketNumber: 42,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue