mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
implement DialNonFWSecure for the client
This commit is contained in:
parent
e6aeb143a7
commit
2bfa7e59cb
4 changed files with 67 additions and 38 deletions
66
client.go
66
client.go
|
@ -3,6 +3,7 @@ package quic
|
|||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -17,11 +18,11 @@ type client struct {
|
|||
mutex sync.Mutex
|
||||
listenErr error
|
||||
|
||||
conn connection
|
||||
hostname string
|
||||
errorChan chan struct{}
|
||||
conn connection
|
||||
hostname string
|
||||
|
||||
handshakeChan chan struct{} // is closed as soon as the handshake completes
|
||||
errorChan chan struct{}
|
||||
handshakeChan <-chan handshakeEvent
|
||||
|
||||
config *Config
|
||||
versionNegotiated bool // has version negotiation completed yet
|
||||
|
@ -50,9 +51,9 @@ func DialAddr(addr string, config *Config) (Session, error) {
|
|||
return Dial(udpConn, udpAddr, addr, config)
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) {
|
||||
func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (NonFWSession, error) {
|
||||
connID, err := utils.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -65,13 +66,12 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
|
|||
|
||||
clientConfig := populateClientConfig(config)
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
connectionID: connID,
|
||||
hostname: hostname,
|
||||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
errorChan: make(chan struct{}),
|
||||
handshakeChan: make(chan struct{}),
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
connectionID: connID,
|
||||
hostname: hostname,
|
||||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
errorChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
err = c.createNewSession(nil)
|
||||
|
@ -81,7 +81,21 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
|
|||
|
||||
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
|
||||
return c.establishConnection()
|
||||
return c.session.(NonFWSession), c.establishSecureConnection()
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) {
|
||||
sess, err := DialNonFWSecure(pconn, remoteAddr, host, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = sess.WaitUntilHandshakeComplete()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func populateClientConfig(config *Config) *Config {
|
||||
|
@ -97,14 +111,21 @@ func populateClientConfig(config *Config) *Config {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *client) establishConnection() (Session, error) {
|
||||
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
go c.listen()
|
||||
|
||||
select {
|
||||
case <-c.errorChan:
|
||||
return nil, c.listenErr
|
||||
case <-c.handshakeChan:
|
||||
return c.session, nil
|
||||
return c.listenErr
|
||||
case ev := <-c.handshakeChan:
|
||||
if ev.err != nil {
|
||||
return ev.err
|
||||
}
|
||||
if ev.encLevel != protocol.EncryptionSecure {
|
||||
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -204,20 +225,13 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
|
|||
return c.createNewSession(hdr.SupportedVersions)
|
||||
}
|
||||
|
||||
func (c *client) cryptoChangeCallback(_ Session, isForwardSecure bool) {
|
||||
if isForwardSecure {
|
||||
close(c.handshakeChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
|
||||
var err error
|
||||
c.session, err = newClientSession(
|
||||
c.session, c.handshakeChan, err = newClientSession(
|
||||
c.conn,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.connectionID,
|
||||
c.cryptoChangeCallback,
|
||||
c.config,
|
||||
negotiatedVersions,
|
||||
)
|
||||
|
|
|
@ -38,7 +38,7 @@ var _ = Describe("Client", func() {
|
|||
version: protocol.SupportedVersions[0],
|
||||
conn: &conn{pconn: packetConn, currentAddr: addr},
|
||||
errorChan: make(chan struct{}),
|
||||
handshakeChan: make(chan struct{}),
|
||||
handshakeChan: make(chan handshakeEvent),
|
||||
}
|
||||
})
|
||||
|
||||
|
|
31
session.go
31
session.go
|
@ -38,6 +38,11 @@ var (
|
|||
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
|
||||
type cryptoChangeCallback func(session Session, isForwardSecure bool)
|
||||
|
||||
type handshakeEvent struct {
|
||||
encLevel protocol.EncryptionLevel
|
||||
err error
|
||||
}
|
||||
|
||||
type closeError struct {
|
||||
err error
|
||||
remote bool
|
||||
|
@ -88,6 +93,9 @@ type session struct {
|
|||
// will be closed as soon as the handshake completes, and receive any error that might occur until then
|
||||
// it is used to block WaitUntilHandshakeComplete()
|
||||
handshakeCompleteChan chan error
|
||||
// handshakeChan receives handshake events and is closed as soon the handshake completes
|
||||
// the receiving end of this channel is passed to the creator of the session
|
||||
handshakeChan chan<- handshakeEvent
|
||||
|
||||
nextAckScheduledTime time.Time
|
||||
|
||||
|
@ -139,6 +147,8 @@ func newSession(
|
|||
}
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
s.aeadChanged = aeadChanged
|
||||
handshakeChan := make(chan handshakeEvent, 1)
|
||||
s.handshakeChan = handshakeChan
|
||||
var err error
|
||||
s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, config.Versions, aeadChanged)
|
||||
if err != nil {
|
||||
|
@ -156,10 +166,9 @@ func newClientSession(
|
|||
hostname string,
|
||||
v protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
cryptoChangeCallback cryptoChangeCallback,
|
||||
config *Config,
|
||||
negotiatedVersions []protocol.VersionNumber,
|
||||
) (*session, error) {
|
||||
) (*session, <-chan handshakeEvent, error) {
|
||||
s := &session{
|
||||
conn: conn,
|
||||
connectionID: connectionID,
|
||||
|
@ -167,7 +176,6 @@ func newClientSession(
|
|||
version: v,
|
||||
config: config,
|
||||
|
||||
cryptoChangeCallback: cryptoChangeCallback,
|
||||
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v),
|
||||
}
|
||||
|
||||
|
@ -176,6 +184,8 @@ func newClientSession(
|
|||
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
s.aeadChanged = aeadChanged
|
||||
handshakeChan := make(chan handshakeEvent, 1)
|
||||
s.handshakeChan = handshakeChan
|
||||
cryptoStream, _ := s.OpenStream()
|
||||
var err error
|
||||
s.cryptoSetup, err = handshake.NewCryptoSetupClient(
|
||||
|
@ -190,13 +200,13 @@ func newClientSession(
|
|||
negotiatedVersions,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version)
|
||||
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
|
||||
|
||||
return s, err
|
||||
return s, handshakeChan, err
|
||||
}
|
||||
|
||||
// setup is called from newSession and newClientSession and initializes values that are independent of the perspective
|
||||
|
@ -275,16 +285,22 @@ runLoop:
|
|||
// begins with the public header and we never copy it.
|
||||
putPacketBuffer(p.publicHeader.Raw)
|
||||
case l, ok := <-aeadChanged:
|
||||
if !ok {
|
||||
if !ok { // the aeadChanged chan was closed. This means that the handshake is completed.
|
||||
s.handshakeComplete = true
|
||||
aeadChanged = nil // prevent this case from ever being selected again
|
||||
close(s.handshakeChan)
|
||||
close(s.handshakeCompleteChan)
|
||||
} else {
|
||||
if l == protocol.EncryptionForwardSecure {
|
||||
s.packer.SetForwardSecure()
|
||||
}
|
||||
s.tryDecryptingQueuedPackets()
|
||||
s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure)
|
||||
// TODO: remove this, when removing the cryptoChangeCallback for the server
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure)
|
||||
} else {
|
||||
s.handshakeChan <- handshakeEvent{encLevel: l}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -314,6 +330,7 @@ runLoop:
|
|||
// otherwise this chan will already be closed
|
||||
if !s.handshakeComplete {
|
||||
s.handshakeCompleteChan <- closeErr.err
|
||||
s.handshakeChan <- handshakeEvent{err: closeErr.err}
|
||||
}
|
||||
s.handleCloseError(closeErr)
|
||||
close(s.runClosed)
|
||||
|
|
|
@ -162,12 +162,11 @@ var _ = Describe("Session", func() {
|
|||
cpm = &mockConnectionParametersManager{idleTime: 60 * time.Second}
|
||||
sess.connectionParameters = cpm
|
||||
|
||||
clientSess, err = newClientSession(
|
||||
clientSess, _, err = newClientSession(
|
||||
mconn,
|
||||
"hostname",
|
||||
protocol.Version35,
|
||||
0,
|
||||
func(Session, bool) {},
|
||||
populateClientConfig(&Config{}),
|
||||
nil,
|
||||
)
|
||||
|
@ -817,12 +816,11 @@ var _ = Describe("Session", func() {
|
|||
})
|
||||
|
||||
It("passes the transport parameters to the cryptoSetup, as a client", func() {
|
||||
s, err := newClientSession(
|
||||
s, _, err := newClientSession(
|
||||
nil,
|
||||
"hostname",
|
||||
protocol.Version35,
|
||||
0,
|
||||
func(Session, bool) {},
|
||||
populateClientConfig(&Config{RequestConnectionIDTruncation: true}),
|
||||
nil,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue