implement a listener that returns early sessions

This commit is contained in:
Marten Seemann 2019-07-27 01:55:46 -04:00
parent cc76441539
commit 5cbb8d6597
5 changed files with 551 additions and 308 deletions

View file

@ -171,6 +171,19 @@ type Session interface {
ConnectionState() tls.ConnectionState
}
// An EarlySession is a session that is handshaking.
// Data sent during the handshake is encrypted using the forward secure keys.
// When using client certificates, the client's identity is only verified
// after completion of the handshake.
type EarlySession interface {
Session
// Blocks until the handshake completes (or fails).
// Data sent before completion of the handshake is encrypted with 1-RTT keys.
// Note that the client's identity hasn't been verified yet.
HandshakeComplete() context.Context
}
// Config contains all configuration data needed for a QUIC server or client.
type Config struct {
// The QUIC versions that can be negotiated.
@ -234,3 +247,14 @@ type Listener interface {
// Accept returns new sessions. It should be called in a loop.
Accept(context.Context) (Session, error)
}
// An EarlyListener listens for incoming QUIC connections,
// and returns them before the handshake completes.
type EarlyListener interface {
// Close the server. All active sessions will be closed.
Close() error
// Addr returns the local network addr that the server is listening on.
Addr() net.Addr
// Accept returns new early sessions. It should be called in a loop.
Accept(context.Context) (EarlySession, error)
}

View file

@ -277,6 +277,20 @@ func (mr *MockQuicSessionMockRecorder) destroy(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQuicSession)(nil).destroy), arg0)
}
// earlySessionReady mocks base method
func (m *MockQuicSession) earlySessionReady() <-chan struct{} {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "earlySessionReady")
ret0, _ := ret[0].(<-chan struct{})
return ret0
}
// earlySessionReady indicates an expected call of earlySessionReady
func (mr *MockQuicSessionMockRecorder) earlySessionReady() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "earlySessionReady", reflect.TypeOf((*MockQuicSession)(nil).earlySessionReady))
}
// getPerspective mocks base method
func (m *MockQuicSession) getPerspective() protocol.Perspective {
m.ctrl.T.Helper()

126
server.go
View file

@ -45,8 +45,8 @@ type packetHandlerManager interface {
}
type quicSession interface {
Session
HandshakeComplete() context.Context
EarlySession
earlySessionReady() <-chan struct{}
handlePacket(*receivedPacket)
GetVersion() protocol.VersionNumber
getPerspective() protocol.Perspective
@ -64,9 +64,11 @@ type sessionRunner interface {
}
// A Listener of QUIC
type server struct {
type baseServer struct {
mutex sync.Mutex
acceptEarlySessions bool
tlsConf *tls.Config
config *Config
@ -86,19 +88,40 @@ type server struct {
errorChan chan struct{}
closed bool
sessionQueue chan Session
sessionQueue chan quicSession
sessionQueueLen int32 // to be used as an atomic
logger utils.Logger
}
var _ Listener = &server{}
var _ unknownPacketHandler = &server{}
var _ Listener = &baseServer{}
var _ unknownPacketHandler = &baseServer{}
type earlyServer struct{ *baseServer }
var _ EarlyListener = &earlyServer{}
func (s *earlyServer) Accept(ctx context.Context) (EarlySession, error) {
return s.baseServer.accept(ctx)
}
// ListenAddr creates a QUIC server listening on a given address.
// The tls.Config must not be nil and must contain a certificate configuration.
// The quic.Config may be nil, in that case the default values will be used.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
return listenAddr(addr, tlsConf, config, false)
}
// ListenAddrEarly works like ListenAddr, but it returns sessions before the handshake completes.
func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyListener, error) {
s, err := listenAddr(addr, tlsConf, config, true)
if err != nil {
return nil, err
}
return &earlyServer{s}, nil
}
func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
@ -107,7 +130,7 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err
if err != nil {
return nil, err
}
serv, err := listen(conn, tlsConf, config)
serv, err := listen(conn, tlsConf, config, acceptEarly)
if err != nil {
return nil, err
}
@ -123,10 +146,20 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err
// Furthermore, it must define an application control (using NextProtos).
// The quic.Config may be nil, in that case the default values will be used.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
return listen(conn, tlsConf, config)
return listen(conn, tlsConf, config, false)
}
func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) {
// ListenEarly works like Listen, but it returns sessions before the handshake completes.
func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (EarlyListener, error) {
s, err := listen(conn, tlsConf, config, true)
if err != nil {
return nil, err
}
s.acceptEarlySessions = true
return &earlyServer{s}, nil
}
func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) {
// TODO(#1655): only require that tls.Config.Certificates or tls.Config.GetCertificate is set
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
@ -146,16 +179,17 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server,
if err != nil {
return nil, err
}
s := &server{
conn: conn,
tlsConf: tlsConf,
config: config,
tokenGenerator: tokenGenerator,
sessionHandler: sessionHandler,
sessionQueue: make(chan Session),
errorChan: make(chan struct{}),
newSession: newSession,
logger: utils.DefaultLogger.WithPrefix("server"),
s := &baseServer{
conn: conn,
tlsConf: tlsConf,
config: config,
tokenGenerator: tokenGenerator,
sessionHandler: sessionHandler,
sessionQueue: make(chan quicSession),
errorChan: make(chan struct{}),
newSession: newSession,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlySessions: acceptEarly,
}
sessionHandler.SetServer(s)
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
@ -248,13 +282,17 @@ func populateServerConfig(config *Config) *Config {
}
}
// Accept returns newly openend sessions
func (s *server) Accept(ctx context.Context) (Session, error) {
var sess Session
// Accept returns sessions that already completed the handshake.
// It is only valid if acceptEarlySessions is false.
func (s *baseServer) Accept(ctx context.Context) (Session, error) {
return s.accept(ctx)
}
func (s *baseServer) accept(ctx context.Context) (quicSession, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case sess = <-s.sessionQueue:
case sess := <-s.sessionQueue:
atomic.AddInt32(&s.sessionQueueLen, -1)
return sess, nil
case <-s.errorChan:
@ -263,7 +301,7 @@ func (s *server) Accept(ctx context.Context) (Session, error) {
}
// Close the server
func (s *server) Close() error {
func (s *baseServer) Close() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closed {
@ -284,7 +322,7 @@ func (s *server) Close() error {
return err
}
func (s *server) setCloseError(e error) {
func (s *baseServer) setCloseError(e error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closed {
@ -296,11 +334,11 @@ func (s *server) setCloseError(e error) {
}
// Addr returns the server's network address
func (s *server) Addr() net.Addr {
func (s *baseServer) Addr() net.Addr {
return s.conn.LocalAddr()
}
func (s *server) handlePacket(p *receivedPacket) {
func (s *baseServer) handlePacket(p *receivedPacket) {
go func() {
if shouldReleaseBuffer := s.handlePacketImpl(p); !shouldReleaseBuffer {
p.buffer.Release()
@ -308,7 +346,7 @@ func (s *server) handlePacket(p *receivedPacket) {
}()
}
func (s *server) handlePacketImpl(p *receivedPacket) bool /* was the packet passed on to a session */ {
func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet passed on to a session */ {
if len(p.data) < protocol.MinInitialPacketSize {
s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", len(p.data))
return false
@ -352,7 +390,7 @@ func (s *server) handlePacketImpl(p *receivedPacket) bool /* was the packet pass
return true
}
func (s *server) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, protocol.ConnectionID, error) {
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, protocol.ConnectionID, error) {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
return nil, nil, errors.New("too short connection ID")
}
@ -402,7 +440,7 @@ func (s *server) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSes
return sess, connID, nil
}
func (s *server) createNewSession(
func (s *baseServer) createNewSession(
remoteAddr net.Addr,
origDestConnID protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
@ -442,16 +480,26 @@ func (s *server) createNewSession(
return nil, err
}
go sess.run()
go s.waitUntilHandshakeComplete(sess)
go s.handleNewSession(sess)
return sess, nil
}
func (s *server) waitUntilHandshakeComplete(sess quicSession) {
func (s *baseServer) handleNewSession(sess quicSession) {
sessCtx := sess.Context()
select {
case <-sess.HandshakeComplete().Done():
case <-sessCtx.Done():
return
if s.acceptEarlySessions {
// wait until the early session is ready (or the handshake fails)
select {
case <-sess.earlySessionReady():
case <-sessCtx.Done():
return
}
} else {
// wait until the handshake is complete (or fails)
select {
case <-sess.HandshakeComplete().Done():
case <-sessCtx.Done():
return
}
}
atomic.AddInt32(&s.sessionQueueLen, 1)
@ -464,7 +512,7 @@ func (s *server) waitUntilHandshakeComplete(sess quicSession) {
}
}
func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
token, err := s.tokenGenerator.NewRetryToken(remoteAddr, hdr.DestConnectionID)
if err != nil {
return err
@ -494,7 +542,7 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
return nil
}
func (s *server) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error {
func (s *baseServer) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error {
sealer, _, err := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer)
if err != nil {
return err
@ -541,7 +589,7 @@ func (s *server) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error {
return nil
}
func (s *server) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) {
func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) {
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
if err != nil {

View file

@ -39,6 +39,12 @@ var _ = Describe("Server", func() {
}
}
parseHeader := func(data []byte) *wire.Header {
hdr, _, _, err := wire.ParsePacket(data, 0)
Expect(err).ToNot(HaveOccurred())
return hdr
}
BeforeEach(func() {
conn = newMockPacketConn()
conn.addr = &net.UDPAddr{}
@ -61,7 +67,7 @@ var _ = Describe("Server", func() {
It("fills in default values if options are not set in the Config", func() {
ln, err := Listen(conn, tlsConf, &Config{})
Expect(err).ToNot(HaveOccurred())
server := ln.(*server)
server := ln.(*baseServer)
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
Expect(server.config.HandshakeTimeout).To(Equal(protocol.DefaultHandshakeTimeout))
Expect(server.config.IdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
@ -86,7 +92,7 @@ var _ = Describe("Server", func() {
}
ln, err := Listen(conn, tlsConf, &config)
Expect(err).ToNot(HaveOccurred())
server := ln.(*server)
server := ln.(*baseServer)
Expect(server.sessionHandler).ToNot(BeNil())
Expect(server.config.Versions).To(Equal(supportedVersions))
Expect(server.config.HandshakeTimeout).To(Equal(1337 * time.Hour))
@ -103,8 +109,7 @@ var _ = Describe("Server", func() {
addr := "127.0.0.1:13579"
ln, err := ListenAddr(addr, tlsConf, &Config{})
Expect(err).ToNot(HaveOccurred())
serv := ln.(*server)
Expect(serv.Addr().String()).To(Equal(addr))
Expect(ln.Addr().String()).To(Equal(addr))
// stop the listener
Expect(ln.Close()).To(Succeed())
})
@ -121,173 +126,427 @@ var _ = Describe("Server", func() {
Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
})
Context("handling packets", func() {
var serv *server
Context("server accepting sessions that completed the handshake", func() {
var serv *baseServer
BeforeEach(func() {
ln, err := Listen(conn, tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
serv = ln.(*server)
serv = ln.(*baseServer)
})
parseHeader := func(data []byte) *wire.Header {
hdr, _, _, err := wire.ParsePacket(data, 0)
Context("handling packets", func() {
It("drops Initial packets with a too short connection ID", func() {
serv.handlePacket(getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
Version: serv.config.Versions[0],
}, nil))
Consistently(conn.dataWritten).ShouldNot(Receive())
})
It("drops too small Initial", func() {
serv.handlePacket(getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize-100),
))
Consistently(conn.dataWritten).ShouldNot(Receive())
})
It("drops packets with a too short connection ID", func() {
serv.handlePacket(getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize)))
Consistently(conn.dataWritten).ShouldNot(Receive())
})
It("drops non-Initial packets", func() {
serv.handlePacket(getPacket(
&wire.Header{
Type: protocol.PacketTypeHandshake,
Version: serv.config.Versions[0],
},
[]byte("invalid"),
))
})
It("decodes the token from the Token field", func() {
raddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 13, 37),
Port: 1337,
}
done := make(chan struct{})
serv.config.AcceptToken = func(addr net.Addr, token *Token) bool {
Expect(addr).To(Equal(raddr))
Expect(token).ToNot(BeNil())
close(done)
return false
}
token, err := serv.tokenGenerator.NewRetryToken(raddr, nil)
Expect(err).ToNot(HaveOccurred())
packet := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Token: token,
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("passes an empty token to the callback, if decoding fails", func() {
raddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 13, 37),
Port: 1337,
}
done := make(chan struct{})
serv.config.AcceptToken = func(addr net.Addr, token *Token) bool {
Expect(addr).To(Equal(raddr))
Expect(token).To(BeNil())
close(done)
return false
}
packet := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Token: []byte("foobar"),
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("sends a Version Negotiation Packet for unsupported versions", func() {
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
packet := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: srcConnID,
DestConnectionID: destConnID,
Version: 0x42,
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
serv.handlePacket(packet)
var write mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
Expect(wire.IsVersionNegotiationPacket(write.data)).To(BeTrue())
hdr := parseHeader(write.data)
Expect(hdr.DestConnectionID).To(Equal(srcConnID))
Expect(hdr.SrcConnectionID).To(Equal(destConnID))
Expect(hdr.SupportedVersions).ToNot(ContainElement(protocol.VersionNumber(0x42)))
})
It("replies with a Retry packet, if a Token is required", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
serv.handlePacket(packet)
var write mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
replyHdr := parseHeader(write.data)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(replyHdr.OrigDestConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.Token).ToNot(BeEmpty())
})
It("creates a session, if no Token is required", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
run := make(chan struct{})
serv.newSession = func(
_ connection,
_ sessionRunner,
origConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
Expect(origConnID).To(Equal(hdr.DestConnectionID))
Expect(destConnID).To(Equal(hdr.SrcConnectionID))
// make sure we're using a server-generated connection ID
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(p)
sess.EXPECT().run().Do(func() { close(run) })
sess.EXPECT().Context().Return(context.Background())
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil
}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.handlePacket(p)
// the Handshake packet is written by the session
Consistently(conn.dataWritten).ShouldNot(Receive())
close(done)
}()
// make sure we're using a server-generated connection ID
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
})
It("rejects new connection attempts if the accept queue is full", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
p.remoteAddr = senderAddr
serv.newSession = func(
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(p)
sess.EXPECT().run()
sess.EXPECT().Context().Return(context.Background())
ctx, cancel := context.WithCancel(context.Background())
cancel()
sess.EXPECT().HandshakeComplete().Return(ctx)
return sess, nil
}
var wg sync.WaitGroup
wg.Add(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
go func() {
defer GinkgoRecover()
defer wg.Done()
serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive())
}()
}
wg.Wait()
serv.handlePacket(p)
var reject mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&reject))
Expect(reject.to).To(Equal(senderAddr))
rejectHdr := parseHeader(reject.data)
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(rejectHdr.Version).To(Equal(hdr.Version))
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
})
It("doesn't accept new sessions if they were closed in the mean time", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
p.remoteAddr = senderAddr
ctx, cancel := context.WithCancel(context.Background())
sessionCreated := make(chan struct{})
sess := NewMockQuicSession(mockCtrl)
serv.newSession = func(
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
sess.EXPECT().handlePacket(p)
sess.EXPECT().run()
sess.EXPECT().Context().Return(ctx)
ctx, cancel := context.WithCancel(context.Background())
cancel()
sess.EXPECT().HandshakeComplete().Return(ctx)
close(sessionCreated)
return sess, nil
}
serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive())
Eventually(sessionCreated).Should(BeClosed())
cancel()
time.Sleep(scaleDuration(200 * time.Millisecond))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.Accept(context.Background())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
sess.EXPECT().getPerspective()
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
})
Context("accepting sessions", func() {
It("returns Accept when an error occurs", func() {
testErr := errors.New("test err")
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
close(done)
}()
serv.setCloseError(testErr)
Eventually(done).Should(BeClosed())
})
It("returns immediately, if an error occurred before", func() {
testErr := errors.New("test err")
serv.setCloseError(testErr)
for i := 0; i < 3; i++ {
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
}
})
It("returns when the context is canceled", func() {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept(ctx)
Expect(err).To(MatchError("context canceled"))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
cancel()
Eventually(done).Should(BeClosed())
})
It("accepts new sessions when the handshake completes", func() {
sess := NewMockQuicSession(mockCtrl)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(s).To(Equal(sess))
close(done)
}()
ctx, cancel := context.WithCancel(context.Background()) // handshake context
serv.newSession = func(
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
sess.EXPECT().HandshakeComplete().Return(ctx)
sess.EXPECT().run().Do(func() {})
sess.EXPECT().Context().Return(context.Background())
return sess, nil
}
_, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Consistently(done).ShouldNot(BeClosed())
cancel() // complete the handshake
Eventually(done).Should(BeClosed())
})
})
})
Context("server accepting sessions that haven't completed the handshake", func() {
var serv *earlyServer
BeforeEach(func() {
ln, err := ListenEarly(conn, tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
return hdr
}
It("drops Initial packets with a too short connection ID", func() {
serv.handlePacket(getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
Version: serv.config.Versions[0],
}, nil))
Consistently(conn.dataWritten).ShouldNot(Receive())
serv = ln.(*earlyServer)
})
It("drops too small Initial", func() {
serv.handlePacket(getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize-100),
))
Consistently(conn.dataWritten).ShouldNot(Receive())
})
It("accepts new sessions when they become ready", func() {
sess := NewMockQuicSession(mockCtrl)
It("drops packets with a too short connection ID", func() {
serv.handlePacket(getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize)))
Consistently(conn.dataWritten).ShouldNot(Receive())
})
It("drops non-Initial packets", func() {
serv.handlePacket(getPacket(
&wire.Header{
Type: protocol.PacketTypeHandshake,
Version: serv.config.Versions[0],
},
[]byte("invalid"),
))
})
It("decodes the token from the Token field", func() {
raddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 13, 37),
Port: 1337,
}
done := make(chan struct{})
serv.config.AcceptToken = func(addr net.Addr, token *Token) bool {
Expect(addr).To(Equal(raddr))
Expect(token).ToNot(BeNil())
go func() {
defer GinkgoRecover()
s, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(s).To(Equal(sess))
close(done)
return false
}
token, err := serv.tokenGenerator.NewRetryToken(raddr, nil)
Expect(err).ToNot(HaveOccurred())
packet := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Token: token,
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
}()
It("passes an empty token to the callback, if decoding fails", func() {
raddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 13, 37),
Port: 1337,
}
done := make(chan struct{})
serv.config.AcceptToken = func(addr net.Addr, token *Token) bool {
Expect(addr).To(Equal(raddr))
Expect(token).To(BeNil())
close(done)
return false
}
packet := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Token: []byte("foobar"),
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("sends a Version Negotiation Packet for unsupported versions", func() {
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
packet := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: srcConnID,
DestConnectionID: destConnID,
Version: 0x42,
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
serv.handlePacket(packet)
var write mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
Expect(wire.IsVersionNegotiationPacket(write.data)).To(BeTrue())
hdr := parseHeader(write.data)
Expect(hdr.DestConnectionID).To(Equal(srcConnID))
Expect(hdr.SrcConnectionID).To(Equal(destConnID))
Expect(hdr.SupportedVersions).ToNot(ContainElement(protocol.VersionNumber(0x42)))
})
It("replies with a Retry packet, if a Token is required", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
serv.handlePacket(packet)
var write mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
replyHdr := parseHeader(write.data)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(replyHdr.OrigDestConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.Token).ToNot(BeEmpty())
})
It("creates a session, if no Token is required", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
run := make(chan struct{})
ready := make(chan struct{})
serv.newSession = func(
_ connection,
_ sessionRunner,
origConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
runner sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
@ -295,29 +554,15 @@ var _ = Describe("Server", func() {
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
Expect(origConnID).To(Equal(hdr.DestConnectionID))
Expect(destConnID).To(Equal(hdr.SrcConnectionID))
// make sure we're using a server-generated connection ID
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(p)
sess.EXPECT().run().Do(func() { close(run) })
sess.EXPECT().run().Do(func() {})
sess.EXPECT().earlySessionReady().Return(ready)
sess.EXPECT().Context().Return(context.Background())
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil
}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.handlePacket(p)
// the Handshake packet is written by the session
Consistently(conn.dataWritten).ShouldNot(Receive())
close(done)
}()
// make sure we're using a server-generated connection ID
Eventually(run).Should(BeClosed())
_, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Consistently(done).ShouldNot(BeClosed())
close(ready)
Eventually(done).Should(BeClosed())
})
@ -347,13 +592,13 @@ var _ = Describe("Server", func() {
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
ready := make(chan struct{})
close(ready)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(p)
sess.EXPECT().run()
sess.EXPECT().earlySessionReady().Return(ready)
sess.EXPECT().Context().Return(context.Background())
ctx, cancel := context.WithCancel(context.Background())
cancel()
sess.EXPECT().HandshakeComplete().Return(ctx)
return sess, nil
}
@ -410,10 +655,8 @@ var _ = Describe("Server", func() {
) (quicSession, error) {
sess.EXPECT().handlePacket(p)
sess.EXPECT().run()
sess.EXPECT().earlySessionReady()
sess.EXPECT().Context().Return(ctx)
ctx, cancel := context.WithCancel(context.Background())
cancel()
sess.EXPECT().HandshakeComplete().Return(ctx)
close(sessionCreated)
return sess, nil
}
@ -438,93 +681,6 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})
})
Context("accepting sessions", func() {
var serv *server
BeforeEach(func() {
ln, err := Listen(conn, tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
serv = ln.(*server)
})
It("returns Accept when an error occurs", func() {
testErr := errors.New("test err")
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
close(done)
}()
serv.setCloseError(testErr)
Eventually(done).Should(BeClosed())
})
It("returns immediately, if an error occurred before", func() {
testErr := errors.New("test err")
serv.setCloseError(testErr)
for i := 0; i < 3; i++ {
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
}
})
It("returns when the context is canceled", func() {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept(ctx)
Expect(err).To(MatchError("context canceled"))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
cancel()
Eventually(done).Should(BeClosed())
})
It("accepts new sessions when the handshake completes", func() {
sess := NewMockQuicSession(mockCtrl)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(s).To(Equal(sess))
close(done)
}()
ctx, cancel := context.WithCancel(context.Background()) // handshake context
serv.newSession = func(
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
sess.EXPECT().HandshakeComplete().Return(ctx)
sess.EXPECT().run().Do(func() {})
sess.EXPECT().Context().Return(context.Background())
return sess, nil
}
_, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Consistently(done).ShouldNot(BeClosed())
cancel() // complete the handshake
Eventually(done).Should(BeClosed())
})
})
})
var _ = Describe("default source address verification", func() {

View file

@ -171,6 +171,7 @@ type session struct {
}
var _ Session = &session{}
var _ EarlySession = &session{}
var _ streamSender = &session{}
var newSession = func(