implement DialNonFWSecure for the client

This commit is contained in:
Marten Seemann 2017-05-09 08:59:29 +08:00
parent e6aeb143a7
commit 2bfa7e59cb
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
4 changed files with 67 additions and 38 deletions

View file

@ -3,6 +3,7 @@ package quic
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"net" "net"
"strings" "strings"
"sync" "sync"
@ -17,11 +18,11 @@ type client struct {
mutex sync.Mutex mutex sync.Mutex
listenErr error listenErr error
conn connection conn connection
hostname string hostname string
errorChan chan struct{}
handshakeChan chan struct{} // is closed as soon as the handshake completes errorChan chan struct{}
handshakeChan <-chan handshakeEvent
config *Config config *Config
versionNegotiated bool // has version negotiation completed yet versionNegotiated bool // has version negotiation completed yet
@ -50,9 +51,9 @@ func DialAddr(addr string, config *Config) (Session, error) {
return Dial(udpConn, udpAddr, addr, config) return Dial(udpConn, udpAddr, addr, config)
} }
// Dial establishes a new QUIC connection to a server using a net.PacketConn. // DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI. // The host parameter is used for SNI.
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) { func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (NonFWSession, error) {
connID, err := utils.GenerateConnectionID() connID, err := utils.GenerateConnectionID()
if err != nil { if err != nil {
return nil, err return nil, err
@ -65,13 +66,12 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
clientConfig := populateClientConfig(config) clientConfig := populateClientConfig(config)
c := &client{ c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr}, conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID, connectionID: connID,
hostname: hostname, hostname: hostname,
config: clientConfig, config: clientConfig,
version: clientConfig.Versions[0], version: clientConfig.Versions[0],
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
handshakeChan: make(chan struct{}),
} }
err = c.createNewSession(nil) err = c.createNewSession(nil)
@ -81,7 +81,21 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version) utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version)
return c.establishConnection() return c.session.(NonFWSession), c.establishSecureConnection()
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, config)
if err != nil {
return nil, err
}
err = sess.WaitUntilHandshakeComplete()
if err != nil {
return nil, err
}
return sess, nil
} }
func populateClientConfig(config *Config) *Config { func populateClientConfig(config *Config) *Config {
@ -97,14 +111,21 @@ func populateClientConfig(config *Config) *Config {
} }
} }
func (c *client) establishConnection() (Session, error) { // establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
func (c *client) establishSecureConnection() error {
go c.listen() go c.listen()
select { select {
case <-c.errorChan: case <-c.errorChan:
return nil, c.listenErr return c.listenErr
case <-c.handshakeChan: case ev := <-c.handshakeChan:
return c.session, nil if ev.err != nil {
return ev.err
}
if ev.encLevel != protocol.EncryptionSecure {
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
}
return nil
} }
} }
@ -204,20 +225,13 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
return c.createNewSession(hdr.SupportedVersions) return c.createNewSession(hdr.SupportedVersions)
} }
func (c *client) cryptoChangeCallback(_ Session, isForwardSecure bool) {
if isForwardSecure {
close(c.handshakeChan)
}
}
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
var err error var err error
c.session, err = newClientSession( c.session, c.handshakeChan, err = newClientSession(
c.conn, c.conn,
c.hostname, c.hostname,
c.version, c.version,
c.connectionID, c.connectionID,
c.cryptoChangeCallback,
c.config, c.config,
negotiatedVersions, negotiatedVersions,
) )

View file

@ -38,7 +38,7 @@ var _ = Describe("Client", func() {
version: protocol.SupportedVersions[0], version: protocol.SupportedVersions[0],
conn: &conn{pconn: packetConn, currentAddr: addr}, conn: &conn{pconn: packetConn, currentAddr: addr},
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
handshakeChan: make(chan struct{}), handshakeChan: make(chan handshakeEvent),
} }
}) })

View file

