diff --git a/client.go b/client.go index e8e9cc87..9c4e3526 100644 --- a/client.go +++ b/client.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "errors" + "fmt" "net" "strings" "sync" @@ -17,11 +18,11 @@ type client struct { mutex sync.Mutex listenErr error - conn connection - hostname string - errorChan chan struct{} + conn connection + hostname string - handshakeChan chan struct{} // is closed as soon as the handshake completes + errorChan chan struct{} + handshakeChan <-chan handshakeEvent config *Config versionNegotiated bool // has version negotiation completed yet @@ -50,9 +51,9 @@ func DialAddr(addr string, config *Config) (Session, error) { return Dial(udpConn, udpAddr, addr, config) } -// Dial establishes a new QUIC connection to a server using a net.PacketConn. +// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn. // The host parameter is used for SNI. -func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) { +func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (NonFWSession, error) { connID, err := utils.GenerateConnectionID() if err != nil { return nil, err @@ -65,13 +66,12 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config clientConfig := populateClientConfig(config) c := &client{ - conn: &conn{pconn: pconn, currentAddr: remoteAddr}, - connectionID: connID, - hostname: hostname, - config: clientConfig, - version: clientConfig.Versions[0], - errorChan: make(chan struct{}), - handshakeChan: make(chan struct{}), + conn: &conn{pconn: pconn, currentAddr: remoteAddr}, + connectionID: connID, + hostname: hostname, + config: clientConfig, + version: clientConfig.Versions[0], + errorChan: make(chan struct{}), } err = c.createNewSession(nil) @@ -81,7 +81,21 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version) - return c.establishConnection() + return c.session.(NonFWSession), c.establishSecureConnection() +} + +// Dial establishes a new QUIC connection to a server using a net.PacketConn. +// The host parameter is used for SNI. +func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) { + sess, err := DialNonFWSecure(pconn, remoteAddr, host, config) + if err != nil { + return nil, err + } + err = sess.WaitUntilHandshakeComplete() + if err != nil { + return nil, err + } + return sess, nil } func populateClientConfig(config *Config) *Config { @@ -97,14 +111,21 @@ func populateClientConfig(config *Config) *Config { } } -func (c *client) establishConnection() (Session, error) { +// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure) +func (c *client) establishSecureConnection() error { go c.listen() select { case <-c.errorChan: - return nil, c.listenErr - case <-c.handshakeChan: - return c.session, nil + return c.listenErr + case ev := <-c.handshakeChan: + if ev.err != nil { + return ev.err + } + if ev.encLevel != protocol.EncryptionSecure { + return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel) + } + return nil } } @@ -204,20 +225,13 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error { return c.createNewSession(hdr.SupportedVersions) } -func (c *client) cryptoChangeCallback(_ Session, isForwardSecure bool) { - if isForwardSecure { - close(c.handshakeChan) - } -} - func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { var err error - c.session, err = newClientSession( + c.session, c.handshakeChan, err = newClientSession( c.conn, c.hostname, c.version, c.connectionID, - c.cryptoChangeCallback, c.config, negotiatedVersions, ) diff --git a/client_test.go b/client_test.go index 04bab6f6..de20e6eb 100644 --- a/client_test.go +++ b/client_test.go @@ -38,7 +38,7 @@ var _ = Describe("Client", func() { version: protocol.SupportedVersions[0], conn: &conn{pconn: packetConn, currentAddr: addr}, errorChan: make(chan struct{}), - handshakeChan: make(chan struct{}), + handshakeChan: make(chan handshakeEvent), } }) diff --git a/session.go b/session.go index 87ab49b5..b8f96d42 100644 --- a/session.go +++ b/session.go @@ -38,6 +38,11 @@ var ( // Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that type cryptoChangeCallback func(session Session, isForwardSecure bool) +type handshakeEvent struct { + encLevel protocol.EncryptionLevel + err error +} + type closeError struct { err error remote bool @@ -88,6 +93,9 @@ type session struct { // will be closed as soon as the handshake completes, and receive any error that might occur until then // it is used to block WaitUntilHandshakeComplete() handshakeCompleteChan chan error + // handshakeChan receives handshake events and is closed as soon the handshake completes + // the receiving end of this channel is passed to the creator of the session + handshakeChan chan<- handshakeEvent nextAckScheduledTime time.Time @@ -139,6 +147,8 @@ func newSession( } aeadChanged := make(chan protocol.EncryptionLevel, 2) s.aeadChanged = aeadChanged + handshakeChan := make(chan handshakeEvent, 1) + s.handshakeChan = handshakeChan var err error s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, config.Versions, aeadChanged) if err != nil { @@ -156,10 +166,9 @@ func newClientSession( hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, - cryptoChangeCallback cryptoChangeCallback, config *Config, negotiatedVersions []protocol.VersionNumber, -) (*session, error) { +) (*session, <-chan handshakeEvent, error) { s := &session{ conn: conn, connectionID: connectionID, @@ -167,7 +176,6 @@ func newClientSession( version: v, config: config, - cryptoChangeCallback: cryptoChangeCallback, connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v), } @@ -176,6 +184,8 @@ func newClientSession( aeadChanged := make(chan protocol.EncryptionLevel, 2) s.aeadChanged = aeadChanged + handshakeChan := make(chan handshakeEvent, 1) + s.handshakeChan = handshakeChan cryptoStream, _ := s.OpenStream() var err error s.cryptoSetup, err = handshake.NewCryptoSetupClient( @@ -190,13 +200,13 @@ func newClientSession( negotiatedVersions, ) if err != nil { - return nil, err + return nil, nil, err } s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version) s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} - return s, err + return s, handshakeChan, err } // setup is called from newSession and newClientSession and initializes values that are independent of the perspective @@ -275,16 +285,22 @@ runLoop: // begins with the public header and we never copy it. putPacketBuffer(p.publicHeader.Raw) case l, ok := <-aeadChanged: - if !ok { + if !ok { // the aeadChanged chan was closed. This means that the handshake is completed. s.handshakeComplete = true aeadChanged = nil // prevent this case from ever being selected again + close(s.handshakeChan) close(s.handshakeCompleteChan) } else { if l == protocol.EncryptionForwardSecure { s.packer.SetForwardSecure() } s.tryDecryptingQueuedPackets() - s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure) + // TODO: remove this, when removing the cryptoChangeCallback for the server + if s.perspective == protocol.PerspectiveServer { + s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure) + } else { + s.handshakeChan <- handshakeEvent{encLevel: l} + } } } @@ -314,6 +330,7 @@ runLoop: // otherwise this chan will already be closed if !s.handshakeComplete { s.handshakeCompleteChan <- closeErr.err + s.handshakeChan <- handshakeEvent{err: closeErr.err} } s.handleCloseError(closeErr) close(s.runClosed) diff --git a/session_test.go b/session_test.go index 97bd04a3..dbb52a74 100644 --- a/session_test.go +++ b/session_test.go @@ -162,12 +162,11 @@ var _ = Describe("Session", func() { cpm = &mockConnectionParametersManager{idleTime: 60 * time.Second} sess.connectionParameters = cpm - clientSess, err = newClientSession( + clientSess, _, err = newClientSession( mconn, "hostname", protocol.Version35, 0, - func(Session, bool) {}, populateClientConfig(&Config{}), nil, ) @@ -817,12 +816,11 @@ var _ = Describe("Session", func() { }) It("passes the transport parameters to the cryptoSetup, as a client", func() { - s, err := newClientSession( + s, _, err := newClientSession( nil, "hostname", protocol.Version35, 0, - func(Session, bool) {}, populateClientConfig(&Config{RequestConnectionIDTruncation: true}), nil, )