implement a Listener.Accept() method

This commit is contained in:
Marten Seemann 2017-05-05 15:45:57 +08:00
parent 0bd3b61e6a
commit 30a0211243
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
6 changed files with 104 additions and 121 deletions

View file

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"math/rand"
"net"
"time"
"github.com/lucas-clemente/quic-go/protocol"
@ -27,33 +28,32 @@ var _ = Describe("Benchmarks", func() {
Measure("transferring a file", func(b Benchmarker) {
rand.Read(data) // no need to check for an error. math.Rand.Read never errors
// start the server
sconf := &Config{
TLSConfig: testdata.GetTLSConfig(),
ConnState: func(sess Session, cs ConnState) {
if cs != ConnStateForwardSecure {
return
}
var ln Listener
defer GinkgoRecover()
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(data)
Expect(err).ToNot(HaveOccurred())
err = str.Close()
Expect(err).ToNot(HaveOccurred())
},
}
ln, err := ListenAddr("localhost:0", sconf)
Expect(err).ToNot(HaveOccurred())
// Serve will error as soon as ln is closed. Ignore all errors here
go ln.Serve()
serverAddr := make(chan net.Addr)
// start the server
go func() {
defer GinkgoRecover()
var err error
ln, err = ListenAddr("localhost:0", &Config{TLSConfig: testdata.GetTLSConfig()})
Expect(err).ToNot(HaveOccurred())
serverAddr <- ln.Addr()
sess, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(data)
Expect(err).ToNot(HaveOccurred())
err = str.Close()
Expect(err).ToNot(HaveOccurred())
}()
// start the client
cconf := &Config{
conf := &Config{
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
sess, err := DialAddr(ln.Addr().String(), cconf)
addr := <-serverAddr
sess, err := DialAddr(addr.String(), conf)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())

View file

@ -33,26 +33,22 @@ func main() {
func echoServer() error {
cfgServer := &quic.Config{
TLSConfig: generateTLSConfig(),
ConnState: func(sess quic.Session, cs quic.ConnState) {
// Ignore unless the handshake is finished
if cs != quic.ConnStateForwardSecure {
return
}
go func() {
stream, err := sess.AcceptStream()
if err != nil {
panic(err)
}
// Echo through the loggingWriter
go io.Copy(loggingWriter{stream}, stream)
}()
},
}
listener, err := quic.ListenAddr(addr, cfgServer)
if err != nil {
return err
}
return listener.Serve()
sess, err := listener.Accept()
if err != nil {
return err
}
stream, err := sess.AcceptStream()
if err != nil {
panic(err)
}
// Echo through the loggingWriter
_, err = io.Copy(loggingWriter{stream}, stream)
return err
}
func clientMain() error {

View file

@ -85,13 +85,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error {
config := quic.Config{
TLSConfig: tlsConfig,
ConnState: func(session quic.Session, connState quic.ConnState) {
sess := session.(streamCreator)
if connState == quic.ConnStateVersionNegotiated {
s.handleHeaderStream(sess)
}
},
Versions: protocol.SupportedVersions,
Versions: protocol.SupportedVersions,
}
var ln quic.Listener
@ -107,7 +101,14 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error {
}
s.listener = ln
s.listenerMutex.Unlock()
return ln.Serve()
for {
sess, err := ln.Accept()
if err != nil {
return err
}
go s.handleHeaderStream(sess.(streamCreator))
}
}
func (s *Server) handleHeaderStream(session streamCreator) {

View file

@ -79,6 +79,6 @@ type Listener interface {
Close() error
// Addr returns the local network addr that the server is listening on.
Addr() net.Addr
// Serve starts the main server loop, and blocks until a network error occurs or the server is closed.
Serve() error
// Accept returns new sessions. It should be called in a loop.
Accept() (Session, error)
}

View file

@ -34,6 +34,10 @@ type server struct {
sessionsMutex sync.RWMutex
deleteClosedSessionsAfter time.Duration
serverError error
sessionQueue chan Session
errorChan chan struct{}
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, config *Config) (packetHandler, error)
}
@ -66,7 +70,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
return nil, err
}
return &server{
s := &server{
conn: conn,
config: populateServerConfig(config),
certChain: certChain,
@ -74,7 +78,11 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
sessions: map[protocol.ConnectionID]packetHandler{},
newSession: newSession,
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
}, nil
sessionQueue: make(chan Session, 5),
errorChan: make(chan struct{}),
}
go s.serve()
return s, nil
}
func populateServerConfig(config *Config) *Config {
@ -85,13 +93,12 @@ func populateServerConfig(config *Config) *Config {
return &Config{
TLSConfig: config.TLSConfig,
ConnState: config.ConnState,
Versions: versions,
}
}
// Listen listens on an existing PacketConn
func (s *server) Serve() error {
// serve listens on an existing PacketConn
func (s *server) serve() {
for {
data := getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize]
@ -99,14 +106,27 @@ func (s *server) Serve() error {
// If it does, we only read a truncated packet, which will then end up undecryptable
n, remoteAddr, err := s.conn.ReadFrom(data)
if err != nil {
s.serverError = err
close(s.errorChan)
_ = s.Close()
return err
return
}
data = data[:n]
if err := s.handlePacket(s.conn, remoteAddr, data); err != nil {
utils.Errorf("error handling packet: %s", err.Error())
}
}
}
// Accept returns newly openend sessions
func (s *server) Accept() (Session, error) {
var sess Session
select {
case sess = <-s.sessionQueue:
return sess, nil
case <-s.errorChan:
return nil, s.serverError
}
}
// Close the server
@ -212,10 +232,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
s.removeConnection(hdr.ConnectionID)
}()
if s.config.ConnState != nil {
go s.config.ConnState(session, ConnStateVersionNegotiated)
}
}
if session == nil {
// Late packet for closed session
@ -231,14 +247,8 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
}
func (s *server) cryptoChangeCallback(session Session, isForwardSecure bool) {
var state ConnState
if isForwardSecure {
state = ConnStateForwardSecure
} else {
state = ConnStateSecure
}
if s.config.ConnState != nil {
go s.config.ConnState(session, state)
s.sessionQueue <- session
}
}

View file

@ -87,10 +87,12 @@ var _ = Describe("Server", func() {
BeforeEach(func() {
serv = &server{
sessions: make(map[protocol.ConnectionID]packetHandler),
newSession: newMockSession,
conn: conn,
config: config,
sessions: make(map[protocol.ConnectionID]packetHandler),
newSession: newMockSession,
conn: conn,
config: config,
sessionQueue: make(chan Session, 5),
errorChan: make(chan struct{}),
}
b := &bytes.Buffer{}
utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.SupportedVersions[0]))
@ -115,56 +117,29 @@ var _ = Describe("Server", func() {
})
It("creates new sessions", func() {
var connStateCalled bool
var connStateStatus ConnState
var connStateSession Session
config.ConnState = func(s Session, state ConnState) {
connStateStatus = state
connStateSession = s
connStateCalled = true
}
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[connID].(*mockSession)
Expect(sess.connectionID).To(Equal(connID))
Expect(sess.packetCount).To(Equal(1))
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
Expect(connStateSession).To(Equal(sess))
Expect(connStateStatus).To(Equal(ConnStateVersionNegotiated))
})
It("calls the ConnState callback when the connection is secure", func() {
var connStateCalled bool
var connStateStatus ConnState
var connStateSession Session
config.ConnState = func(s Session, state ConnState) {
connStateStatus = state
connStateSession = s
connStateCalled = true
}
sess := &mockSession{}
serv.cryptoChangeCallback(sess, false)
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
Expect(connStateSession).To(Equal(sess))
Expect(connStateStatus).To(Equal(ConnStateSecure))
})
It("calls the ConnState callback when the connection is forward-secure", func() {
var connStateCalled bool
var connStateStatus ConnState
var connStateSession Session
config.ConnState = func(s Session, state ConnState) {
connStateStatus = state
connStateSession = s
connStateCalled = true
}
It("accepts a session once the connection it is forward secure", func(done Done) {
var acceptedSess Session
go func() {
defer GinkgoRecover()
var err error
acceptedSess, err = serv.Accept()
Expect(err).ToNot(HaveOccurred())
}()
sess := &mockSession{}
// serv.cryptoChangeCallback(sess, false)
// Consistently(func() Session { return acceptedSess }).Should(BeNil())
serv.cryptoChangeCallback(sess, true)
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
Expect(connStateStatus).To(Equal(ConnStateForwardSecure))
Expect(connStateSession).To(Equal(sess))
})
Eventually(func() Session { return acceptedSess }).Should(Equal(sess))
close(done)
}, 0.5)
It("assigns packets to existing sessions", func() {
err := serv.handlePacket(nil, nil, firstPacket)
@ -231,7 +206,7 @@ var _ = Describe("Server", func() {
var returned bool
go func() {
defer GinkgoRecover()
err := ln.Serve()
_, err := ln.Accept()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
returned = true
@ -240,12 +215,15 @@ var _ = Describe("Server", func() {
Eventually(func() bool { return returned }).Should(BeTrue())
})
It("errors when encountering a connection error", func() {
It("errors when encountering a connection error", func(done Done) {
testErr := errors.New("connection error")
conn.readErr = testErr
err := serv.Serve()
go serv.serve()
_, err := serv.Accept()
Expect(err).To(MatchError(testErr))
})
Expect(serv.Close()).To(Succeed())
close(done)
}, 0.5)
It("closes all sessions when encountering a connection error", func() {
err := serv.handlePacket(nil, nil, firstPacket)
@ -254,8 +232,9 @@ var _ = Describe("Server", func() {
Expect(serv.sessions[connID].(*mockSession).closed).To(BeFalse())
testErr := errors.New("connection error")
conn.readErr = testErr
_ = serv.Serve()
Expect(serv.sessions[connID].(*mockSession).closed).To(BeTrue())
go serv.serve()
Eventually(func() bool { return serv.sessions[connID].(*mockSession).closed }).Should(BeTrue())
Expect(serv.Close()).To(Succeed())
})
It("ignores delayed packets with mismatching versions", func() {
@ -324,20 +303,17 @@ var _ = Describe("Server", func() {
})
It("setups with the right values", func() {
var connStateCallback ConnStateCallback = func(_ Session, _ ConnState) {}
supportedVersions := []protocol.VersionNumber{1, 3, 5}
config := Config{
TLSConfig: &tls.Config{},
ConnState: connStateCallback,
Versions: supportedVersions,
}
ln, err := Listen(conn, &config)
Expect(err).ToNot(HaveOccurred())
server := ln.(*server)
Expect(err).ToNot(HaveOccurred())
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
Expect(server.sessions).ToNot(BeNil())
Expect(server.scfg).ToNot(BeNil())
Expect(server.config.ConnState).ToNot(BeNil())
Expect(server.config.Versions).To(Equal(supportedVersions))
})
@ -387,7 +363,7 @@ var _ = Describe("Server", func() {
var returned bool
go func() {
ln.Serve()
ln.Accept()
returned = true
}()
@ -400,7 +376,7 @@ var _ = Describe("Server", func() {
b.Bytes()...,
)
Expect(conn.dataWritten.Bytes()).To(Equal(expected))
Expect(returned).To(BeFalse())
Consistently(func() bool { return returned }).Should(BeFalse())
})
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
@ -410,7 +386,7 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
err := ln.Serve()
_, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
}()