mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
implement a Listener.Accept() method
This commit is contained in:
parent
0bd3b61e6a
commit
30a0211243
6 changed files with 104 additions and 121 deletions
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
|
@ -27,33 +28,32 @@ var _ = Describe("Benchmarks", func() {
|
|||
Measure("transferring a file", func(b Benchmarker) {
|
||||
rand.Read(data) // no need to check for an error. math.Rand.Read never errors
|
||||
|
||||
// start the server
|
||||
sconf := &Config{
|
||||
TLSConfig: testdata.GetTLSConfig(),
|
||||
ConnState: func(sess Session, cs ConnState) {
|
||||
if cs != ConnStateForwardSecure {
|
||||
return
|
||||
}
|
||||
var ln Listener
|
||||
|
||||
defer GinkgoRecover()
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
},
|
||||
}
|
||||
ln, err := ListenAddr("localhost:0", sconf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Serve will error as soon as ln is closed. Ignore all errors here
|
||||
go ln.Serve()
|
||||
serverAddr := make(chan net.Addr)
|
||||
// start the server
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
ln, err = ListenAddr("localhost:0", &Config{TLSConfig: testdata.GetTLSConfig()})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAddr <- ln.Addr()
|
||||
sess, err := ln.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
|
||||
// start the client
|
||||
cconf := &Config{
|
||||
conf := &Config{
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
sess, err := DialAddr(ln.Addr().String(), cconf)
|
||||
addr := <-serverAddr
|
||||
sess, err := DialAddr(addr.String(), conf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -33,26 +33,22 @@ func main() {
|
|||
func echoServer() error {
|
||||
cfgServer := &quic.Config{
|
||||
TLSConfig: generateTLSConfig(),
|
||||
ConnState: func(sess quic.Session, cs quic.ConnState) {
|
||||
// Ignore unless the handshake is finished
|
||||
if cs != quic.ConnStateForwardSecure {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
stream, err := sess.AcceptStream()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Echo through the loggingWriter
|
||||
go io.Copy(loggingWriter{stream}, stream)
|
||||
}()
|
||||
},
|
||||
}
|
||||
listener, err := quic.ListenAddr(addr, cfgServer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return listener.Serve()
|
||||
sess, err := listener.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stream, err := sess.AcceptStream()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Echo through the loggingWriter
|
||||
_, err = io.Copy(loggingWriter{stream}, stream)
|
||||
return err
|
||||
}
|
||||
|
||||
func clientMain() error {
|
||||
|
|
|
@ -85,13 +85,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error {
|
|||
|
||||
config := quic.Config{
|
||||
TLSConfig: tlsConfig,
|
||||
ConnState: func(session quic.Session, connState quic.ConnState) {
|
||||
sess := session.(streamCreator)
|
||||
if connState == quic.ConnStateVersionNegotiated {
|
||||
s.handleHeaderStream(sess)
|
||||
}
|
||||
},
|
||||
Versions: protocol.SupportedVersions,
|
||||
Versions: protocol.SupportedVersions,
|
||||
}
|
||||
|
||||
var ln quic.Listener
|
||||
|
@ -107,7 +101,14 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error {
|
|||
}
|
||||
s.listener = ln
|
||||
s.listenerMutex.Unlock()
|
||||
return ln.Serve()
|
||||
|
||||
for {
|
||||
sess, err := ln.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.handleHeaderStream(sess.(streamCreator))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleHeaderStream(session streamCreator) {
|
||||
|
|
|
@ -79,6 +79,6 @@ type Listener interface {
|
|||
Close() error
|
||||
// Addr returns the local network addr that the server is listening on.
|
||||
Addr() net.Addr
|
||||
// Serve starts the main server loop, and blocks until a network error occurs or the server is closed.
|
||||
Serve() error
|
||||
// Accept returns new sessions. It should be called in a loop.
|
||||
Accept() (Session, error)
|
||||
}
|
||||
|
|
44
server.go
44
server.go
|
@ -34,6 +34,10 @@ type server struct {
|
|||
sessionsMutex sync.RWMutex
|
||||
deleteClosedSessionsAfter time.Duration
|
||||
|
||||
serverError error
|
||||
sessionQueue chan Session
|
||||
errorChan chan struct{}
|
||||
|
||||
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, config *Config) (packetHandler, error)
|
||||
}
|
||||
|
||||
|
@ -66,7 +70,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &server{
|
||||
s := &server{
|
||||
conn: conn,
|
||||
config: populateServerConfig(config),
|
||||
certChain: certChain,
|
||||
|
@ -74,7 +78,11 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
|
|||
sessions: map[protocol.ConnectionID]packetHandler{},
|
||||
newSession: newSession,
|
||||
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
|
||||
}, nil
|
||||
sessionQueue: make(chan Session, 5),
|
||||
errorChan: make(chan struct{}),
|
||||
}
|
||||
go s.serve()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func populateServerConfig(config *Config) *Config {
|
||||
|
@ -85,13 +93,12 @@ func populateServerConfig(config *Config) *Config {
|
|||
|
||||
return &Config{
|
||||
TLSConfig: config.TLSConfig,
|
||||
ConnState: config.ConnState,
|
||||
Versions: versions,
|
||||
}
|
||||
}
|
||||
|
||||
// Listen listens on an existing PacketConn
|
||||
func (s *server) Serve() error {
|
||||
// serve listens on an existing PacketConn
|
||||
func (s *server) serve() {
|
||||
for {
|
||||
data := getPacketBuffer()
|
||||
data = data[:protocol.MaxReceivePacketSize]
|
||||
|
@ -99,14 +106,27 @@ func (s *server) Serve() error {
|
|||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||
n, remoteAddr, err := s.conn.ReadFrom(data)
|
||||
if err != nil {
|
||||
s.serverError = err
|
||||
close(s.errorChan)
|
||||
_ = s.Close()
|
||||
return err
|
||||
return
|
||||
}
|
||||
data = data[:n]
|
||||
if err := s.handlePacket(s.conn, remoteAddr, data); err != nil {
|
||||
utils.Errorf("error handling packet: %s", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accept returns newly openend sessions
|
||||
func (s *server) Accept() (Session, error) {
|
||||
var sess Session
|
||||
select {
|
||||
case sess = <-s.sessionQueue:
|
||||
return sess, nil
|
||||
case <-s.errorChan:
|
||||
return nil, s.serverError
|
||||
}
|
||||
}
|
||||
|
||||
// Close the server
|
||||
|
@ -212,10 +232,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
|||
|
||||
s.removeConnection(hdr.ConnectionID)
|
||||
}()
|
||||
|
||||
if s.config.ConnState != nil {
|
||||
go s.config.ConnState(session, ConnStateVersionNegotiated)
|
||||
}
|
||||
}
|
||||
if session == nil {
|
||||
// Late packet for closed session
|
||||
|
@ -231,14 +247,8 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
|||
}
|
||||
|
||||
func (s *server) cryptoChangeCallback(session Session, isForwardSecure bool) {
|
||||
var state ConnState
|
||||
if isForwardSecure {
|
||||
state = ConnStateForwardSecure
|
||||
} else {
|
||||
state = ConnStateSecure
|
||||
}
|
||||
if s.config.ConnState != nil {
|
||||
go s.config.ConnState(session, state)
|
||||
s.sessionQueue <- session
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -87,10 +87,12 @@ var _ = Describe("Server", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
serv = &server{
|
||||
sessions: make(map[protocol.ConnectionID]packetHandler),
|
||||
newSession: newMockSession,
|
||||
conn: conn,
|
||||
config: config,
|
||||
sessions: make(map[protocol.ConnectionID]packetHandler),
|
||||
newSession: newMockSession,
|
||||
conn: conn,
|
||||
config: config,
|
||||
sessionQueue: make(chan Session, 5),
|
||||
errorChan: make(chan struct{}),
|
||||
}
|
||||
b := &bytes.Buffer{}
|
||||
utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.SupportedVersions[0]))
|
||||
|
@ -115,56 +117,29 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("creates new sessions", func() {
|
||||
var connStateCalled bool
|
||||
var connStateStatus ConnState
|
||||
var connStateSession Session
|
||||
config.ConnState = func(s Session, state ConnState) {
|
||||
connStateStatus = state
|
||||
connStateSession = s
|
||||
connStateCalled = true
|
||||
}
|
||||
err := serv.handlePacket(nil, nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
sess := serv.sessions[connID].(*mockSession)
|
||||
Expect(sess.connectionID).To(Equal(connID))
|
||||
Expect(sess.packetCount).To(Equal(1))
|
||||
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
|
||||
Expect(connStateSession).To(Equal(sess))
|
||||
Expect(connStateStatus).To(Equal(ConnStateVersionNegotiated))
|
||||
})
|
||||
|
||||
It("calls the ConnState callback when the connection is secure", func() {
|
||||
var connStateCalled bool
|
||||
var connStateStatus ConnState
|
||||
var connStateSession Session
|
||||
config.ConnState = func(s Session, state ConnState) {
|
||||
connStateStatus = state
|
||||
connStateSession = s
|
||||
connStateCalled = true
|
||||
}
|
||||
sess := &mockSession{}
|
||||
serv.cryptoChangeCallback(sess, false)
|
||||
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
|
||||
Expect(connStateSession).To(Equal(sess))
|
||||
Expect(connStateStatus).To(Equal(ConnStateSecure))
|
||||
})
|
||||
|
||||
It("calls the ConnState callback when the connection is forward-secure", func() {
|
||||
var connStateCalled bool
|
||||
var connStateStatus ConnState
|
||||
var connStateSession Session
|
||||
config.ConnState = func(s Session, state ConnState) {
|
||||
connStateStatus = state
|
||||
connStateSession = s
|
||||
connStateCalled = true
|
||||
}
|
||||
It("accepts a session once the connection it is forward secure", func(done Done) {
|
||||
var acceptedSess Session
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
acceptedSess, err = serv.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
sess := &mockSession{}
|
||||
// serv.cryptoChangeCallback(sess, false)
|
||||
// Consistently(func() Session { return acceptedSess }).Should(BeNil())
|
||||
serv.cryptoChangeCallback(sess, true)
|
||||
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
|
||||
Expect(connStateStatus).To(Equal(ConnStateForwardSecure))
|
||||
Expect(connStateSession).To(Equal(sess))
|
||||
})
|
||||
Eventually(func() Session { return acceptedSess }).Should(Equal(sess))
|
||||
close(done)
|
||||
}, 0.5)
|
||||
|
||||
It("assigns packets to existing sessions", func() {
|
||||
err := serv.handlePacket(nil, nil, firstPacket)
|
||||
|
@ -231,7 +206,7 @@ var _ = Describe("Server", func() {
|
|||
var returned bool
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := ln.Serve()
|
||||
_, err := ln.Accept()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
|
||||
returned = true
|
||||
|
@ -240,12 +215,15 @@ var _ = Describe("Server", func() {
|
|||
Eventually(func() bool { return returned }).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("errors when encountering a connection error", func() {
|
||||
It("errors when encountering a connection error", func(done Done) {
|
||||
testErr := errors.New("connection error")
|
||||
conn.readErr = testErr
|
||||
err := serv.Serve()
|
||||
go serv.serve()
|
||||
_, err := serv.Accept()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
close(done)
|
||||
}, 0.5)
|
||||
|
||||
It("closes all sessions when encountering a connection error", func() {
|
||||
err := serv.handlePacket(nil, nil, firstPacket)
|
||||
|
@ -254,8 +232,9 @@ var _ = Describe("Server", func() {
|
|||
Expect(serv.sessions[connID].(*mockSession).closed).To(BeFalse())
|
||||
testErr := errors.New("connection error")
|
||||
conn.readErr = testErr
|
||||
_ = serv.Serve()
|
||||
Expect(serv.sessions[connID].(*mockSession).closed).To(BeTrue())
|
||||
go serv.serve()
|
||||
Eventually(func() bool { return serv.sessions[connID].(*mockSession).closed }).Should(BeTrue())
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
It("ignores delayed packets with mismatching versions", func() {
|
||||
|
@ -324,20 +303,17 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("setups with the right values", func() {
|
||||
var connStateCallback ConnStateCallback = func(_ Session, _ ConnState) {}
|
||||
supportedVersions := []protocol.VersionNumber{1, 3, 5}
|
||||
config := Config{
|
||||
TLSConfig: &tls.Config{},
|
||||
ConnState: connStateCallback,
|
||||
Versions: supportedVersions,
|
||||
}
|
||||
ln, err := Listen(conn, &config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server := ln.(*server)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
|
||||
Expect(server.sessions).ToNot(BeNil())
|
||||
Expect(server.scfg).ToNot(BeNil())
|
||||
Expect(server.config.ConnState).ToNot(BeNil())
|
||||
Expect(server.config.Versions).To(Equal(supportedVersions))
|
||||
})
|
||||
|
||||
|
@ -387,7 +363,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
var returned bool
|
||||
go func() {
|
||||
ln.Serve()
|
||||
ln.Accept()
|
||||
returned = true
|
||||
}()
|
||||
|
||||
|
@ -400,7 +376,7 @@ var _ = Describe("Server", func() {
|
|||
b.Bytes()...,
|
||||
)
|
||||
Expect(conn.dataWritten.Bytes()).To(Equal(expected))
|
||||
Expect(returned).To(BeFalse())
|
||||
Consistently(func() bool { return returned }).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
|
||||
|
@ -410,7 +386,7 @@ var _ = Describe("Server", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := ln.Serve()
|
||||
_, err := ln.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue