diff --git a/common.go b/common.go index eec6e1e..5b68742 100644 --- a/common.go +++ b/common.go @@ -7,6 +7,7 @@ package tls import ( "bytes" "container/list" + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -443,6 +444,16 @@ type ClientHelloInfo struct { // config is embedded by the GetCertificate or GetConfigForClient caller, // for use with SupportsCertificate. config *Config + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *ClientHelloInfo) Context() context.Context { + return c.ctx } // CertificateRequestInfo contains information from a server's @@ -461,6 +472,16 @@ type CertificateRequestInfo struct { // Version is the TLS version that was negotiated for this connection. Version uint16 + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *CertificateRequestInfo) Context() context.Context { + return c.ctx } // RenegotiationSupport enumerates the different levels of support for TLS diff --git a/conn.go b/conn.go index 72ad52c..969f357 100644 --- a/conn.go +++ b/conn.go @@ -8,6 +8,7 @@ package tls import ( "bytes" + "context" "crypto/cipher" "crypto/subtle" "crypto/x509" @@ -27,7 +28,7 @@ type Conn struct { // constant conn net.Conn isClient bool - handshakeFn func() error // (*Conn).clientHandshake or serverHandshake + handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake // handshakeStatus is 1 if the connection is currently transferring // application data (i.e. is not currently processing a handshake). @@ -1190,7 +1191,7 @@ func (c *Conn) handleRenegotiation() error { defer c.handshakeMutex.Unlock() atomic.StoreUint32(&c.handshakeStatus, 0) - if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { + if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { c.handshakes++ } return c.handshakeErr @@ -1373,8 +1374,61 @@ func (c *Conn) closeNotify() error { // first Read or Write will call it automatically. // // For control over canceling or setting a timeout on a handshake, use -// the Dialer's DialContext method. +// HandshakeContext or the Dialer's DialContext method instead. func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +// HandshakeContext runs the client or server handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first Read or Write will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + // Delegate to unexported method for named return + // without confusing documented signature. + return c.handshakeContext(ctx) +} + +func (c *Conn) handshakeContext(ctx context.Context) (ret error) { + handshakeCtx, cancel := context.WithCancel(ctx) + // Note: defer this before starting the "interrupter" goroutine + // so that we can tell the difference between the input being canceled and + // this cancellation. In the former case, we need to close the connection. + defer cancel() + + // Start the "interrupter" goroutine, if this context might be canceled. + // (The background context cannot). + // + // The interrupter goroutine waits for the input context to be done and + // closes the connection if this happens before the function returns. + if ctx.Done() != nil { + done := make(chan struct{}) + interruptRes := make(chan error, 1) + defer func() { + close(done) + if ctxErr := <-interruptRes; ctxErr != nil { + // Return context error to user. + ret = ctxErr + } + }() + go func() { + select { + case <-handshakeCtx.Done(): + // Close the connection, discarding the error + _ = c.conn.Close() + interruptRes <- handshakeCtx.Err() + case <-done: + interruptRes <- nil + } + }() + } + c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() @@ -1388,7 +1442,7 @@ func (c *Conn) Handshake() error { c.in.Lock() defer c.in.Unlock() - c.handshakeErr = c.handshakeFn() + c.handshakeErr = c.handshakeFn(handshakeCtx) if c.handshakeErr == nil { c.handshakes++ } else { diff --git a/handshake_client.go b/handshake_client.go index e684b21..92e33e7 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -24,6 +25,7 @@ import ( type clientHandshakeState struct { c *Conn + ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg suite *cipherSuite @@ -134,7 +136,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { return hello, params, nil } -func (c *Conn) clientHandshake() (err error) { +func (c *Conn) clientHandshake(ctx context.Context) (err error) { if c.config == nil { c.config = defaultConfig() } @@ -198,6 +200,7 @@ func (c *Conn) clientHandshake() (err error) { if c.vers == VersionTLS13 { hs := &clientHandshakeStateTLS13{ c: c, + ctx: ctx, serverHello: serverHello, hello: hello, ecdheParams: ecdheParams, @@ -212,6 +215,7 @@ func (c *Conn) clientHandshake() (err error) { hs := &clientHandshakeState{ c: c, + ctx: ctx, serverHello: serverHello, hello: hello, session: session, @@ -540,7 +544,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { certRequested = true hs.finishedHash.Write(certReq.marshal()) - cri := certificateRequestInfoFromMsg(c.vers, certReq) + cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) if chainToSend, err = c.getClientCertificate(cri); err != nil { c.sendAlert(alertInternalError) return err @@ -880,10 +884,11 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { // certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS // <= 1.2 CertificateRequest, making an effort to fill in missing information. -func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { +func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { cri := &CertificateRequestInfo{ AcceptableCAs: certReq.certificateAuthorities, Version: vers, + ctx: ctx, } var rsaAvail, ecAvail bool diff --git a/handshake_client_test.go b/handshake_client_test.go index 693f968..2f30b30 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto/rsa" "crypto/x509" "encoding/base64" @@ -20,6 +21,7 @@ import ( "os/exec" "path/filepath" "reflect" + "runtime" "strconv" "strings" "testing" @@ -2511,3 +2513,37 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) { serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) } } + +func TestClientHandshakeContextCancellation(t *testing.T) { + c, s := localPipe(t) + serverConfig := testConfig.Clone() + serverErr := make(chan error, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + defer close(serverErr) + defer s.Close() + conn := Server(s, serverConfig) + _, err := conn.readClientHello(ctx) + cancel() + serverErr <- err + }() + cli := Client(c, testConfig) + err := cli.HandshakeContext(ctx) + if err == nil { + t.Fatal("Client handshake did not error when the context was canceled") + } + if err != context.Canceled { + t.Errorf("Unexpected client handshake error: %v", err) + } + if err := <-serverErr; err != nil { + t.Errorf("Unexpected server error: %v", err) + } + if runtime.GOARCH == "wasm" { + t.Skip("conn.Close does not error as expected when called multiple times on WASM") + } + err = cli.Close() + if err == nil { + t.Error("Client connection was not closed when the context was canceled") + } +} diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index daa5d97..be37c68 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/hmac" "crypto/rsa" @@ -17,6 +18,7 @@ import ( type clientHandshakeStateTLS13 struct { c *Conn + ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg ecdheParams ecdheParameters @@ -555,6 +557,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { AcceptableCAs: hs.certReq.certificateAuthorities, SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, Version: c.vers, + ctx: hs.ctx, }) if err != nil { return err diff --git a/handshake_server.go b/handshake_server.go index 9c3e0f6..5a572a9 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -5,6 +5,7 @@ package tls import ( + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -23,6 +24,7 @@ import ( // It's discarded once the handshake has completed. type serverHandshakeState struct { c *Conn + ctx context.Context clientHello *clientHelloMsg hello *serverHelloMsg suite *cipherSuite @@ -37,8 +39,8 @@ type serverHandshakeState struct { } // serverHandshake performs a TLS handshake as a server. -func (c *Conn) serverHandshake() error { - clientHello, err := c.readClientHello() +func (c *Conn) serverHandshake(ctx context.Context) error { + clientHello, err := c.readClientHello(ctx) if err != nil { return err } @@ -46,6 +48,7 @@ func (c *Conn) serverHandshake() error { if c.vers == VersionTLS13 { hs := serverHandshakeStateTLS13{ c: c, + ctx: ctx, clientHello: clientHello, } return hs.handshake() @@ -53,6 +56,7 @@ func (c *Conn) serverHandshake() error { hs := serverHandshakeState{ c: c, + ctx: ctx, clientHello: clientHello, } return hs.handshake() @@ -124,7 +128,7 @@ func (hs *serverHandshakeState) handshake() error { } // readClientHello reads a ClientHello message and selects the protocol version. -func (c *Conn) readClientHello() (*clientHelloMsg, error) { +func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { msg, err := c.readHandshake() if err != nil { return nil, err @@ -138,7 +142,7 @@ func (c *Conn) readClientHello() (*clientHelloMsg, error) { var configForClient *Config originalConfig := c.config if c.config.GetConfigForClient != nil { - chi := clientHelloInfo(c, clientHello) + chi := clientHelloInfo(ctx, c, clientHello) if configForClient, err = c.config.GetConfigForClient(chi); err != nil { c.sendAlert(alertInternalError) return nil, err @@ -220,7 +224,7 @@ func (hs *serverHandshakeState) processClientHello() error { } } - hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) + hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) if err != nil { if err == errNoCertificates { c.sendAlert(alertUnrecognizedName) @@ -828,7 +832,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { return nil } -func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { +func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { supportedVersions := clientHello.supportedVersions if len(clientHello.supportedVersions) == 0 { supportedVersions = supportedVersionsFromMax(clientHello.vers) @@ -844,5 +848,6 @@ func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { SupportedVersions: supportedVersions, Conn: c.conn, config: c.config, + ctx: ctx, } } diff --git a/handshake_server_test.go b/handshake_server_test.go index d6bf9e4..432b4cf 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/elliptic" "crypto/x509" @@ -17,6 +18,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "strings" "testing" "time" @@ -38,10 +40,12 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa cli.writeRecord(recordTypeHandshake, m.marshal()) c.Close() }() + ctx := context.Background() conn := Server(s, serverConfig) - ch, err := conn.readClientHello() + ch, err := conn.readClientHello(ctx) hs := serverHandshakeState{ c: conn, + ctx: ctx, clientHello: ch, } if err == nil { @@ -1421,9 +1425,11 @@ func TestSNIGivenOnFailure(t *testing.T) { c.Close() }() conn := Server(s, serverConfig) - ch, err := conn.readClientHello() + ctx := context.Background() + ch, err := conn.readClientHello(ctx) hs := serverHandshakeState{ c: conn, + ctx: ctx, clientHello: ch, } if err == nil { @@ -1939,3 +1945,112 @@ func TestAESCipherReordering13(t *testing.T) { }) } } + +func TestServerHandshakeContextCancellation(t *testing.T) { + c, s := localPipe(t) + clientConfig := testConfig.Clone() + clientErr := make(chan error, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + defer close(clientErr) + defer c.Close() + clientHello := &clientHelloMsg{ + vers: VersionTLS10, + random: make([]byte, 32), + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + compressionMethods: []uint8{compressionNone}, + } + cli := Client(c, clientConfig) + _, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal()) + cancel() + clientErr <- err + }() + conn := Server(s, testConfig) + err := conn.HandshakeContext(ctx) + if err == nil { + t.Fatal("Server handshake did not error when the context was canceled") + } + if err != context.Canceled { + t.Errorf("Unexpected server handshake error: %v", err) + } + if err := <-clientErr; err != nil { + t.Errorf("Unexpected client error: %v", err) + } + if runtime.GOARCH == "wasm" { + t.Skip("conn.Close does not error as expected when called multiple times on WASM") + } + err = conn.Close() + if err == nil { + t.Error("Server connection was not closed when the context was canceled") + } +} + +// TestHandshakeContextHierarchy tests whether the contexts +// available to GetClientCertificate and GetCertificate are +// derived from the context provided to HandshakeContext, and +// that those contexts are cancelled after HandshakeContext has +// returned. +func TestHandshakeContextHierarchy(t *testing.T) { + c, s := localPipe(t) + clientErr := make(chan error, 1) + clientConfig := testConfig.Clone() + serverConfig := testConfig.Clone() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + key := struct{}{} + ctx = context.WithValue(ctx, key, true) + go func() { + defer close(clientErr) + defer c.Close() + var innerCtx context.Context + clientConfig.Certificates = nil + clientConfig.GetClientCertificate = func(certificateRequest *CertificateRequestInfo) (*Certificate, error) { + if val, ok := certificateRequest.Context().Value(key).(bool); !ok || !val { + t.Errorf("GetClientCertificate context was not child of HandshakeContext") + } + innerCtx = certificateRequest.Context() + return &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + }, nil + } + cli := Client(c, clientConfig) + err := cli.HandshakeContext(ctx) + if err != nil { + clientErr <- err + return + } + select { + case <-innerCtx.Done(): + default: + t.Errorf("GetClientCertificate context was not cancelled after HandshakeContext returned.") + } + }() + var innerCtx context.Context + serverConfig.Certificates = nil + serverConfig.ClientAuth = RequestClientCert + serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { + if val, ok := clientHello.Context().Value(key).(bool); !ok || !val { + t.Errorf("GetClientCertificate context was not child of HandshakeContext") + } + innerCtx = clientHello.Context() + return &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + }, nil + } + conn := Server(s, serverConfig) + err := conn.HandshakeContext(ctx) + if err != nil { + t.Errorf("Unexpected server handshake error: %v", err) + } + select { + case <-innerCtx.Done(): + default: + t.Errorf("GetCertificate context was not cancelled after HandshakeContext returned.") + } + if err := <-clientErr; err != nil { + t.Errorf("Unexpected client error: %v", err) + } +} diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index c2c288a..c7837d2 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/hmac" "crypto/rsa" @@ -23,6 +24,7 @@ const maxClientPSKIdentities = 5 type serverHandshakeStateTLS13 struct { c *Conn + ctx context.Context clientHello *clientHelloMsg hello *serverHelloMsg sentDummyCCS bool @@ -374,7 +376,7 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error { return c.sendAlert(alertMissingExtension) } - certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) + certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) if err != nil { if err == errNoCertificates { c.sendAlert(alertUnrecognizedName) diff --git a/tls.go b/tls.go index 4989364..b529c70 100644 --- a/tls.go +++ b/tls.go @@ -25,7 +25,6 @@ import ( "net" "os" "strings" - "time" ) // Server returns a new TLS server side connection @@ -119,28 +118,16 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (* } func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { - // We want the Timeout and Deadline values from dialer to cover the - // whole process: TCP connection and TLS handshake. This means that we - // also need to start our own timers now. - timeout := netDialer.Timeout + if netDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) + defer cancel() + } if !netDialer.Deadline.IsZero() { - deadlineTimeout := time.Until(netDialer.Deadline) - if timeout == 0 || deadlineTimeout < timeout { - timeout = deadlineTimeout - } - } - - // hsErrCh is non-nil if we might not wait for Handshake to complete. - var hsErrCh chan error - if timeout != 0 || ctx.Done() != nil { - hsErrCh = make(chan error, 2) - } - if timeout != 0 { - timer := time.AfterFunc(timeout, func() { - hsErrCh <- timeoutError{} - }) - defer timer.Stop() + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) + defer cancel() } rawConn, err := netDialer.DialContext(ctx, network, addr) @@ -167,34 +154,10 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf } conn := Client(rawConn, config) - - if hsErrCh == nil { - err = conn.Handshake() - } else { - go func() { - hsErrCh <- conn.Handshake() - }() - - select { - case <-ctx.Done(): - err = ctx.Err() - case err = <-hsErrCh: - if err != nil { - // If the error was due to the context - // closing, prefer the context's error, rather - // than some random network teardown error. - if e := ctx.Err(); e != nil { - err = e - } - } - } - } - - if err != nil { + if err := conn.HandshakeContext(ctx); err != nil { rawConn.Close() return nil, err } - return conn, nil }