implement stateless handling of Initial packets for the TLS server

This commit is contained in:
Marten Seemann 2017-11-11 08:38:27 +08:00
parent 57c6f3ceb5
commit 25a6dc9654
36 changed files with 1617 additions and 724 deletions

106
server.go
View file

@ -19,6 +19,7 @@ import (
// packetHandler handles packets
type packetHandler interface {
Session
getCryptoStream() cryptoStream
handshakeStatus() <-chan handshakeEvent
handlePacket(*receivedPacket)
GetVersion() protocol.VersionNumber
@ -33,6 +34,9 @@ type server struct {
conn net.PacketConn
supportsTLS bool
serverTLS *serverTLS
certChain crypto.CertChain
scfg *handshake.ServerConfig
@ -77,11 +81,21 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
if err != nil {
return nil, err
}
config = populateServerConfig(config)
// check if any of the supported versions supports TLS
var supportsTLS bool
for _, v := range config.Versions {
if v.UsesTLS() {
supportsTLS = true
break
}
}
s := &server{
conn: conn,
tlsConf: tlsConf,
config: populateServerConfig(config),
config: config,
certChain: certChain,
scfg: scfg,
sessions: map[protocol.ConnectionID]packetHandler{},
@ -89,12 +103,47 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
sessionQueue: make(chan Session, 5),
errorChan: make(chan struct{}),
supportsTLS: supportsTLS,
}
if supportsTLS {
if err := s.setupTLS(); err != nil {
return nil, err
}
}
go s.serve()
utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s, nil
}
func (s *server) setupTLS() error {
cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie)
if err != nil {
return err
}
serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf)
if err != nil {
return err
}
s.serverTLS = serverTLS
// handle TLS connection establishment statelessly
go func() {
for {
select {
case <-s.errorChan:
return
case sess := <-sessionChan:
// TODO: think about what to do with connection ID collisions
connID := sess.(*session).connectionID
s.sessionsMutex.Lock()
s.sessions[connID] = sess
s.sessionsMutex.Unlock()
s.runHandshakeAndSession(sess, connID)
}
}
}()
return nil
}
var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
if cookie == nil {
return false
@ -225,8 +274,16 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
}
hdr.Raw = packet[:len(packet)-r.Len()]
packetData := packet[len(packet)-r.Len():]
connID := hdr.ConnectionID
if hdr.Type == protocol.PacketTypeInitial {
if s.supportsTLS {
go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData)
}
return nil
}
s.sessionsMutex.RLock()
session, sessionKnown := s.sessions[connID]
s.sessionsMutex.RUnlock()
@ -279,11 +336,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
return err
}
}
// send an IETF draft style Version Negotiation Packet, if the client sent an unsupported version with an IETF draft style header
if hdr.Type == protocol.PacketTypeInitial && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
_, err := pconn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.config.Versions), remoteAddr)
return err
}
if !sessionKnown {
version := hdr.Version
@ -307,34 +359,38 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
s.sessions[connID] = session
s.sessionsMutex.Unlock()
go func() {
// session.run() returns as soon as the session is closed
_ = session.run()
s.removeConnection(connID)
}()
go func() {
for {
ev := <-session.handshakeStatus()
if ev.err != nil {
return
}
if ev.encLevel == protocol.EncryptionForwardSecure {
break
}
}
s.sessionQueue <- session
}()
s.runHandshakeAndSession(session, connID)
}
session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
data: packet[len(packet)-r.Len():],
data: packetData,
rcvTime: rcvTime,
})
return nil
}
func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) {
go func() {
_ = session.run()
// session.run() returns as soon as the session is closed
s.removeConnection(connID)
}()
go func() {
for {
ev := <-session.handshakeStatus()
if ev.err != nil {
return
}
if ev.encLevel == protocol.EncryptionForwardSecure {
break
}
}
s.sessionQueue <- session
}()
}
func (s *server) removeConnection(id protocol.ConnectionID) {
s.sessionsMutex.Lock()
s.sessions[id] = nil