@ -38,6 +38,11 @@ var (
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that // Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
type cryptoChangeCallback func(session Session, isForwardSecure bool) type cryptoChangeCallback func(session Session, isForwardSecure bool)
type handshakeEvent struct {
encLevel protocol.EncryptionLevel
err error
}
type closeError struct { type closeError struct {
err error err error
remote bool remote bool
@ -88,6 +93,9 @@ type session struct {
// will be closed as soon as the handshake completes, and receive any error that might occur until then // will be closed as soon as the handshake completes, and receive any error that might occur until then
// it is used to block WaitUntilHandshakeComplete() // it is used to block WaitUntilHandshakeComplete()
handshakeCompleteChan chan error handshakeCompleteChan chan error
// handshakeChan receives handshake events and is closed as soon the handshake completes
// the receiving end of this channel is passed to the creator of the session
handshakeChan chan<- handshakeEvent
nextAckScheduledTime time.Time nextAckScheduledTime time.Time
@ -139,6 +147,8 @@ func newSession(
} }
aeadChanged := make(chan protocol.EncryptionLevel, 2) aeadChanged := make(chan protocol.EncryptionLevel, 2)
s.aeadChanged = aeadChanged s.aeadChanged = aeadChanged
handshakeChan := make(chan handshakeEvent, 1)
s.handshakeChan = handshakeChan
var err error var err error
s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, config.Versions, aeadChanged) s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, config.Versions, aeadChanged)
if err != nil { if err != nil {
@ -156,10 +166,9 @@ func newClientSession(
hostname string, hostname string,
v protocol.VersionNumber, v protocol.VersionNumber,
connectionID protocol.ConnectionID, connectionID protocol.ConnectionID,
cryptoChangeCallback cryptoChangeCallback,
config *Config, config *Config,
negotiatedVersions []protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber,
) (*session, error) { ) (*session, <-chan handshakeEvent, error) {
s := &session{ s := &session{
conn: conn, conn: conn,
connectionID: connectionID, connectionID: connectionID,
@ -167,7 +176,6 @@ func newClientSession(
version: v, version: v,
config: config, config: config,
cryptoChangeCallback: cryptoChangeCallback,
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v), connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v),
} }
@ -176,6 +184,8 @@ func newClientSession(
aeadChanged := make(chan protocol.EncryptionLevel, 2) aeadChanged := make(chan protocol.EncryptionLevel, 2)
s.aeadChanged = aeadChanged s.aeadChanged = aeadChanged
handshakeChan := make(chan handshakeEvent, 1)
s.handshakeChan = handshakeChan
cryptoStream, _ := s.OpenStream() cryptoStream, _ := s.OpenStream()
var err error var err error
s.cryptoSetup, err = handshake.NewCryptoSetupClient( s.cryptoSetup, err = handshake.NewCryptoSetupClient(
@ -190,13 +200,13 @@ func newClientSession(
negotiatedVersions, negotiatedVersions,
) )
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version) s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
return s, err return s, handshakeChan, err
} }
// setup is called from newSession and newClientSession and initializes values that are independent of the perspective // setup is called from newSession and newClientSession and initializes values that are independent of the perspective
@ -275,16 +285,22 @@ runLoop:
// begins with the public header and we never copy it. // begins with the public header and we never copy it.
putPacketBuffer(p.publicHeader.Raw) putPacketBuffer(p.publicHeader.Raw)
case l, ok := <-aeadChanged: case l, ok := <-aeadChanged:
if !ok { if !ok { // the aeadChanged chan was closed. This means that the handshake is completed.
s.handshakeComplete = true s.handshakeComplete = true
aeadChanged = nil // prevent this case from ever being selected again aeadChanged = nil // prevent this case from ever being selected again
close(s.handshakeChan)
close(s.handshakeCompleteChan) close(s.handshakeCompleteChan)
} else { } else {
if l == protocol.EncryptionForwardSecure { if l == protocol.EncryptionForwardSecure {
s.packer.SetForwardSecure() s.packer.SetForwardSecure()
} }
s.tryDecryptingQueuedPackets() s.tryDecryptingQueuedPackets()
s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure) // TODO: remove this, when removing the cryptoChangeCallback for the server
if s.perspective == protocol.PerspectiveServer {
s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure)
} else {
s.handshakeChan <- handshakeEvent{encLevel: l}
}
} }
} }
@ -314,6 +330,7 @@ runLoop:
// otherwise this chan will already be closed // otherwise this chan will already be closed
if !s.handshakeComplete { if !s.handshakeComplete {
s.handshakeCompleteChan <- closeErr.err s.handshakeCompleteChan <- closeErr.err
s.handshakeChan <- handshakeEvent{err: closeErr.err}
} }
s.handleCloseError(closeErr) s.handleCloseError(closeErr)
close(s.runClosed) close(s.runClosed)

View file

@ -162,12 +162,11 @@ var _ = Describe("Session", func() {
cpm = &mockConnectionParametersManager{idleTime: 60 * time.Second} cpm = &mockConnectionParametersManager{idleTime: 60 * time.Second}
sess.connectionParameters = cpm sess.connectionParameters = cpm
clientSess, err = newClientSession( clientSess, _, err = newClientSession(
mconn, mconn,
"hostname", "hostname",
protocol.Version35, protocol.Version35,
0, 0,
func(Session, bool) {},
populateClientConfig(&Config{}), populateClientConfig(&Config{}),
nil, nil,
) )
@ -817,12 +816,11 @@ var _ = Describe("Session", func() {
}) })
It("passes the transport parameters to the cryptoSetup, as a client", func() { It("passes the transport parameters to the cryptoSetup, as a client", func() {
s, err := newClientSession( s, _, err := newClientSession(
nil, nil,
"hostname", "hostname",
protocol.Version35, protocol.Version35,
0, 0,
func(Session, bool) {},
populateClientConfig(&Config{RequestConnectionIDTruncation: true}), populateClientConfig(&Config{RequestConnectionIDTruncation: true}),
nil, nil,
) )