diff --git a/cipher_suites.go b/cipher_suites.go index cbe14f8..6596562 100644 --- a/cipher_suites.go +++ b/cipher_suites.go @@ -162,7 +162,7 @@ type cipherSuite struct { // flags is a bitmask of the suite* values, above. flags int cipher func(key, iv []byte, isRead bool) interface{} - mac func(version uint16, macKey []byte) macFunction + mac func(key []byte) hash.Hash aead func(key, fixedNonce []byte) aead } @@ -249,30 +249,21 @@ func cipherAES(key, iv []byte, isRead bool) interface{} { return cipher.NewCBCEncrypter(block, iv) } -// macSHA1 returns a macFunction for the given protocol version. -func macSHA1(version uint16, key []byte) macFunction { +// macSHA1 returns a SHA-1 based constant time MAC. +func macSHA1(key []byte) hash.Hash { h := sha1.New // The BoringCrypto SHA1 does not have a constant-time // checksum function, so don't try to use it. if !boring.Enabled { h = newConstantTimeHash(h) } - return tls10MAC{h: hmac.New(h, key)} + return hmac.New(h, key) } -// macSHA256 returns a SHA-256 based MAC. These are only supported in TLS 1.2 -// so the given version is ignored. -func macSHA256(version uint16, key []byte) macFunction { - return tls10MAC{h: hmac.New(sha256.New, key)} -} - -type macFunction interface { - // Size returns the length of the MAC. - Size() int - // MAC appends the MAC of (seq, header, data) to out. The extra data is fed - // into the MAC after obtaining the result to normalize timing. The result - // is only valid until the next invocation of MAC as the buffer is reused. - MAC(seq, header, data, extra []byte) []byte +// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and +// is currently only used in disabled-by-default cipher suites. +func macSHA256(key []byte) hash.Hash { + return hmac.New(sha256.New, key) } type aead interface { @@ -431,26 +422,14 @@ func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { } // tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. -type tls10MAC struct { - h hash.Hash - buf []byte -} - -func (s tls10MAC) Size() int { - return s.h.Size() -} - -// MAC is guaranteed to take constant time, as long as -// len(seq)+len(header)+len(data)+len(extra) is constant. extra is not fed into -// the MAC, but is only provided to make the timing profile constant. -func (s tls10MAC) MAC(seq, header, data, extra []byte) []byte { - s.h.Reset() - s.h.Write(seq) - s.h.Write(header) - s.h.Write(data) - res := s.h.Sum(s.buf[:0]) +func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte { + h.Reset() + h.Write(seq) + h.Write(header) + h.Write(data) + res := h.Sum(out) if extra != nil { - s.h.Write(extra) + h.Write(extra) } return res } diff --git a/common.go b/common.go index 5dd586a..7a00994 100644 --- a/common.go +++ b/common.go @@ -7,6 +7,7 @@ package tls import ( "bytes" "container/list" + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -294,10 +295,26 @@ func (cs *ConnectionState) ExportKeyingMaterial(label string, context []byte, le type ClientAuthType int const ( + // NoClientCert indicates that no client certificate should be requested + // during the handshake, and if any certificates are sent they will not + // be verified. NoClientCert ClientAuthType = iota + // RequestClientCert indicates that a client certificate should be requested + // during the handshake, but does not require that the client send any + // certificates. RequestClientCert + // RequireAnyClientCert indicates that a client certificate should be requested + // during the handshake, and that at least one certificate is required to be + // sent by the client, but that certificate is not required to be valid. RequireAnyClientCert + // VerifyClientCertIfGiven indicates that a client certificate should be requested + // during the handshake, but does not require that the client sends a + // certificate. If the client does send a certificate it is required to be + // valid. VerifyClientCertIfGiven + // RequireAndVerifyClientCert indicates that a client certificate should be requested + // during the handshake, and that at least one valid certificate is required + // to be sent by the client. RequireAndVerifyClientCert ) @@ -428,6 +445,16 @@ type ClientHelloInfo struct { // config is embedded by the GetCertificate or GetConfigForClient caller, // for use with SupportsCertificate. config *Config + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *ClientHelloInfo) Context() context.Context { + return c.ctx } // CertificateRequestInfo contains information from a server's @@ -446,6 +473,16 @@ type CertificateRequestInfo struct { // Version is the TLS version that was negotiated for this connection. Version uint16 + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *CertificateRequestInfo) Context() context.Context { + return c.ctx } // RenegotiationSupport enumerates the different levels of support for TLS @@ -1256,7 +1293,9 @@ func (c *Config) BuildNameToCertificate() { if err != nil { continue } - if len(x509Cert.Subject.CommonName) > 0 { + // If SANs are *not* present, some clients will consider the certificate + // valid for the name in the Common Name. + if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 { c.NameToCertificate[x509Cert.Subject.CommonName] = cert } for _, san := range x509Cert.DNSNames { diff --git a/conn.go b/conn.go index f1d4cb9..2788c3c 100644 --- a/conn.go +++ b/conn.go @@ -8,11 +8,13 @@ package tls import ( "bytes" + "context" "crypto/cipher" "crypto/subtle" "crypto/x509" "errors" "fmt" + "hash" "io" "net" "sync" @@ -26,7 +28,7 @@ type Conn struct { // constant conn net.Conn isClient bool - handshakeFn func() error // (*Conn).clientHandshake or serverHandshake + handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake // handshakeStatus is 1 if the connection is currently transferring // application data (i.e. is not currently processing a handshake). @@ -94,7 +96,6 @@ type Conn struct { rawInput bytes.Buffer // raw input, starting with a record header input bytes.Reader // application data waiting to be read, from rawInput.Next hand bytes.Buffer // handshake data waiting to be read - outBuf []byte // scratch buffer used by out.encrypt buffering bool // whether records are buffered in sendBuf sendBuf []byte // a buffer of records waiting to be sent @@ -155,15 +156,16 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { type halfConn struct { sync.Mutex - err error // first permanent error - version uint16 // protocol version - cipher interface{} // cipher algorithm - mac macFunction - seq [8]byte // 64-bit sequence number - additionalData [13]byte // to avoid allocs; interface method args escape + err error // first permanent error + version uint16 // protocol version + cipher interface{} // cipher algorithm + mac hash.Hash + seq [8]byte // 64-bit sequence number + + scratchBuf [13]byte // to avoid allocs; interface method args escape nextCipher interface{} // next encryption state - nextMac macFunction // next MAC algorithm + nextMac hash.Hash // next MAC algorithm trafficSecret []byte // current TLS 1.3 traffic secret } @@ -188,7 +190,7 @@ func (hc *halfConn) setErrorLocked(err error) error { // prepareCipherSpec sets the encryption and MAC states // that a subsequent changeCipherSpec will use. -func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) { +func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) { hc.version = version hc.nextCipher = cipher hc.nextMac = mac @@ -350,15 +352,14 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { } payload = payload[explicitNonceLen:] - additionalData := hc.additionalData[:] + var additionalData []byte if hc.version == VersionTLS13 { additionalData = record[:recordHeaderLen] } else { - copy(additionalData, hc.seq[:]) - copy(additionalData[8:], record[:3]) + additionalData = append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:3]...) n := len(payload) - c.Overhead() - additionalData[11] = byte(n >> 8) - additionalData[12] = byte(n) + additionalData = append(additionalData, byte(n>>8), byte(n)) } var err error @@ -424,7 +425,7 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { record[3] = byte(n >> 8) record[4] = byte(n) remoteMAC := payload[n : n+macSize] - localMAC := hc.mac.MAC(hc.seq[0:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) + localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) // This is equivalent to checking the MACs and paddingGood // separately, but in constant-time to prevent distinguishing @@ -460,7 +461,7 @@ func sliceForAppend(in []byte, n int) (head, tail []byte) { } // encrypt encrypts payload, adding the appropriate nonce and/or MAC, and -// appends it to record, which contains the record header. +// appends it to record, which must already contain the record header. func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { if hc.cipher == nil { return append(record, payload...), nil @@ -477,7 +478,7 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err // an 8 bytes nonce but its nonces must be unpredictable (see RFC // 5246, Appendix F.3), forcing us to use randomness. That's not // 3DES' biggest problem anyway because the birthday bound on block - // collision is reached first due to its simlarly small block size + // collision is reached first due to its similarly small block size // (see the Sweet32 attack). copy(explicitNonce, hc.seq[:]) } else { @@ -487,14 +488,10 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err } } - var mac []byte - if hc.mac != nil { - mac = hc.mac.MAC(hc.seq[:], record[:recordHeaderLen], payload, nil) - } - var dst []byte switch c := hc.cipher.(type) { case cipher.Stream: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) record, dst = sliceForAppend(record, len(payload)+len(mac)) c.XORKeyStream(dst[:len(payload)], payload) c.XORKeyStream(dst[len(payload):], mac) @@ -518,11 +515,12 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err record = c.Seal(record[:recordHeaderLen], nonce, record[recordHeaderLen:], record[:recordHeaderLen]) } else { - copy(hc.additionalData[:], hc.seq[:]) - copy(hc.additionalData[8:], record) - record = c.Seal(record, nonce, payload, hc.additionalData[:]) + additionalData := append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:recordHeaderLen]...) + record = c.Seal(record, nonce, payload, additionalData) } case cbcMode: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) blockSize := c.BlockSize() plaintextLen := len(payload) + len(mac) paddingLen := blockSize - plaintextLen%blockSize @@ -928,9 +926,28 @@ func (c *Conn) flush() (int, error) { return n, err } +// outBufPool pools the record-sized scratch buffers used by writeRecordLocked. +var outBufPool = sync.Pool{ + New: func() interface{} { + return new([]byte) + }, +} + // writeRecordLocked writes a TLS record with the given type and payload to the // connection and updates the record layer state. func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { + outBufPtr := outBufPool.Get().(*[]byte) + outBuf := *outBufPtr + defer func() { + // You might be tempted to simplify this by just passing &outBuf to Put, + // but that would make the local copy of the outBuf slice header escape + // to the heap, causing an allocation. Instead, we keep around the + // pointer to the slice header returned by Get, which is already on the + // heap, and overwrite and return that. + *outBufPtr = outBuf + outBufPool.Put(outBufPtr) + }() + var n int for len(data) > 0 { m := len(data) @@ -938,8 +955,8 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { m = maxPayload } - _, c.outBuf = sliceForAppend(c.outBuf[:0], recordHeaderLen) - c.outBuf[0] = byte(typ) + _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) + outBuf[0] = byte(typ) vers := c.vers if vers == 0 { // Some TLS servers fail if the record version is @@ -950,17 +967,17 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { // See RFC 8446, Section 5.1. vers = VersionTLS12 } - c.outBuf[1] = byte(vers >> 8) - c.outBuf[2] = byte(vers) - c.outBuf[3] = byte(m >> 8) - c.outBuf[4] = byte(m) + outBuf[1] = byte(vers >> 8) + outBuf[2] = byte(vers) + outBuf[3] = byte(m >> 8) + outBuf[4] = byte(m) var err error - c.outBuf, err = c.out.encrypt(c.outBuf, data[:m], c.config.rand()) + outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand()) if err != nil { return n, err } - if _, err := c.write(c.outBuf); err != nil { + if _, err := c.write(outBuf); err != nil { return n, err } n += m @@ -1074,6 +1091,11 @@ var ( ) // Write writes data to the connection. +// +// As Write calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Write is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. func (c *Conn) Write(b []byte) (int, error) { // interlock with Close below for { @@ -1169,7 +1191,7 @@ func (c *Conn) handleRenegotiation() error { defer c.handshakeMutex.Unlock() atomic.StoreUint32(&c.handshakeStatus, 0) - if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { + if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { c.handshakes++ } return c.handshakeErr @@ -1232,8 +1254,12 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { return nil } -// Read can be made to time out and return a net.Error with Timeout() == true -// after a fixed time limit; see SetDeadline and SetReadDeadline. +// Read reads data from the connection. +// +// As Read calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Read is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. func (c *Conn) Read(b []byte) (int, error) { if err := c.Handshake(); err != nil { return 0, err @@ -1301,9 +1327,10 @@ func (c *Conn) Close() error { } var alertErr error - if c.handshakeComplete() { - alertErr = c.closeNotify() + if err := c.closeNotify(); err != nil { + alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) + } } if err := c.conn.Close(); err != nil { @@ -1330,8 +1357,12 @@ func (c *Conn) closeNotify() error { defer c.out.Unlock() if !c.closeNotifySent { + // Set a Write Deadline to prevent possibly blocking forever. + c.SetWriteDeadline(time.Now().Add(time.Second * 5)) c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) c.closeNotifySent = true + // Any subsequent writes will fail. + c.SetWriteDeadline(time.Now()) } return c.closeNotifyErr } @@ -1343,8 +1374,61 @@ func (c *Conn) closeNotify() error { // first Read or Write will call it automatically. // // For control over canceling or setting a timeout on a handshake, use -// the Dialer's DialContext method. +// HandshakeContext or the Dialer's DialContext method instead. func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +// HandshakeContext runs the client or server handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first Read or Write will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + // Delegate to unexported method for named return + // without confusing documented signature. + return c.handshakeContext(ctx) +} + +func (c *Conn) handshakeContext(ctx context.Context) (ret error) { + handshakeCtx, cancel := context.WithCancel(ctx) + // Note: defer this before starting the "interrupter" goroutine + // so that we can tell the difference between the input being canceled and + // this cancellation. In the former case, we need to close the connection. + defer cancel() + + // Start the "interrupter" goroutine, if this context might be canceled. + // (The background context cannot). + // + // The interrupter goroutine waits for the input context to be done and + // closes the connection if this happens before the function returns. + if ctx.Done() != nil { + done := make(chan struct{}) + interruptRes := make(chan error, 1) + defer func() { + close(done) + if ctxErr := <-interruptRes; ctxErr != nil { + // Return context error to user. + ret = ctxErr + } + }() + go func() { + select { + case <-handshakeCtx.Done(): + // Close the connection, discarding the error + _ = c.conn.Close() + interruptRes <- handshakeCtx.Err() + case <-done: + interruptRes <- nil + } + }() + } + c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() @@ -1358,7 +1442,7 @@ func (c *Conn) Handshake() error { c.in.Lock() defer c.in.Unlock() - c.handshakeErr = c.handshakeFn() + c.handshakeErr = c.handshakeFn(handshakeCtx) if c.handshakeErr == nil { c.handshakes++ } else { diff --git a/handshake_client.go b/handshake_client.go index e089762..6d32a72 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -14,6 +15,7 @@ import ( "crypto/x509" "errors" "fmt" + "hash" "io" "net" "strings" @@ -23,6 +25,7 @@ import ( type clientHandshakeState struct { c *Conn + ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg suite *cipherSuite @@ -136,7 +139,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { return hello, params, nil } -func (c *Conn) clientHandshake() (err error) { +func (c *Conn) clientHandshake(ctx context.Context) (err error) { if c.config == nil { c.config = defaultConfig() } @@ -200,6 +203,7 @@ func (c *Conn) clientHandshake() (err error) { if c.vers == VersionTLS13 { hs := &clientHandshakeStateTLS13{ c: c, + ctx: ctx, serverHello: serverHello, hello: hello, ecdheParams: ecdheParams, @@ -214,6 +218,7 @@ func (c *Conn) clientHandshake() (err error) { hs := &clientHandshakeState{ c: c, + ctx: ctx, serverHello: serverHello, hello: hello, session: session, @@ -542,7 +547,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { certRequested = true hs.finishedHash.Write(certReq.marshal()) - cri := certificateRequestInfoFromMsg(c.vers, certReq) + cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) if chainToSend, err = c.getClientCertificate(cri); err != nil { c.sendAlert(alertInternalError) return err @@ -650,12 +655,12 @@ func (hs *clientHandshakeState) establishKeys() error { clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) var clientCipher, serverCipher interface{} - var clientHash, serverHash macFunction + var clientHash, serverHash hash.Hash if hs.suite.cipher != nil { clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) - clientHash = hs.suite.mac(c.vers, clientMAC) + clientHash = hs.suite.mac(clientMAC) serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) - serverHash = hs.suite.mac(c.vers, serverMAC) + serverHash = hs.suite.mac(serverMAC) } else { clientCipher = hs.suite.aead(clientKey, clientIV) serverCipher = hs.suite.aead(serverKey, serverIV) @@ -884,10 +889,11 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { // certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS // <= 1.2 CertificateRequest, making an effort to fill in missing information. -func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { +func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { cri := &CertificateRequestInfo{ AcceptableCAs: certReq.certificateAuthorities, Version: vers, + ctx: ctx, } var rsaAvail, ecAvail bool diff --git a/handshake_client_test.go b/handshake_client_test.go index 12b0254..8889e2c 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto/rsa" "crypto/x509" "encoding/base64" @@ -20,6 +21,7 @@ import ( "os/exec" "path/filepath" "reflect" + "runtime" "strconv" "strings" "testing" @@ -2511,3 +2513,37 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) { serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) } } + +func TestClientHandshakeContextCancellation(t *testing.T) { + c, s := localPipe(t) + serverConfig := testConfig.Clone() + serverErr := make(chan error, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + defer close(serverErr) + defer s.Close() + conn := Server(s, serverConfig) + _, err := conn.readClientHello(ctx) + cancel() + serverErr <- err + }() + cli := Client(c, testConfig) + err := cli.HandshakeContext(ctx) + if err == nil { + t.Fatal("Client handshake did not error when the context was canceled") + } + if err != context.Canceled { + t.Errorf("Unexpected client handshake error: %v", err) + } + if err := <-serverErr; err != nil { + t.Errorf("Unexpected server error: %v", err) + } + if runtime.GOARCH == "wasm" { + t.Skip("conn.Close does not error as expected when called multiple times on WASM") + } + err = cli.Close() + if err == nil { + t.Error("Client connection was not closed when the context was canceled") + } +} diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index fab26b2..7a6ef04 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/hmac" "crypto/rsa" @@ -17,6 +18,7 @@ import ( type clientHandshakeStateTLS13 struct { c *Conn + ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg ecdheParams ecdheParameters @@ -553,6 +555,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { AcceptableCAs: hs.certReq.certificateAuthorities, SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, Version: c.vers, + ctx: hs.ctx, }) if err != nil { return err diff --git a/handshake_server.go b/handshake_server.go index 99e30b8..17abe0f 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -5,6 +5,7 @@ package tls import ( + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -13,6 +14,7 @@ import ( "crypto/x509" "errors" "fmt" + "hash" "io" "sync/atomic" "time" @@ -22,6 +24,7 @@ import ( // It's discarded once the handshake has completed. type serverHandshakeState struct { c *Conn + ctx context.Context clientHello *clientHelloMsg hello *serverHelloMsg suite *cipherSuite @@ -36,8 +39,8 @@ type serverHandshakeState struct { } // serverHandshake performs a TLS handshake as a server. -func (c *Conn) serverHandshake() error { - clientHello, err := c.readClientHello() +func (c *Conn) serverHandshake(ctx context.Context) error { + clientHello, err := c.readClientHello(ctx) if err != nil { return err } @@ -45,6 +48,7 @@ func (c *Conn) serverHandshake() error { if c.vers == VersionTLS13 { hs := serverHandshakeStateTLS13{ c: c, + ctx: ctx, clientHello: clientHello, } return hs.handshake() @@ -52,6 +56,7 @@ func (c *Conn) serverHandshake() error { hs := serverHandshakeState{ c: c, + ctx: ctx, clientHello: clientHello, } return hs.handshake() @@ -123,7 +128,7 @@ func (hs *serverHandshakeState) handshake() error { } // readClientHello reads a ClientHello message and selects the protocol version. -func (c *Conn) readClientHello() (*clientHelloMsg, error) { +func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { msg, err := c.readHandshake() if err != nil { return nil, err @@ -137,7 +142,7 @@ func (c *Conn) readClientHello() (*clientHelloMsg, error) { var configForClient *Config originalConfig := c.config if c.config.GetConfigForClient != nil { - chi := clientHelloInfo(c, clientHello) + chi := clientHelloInfo(ctx, c, clientHello) if configForClient, err = c.config.GetConfigForClient(chi); err != nil { c.sendAlert(alertInternalError) return nil, err @@ -219,7 +224,7 @@ func (hs *serverHandshakeState) processClientHello() error { } } - hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) + hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) if err != nil { if err == errNoCertificates { c.sendAlert(alertUnrecognizedName) @@ -641,13 +646,13 @@ func (hs *serverHandshakeState) establishKeys() error { keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) var clientCipher, serverCipher interface{} - var clientHash, serverHash macFunction + var clientHash, serverHash hash.Hash if hs.suite.aead == nil { clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) - clientHash = hs.suite.mac(c.vers, clientMAC) + clientHash = hs.suite.mac(clientMAC) serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) - serverHash = hs.suite.mac(c.vers, serverMAC) + serverHash = hs.suite.mac(serverMAC) } else { clientCipher = hs.suite.aead(clientKey, clientIV) serverCipher = hs.suite.aead(serverKey, serverIV) @@ -815,7 +820,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { return nil } -func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { +func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { supportedVersions := clientHello.supportedVersions if len(clientHello.supportedVersions) == 0 { supportedVersions = supportedVersionsFromMax(clientHello.vers) @@ -831,5 +836,6 @@ func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { SupportedVersions: supportedVersions, Conn: c.conn, config: c.config, + ctx: ctx, } } diff --git a/handshake_server_test.go b/handshake_server_test.go index a7a5324..c4416c3 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/elliptic" "crypto/x509" @@ -17,6 +18,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "strings" "testing" "time" @@ -36,10 +38,12 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa cli.writeRecord(recordTypeHandshake, m.marshal()) c.Close() }() + ctx := context.Background() conn := Server(s, serverConfig) - ch, err := conn.readClientHello() + ch, err := conn.readClientHello(ctx) hs := serverHandshakeState{ c: conn, + ctx: ctx, clientHello: ch, } if err == nil { @@ -1418,9 +1422,11 @@ func TestSNIGivenOnFailure(t *testing.T) { c.Close() }() conn := Server(s, serverConfig) - ch, err := conn.readClientHello() + ctx := context.Background() + ch, err := conn.readClientHello(ctx) hs := serverHandshakeState{ c: conn, + ctx: ctx, clientHello: ch, } if err == nil { @@ -1673,3 +1679,43 @@ func TestMultipleCertificates(t *testing.T) { t.Errorf("expected RSA certificate, got %v", got) } } + +func TestServerHandshakeContextCancellation(t *testing.T) { + c, s := localPipe(t) + clientConfig := testConfig.Clone() + clientErr := make(chan error, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + defer close(clientErr) + defer c.Close() + clientHello := &clientHelloMsg{ + vers: VersionTLS10, + random: make([]byte, 32), + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + compressionMethods: []uint8{compressionNone}, + } + cli := Client(c, clientConfig) + _, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal()) + cancel() + clientErr <- err + }() + conn := Server(s, testConfig) + err := conn.HandshakeContext(ctx) + if err == nil { + t.Fatal("Server handshake did not error when the context was canceled") + } + if err != context.Canceled { + t.Errorf("Unexpected server handshake error: %v", err) + } + if err := <-clientErr; err != nil { + t.Errorf("Unexpected client error: %v", err) + } + if runtime.GOARCH == "wasm" { + t.Skip("conn.Close does not error as expected when called multiple times on WASM") + } + err = conn.Close() + if err == nil { + t.Error("Server connection was not closed when the context was canceled") + } +} diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 7a84efd..222d634 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/hmac" "crypto/rsa" @@ -23,6 +24,7 @@ const maxClientPSKIdentities = 5 type serverHandshakeStateTLS13 struct { c *Conn + ctx context.Context clientHello *clientHelloMsg hello *serverHelloMsg sentDummyCCS bool @@ -365,7 +367,7 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error { return c.sendAlert(alertMissingExtension) } - certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) + certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) if err != nil { if err == errNoCertificates { c.sendAlert(alertUnrecognizedName) diff --git a/handshake_test.go b/handshake_test.go index f55cd16..605be58 100644 --- a/handshake_test.go +++ b/handshake_test.go @@ -86,7 +86,7 @@ func checkOpenSSLVersion() error { println("to update the test data.") println("") println("Configure it with:") - println("./Configure enable-weak-ssl-ciphers") + println("./Configure enable-weak-ssl-ciphers no-shared") println("and then add the apps/ directory at the front of your PATH.") println("***********************************************") @@ -403,7 +403,7 @@ func testHandshake(t *testing.T, clientConfig, serverConfig *Config) (serverStat } defer cli.Close() clientState = cli.ConnectionState() - buf, err := ioutil.ReadAll(cli) + buf, err := io.ReadAll(cli) if err != nil { t.Errorf("failed to call cli.Read: %v", err) } diff --git a/tls.go b/tls.go index 454aa0b..bf577ca 100644 --- a/tls.go +++ b/tls.go @@ -25,7 +25,6 @@ import ( "io/ioutil" "net" "strings" - "time" ) // Server returns a new TLS server side connection @@ -116,28 +115,16 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (* } func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { - // We want the Timeout and Deadline values from dialer to cover the - // whole process: TCP connection and TLS handshake. This means that we - // also need to start our own timers now. - timeout := netDialer.Timeout + if netDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) + defer cancel() + } if !netDialer.Deadline.IsZero() { - deadlineTimeout := time.Until(netDialer.Deadline) - if timeout == 0 || deadlineTimeout < timeout { - timeout = deadlineTimeout - } - } - - // hsErrCh is non-nil if we might not wait for Handshake to complete. - var hsErrCh chan error - if timeout != 0 || ctx.Done() != nil { - hsErrCh = make(chan error, 2) - } - if timeout != 0 { - timer := time.AfterFunc(timeout, func() { - hsErrCh <- timeoutError{} - }) - defer timer.Stop() + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) + defer cancel() } rawConn, err := netDialer.DialContext(ctx, network, addr) @@ -164,34 +151,10 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf } conn := Client(rawConn, config) - - if hsErrCh == nil { - err = conn.Handshake() - } else { - go func() { - hsErrCh <- conn.Handshake() - }() - - select { - case <-ctx.Done(): - err = ctx.Err() - case err = <-hsErrCh: - if err != nil { - // If the error was due to the context - // closing, prefer the context's error, rather - // than some random network teardown error. - if e := ctx.Err(); e != nil { - err = e - } - } - } - } - - if err != nil { + if err := conn.HandshakeContext(ctx); err != nil { rawConn.Close() return nil, err } - return conn, nil } diff --git a/tls_test.go b/tls_test.go index 4ab8a43..9995538 100644 --- a/tls_test.go +++ b/tls_test.go @@ -14,7 +14,6 @@ import ( "fmt" "internal/testenv" "io" - "io/ioutil" "math" "net" "os" @@ -594,7 +593,7 @@ func TestConnCloseWrite(t *testing.T) { } defer srv.Close() - data, err := ioutil.ReadAll(srv) + data, err := io.ReadAll(srv) if err != nil { return err } @@ -635,7 +634,7 @@ func TestConnCloseWrite(t *testing.T) { return fmt.Errorf("CloseWrite error = %v; want errShutdown", err) } - data, err := ioutil.ReadAll(conn) + data, err := io.ReadAll(conn) if err != nil { return err } @@ -698,7 +697,7 @@ func TestWarningAlertFlood(t *testing.T) { } defer srv.Close() - _, err = ioutil.ReadAll(srv) + _, err = io.ReadAll(srv) if err == nil { return errors.New("unexpected lack of error from server") }