crypto/tls: add HandshakeContext method to Conn

Adds the (*tls.Conn).HandshakeContext method. This allows
us to pass the context provided down the call stack to
eventually reach the tls.ClientHelloInfo and
tls.CertificateRequestInfo structs.
These contexts are exposed to the user as read-only via Context()
methods.

This allows users of (*tls.Config).GetCertificate and
(*tls.Config).GetClientCertificate to use the context for
request scoped parameters and cancellation.

Replace uses of (*tls.Conn).Handshake with (*tls.Conn).HandshakeContext
where appropriate, to propagate existing contexts.

Fixes #32406

Change-Id: I259939c744bdc9b805bf51a845a8bc462c042483
Reviewed-on: https://go-review.googlesource.com/c/go/+/295370
Run-TryBot: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Katie Hockman <katie@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Johan Brandhorst 2020-08-01 12:18:31 +01:00 committed by Johan Brandhorst-Satzkorn
parent 2708f2d5a3
commit 93cad92f83
9 changed files with 266 additions and 62 deletions

View file

@ -6,6 +6,7 @@ package tls
import (
"bytes"
"context"
"crypto"
"crypto/elliptic"
"crypto/x509"
@ -17,6 +18,7 @@ import (
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
@ -38,10 +40,12 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
cli.writeRecord(recordTypeHandshake, m.marshal())
c.Close()
}()
ctx := context.Background()
conn := Server(s, serverConfig)
ch, err := conn.readClientHello()
ch, err := conn.readClientHello(ctx)
hs := serverHandshakeState{
c: conn,
ctx: ctx,
clientHello: ch,
}
if err == nil {
@ -1421,9 +1425,11 @@ func TestSNIGivenOnFailure(t *testing.T) {
c.Close()
}()
conn := Server(s, serverConfig)
ch, err := conn.readClientHello()
ctx := context.Background()
ch, err := conn.readClientHello(ctx)
hs := serverHandshakeState{
c: conn,
ctx: ctx,
clientHello: ch,
}
if err == nil {
@ -1939,3 +1945,112 @@ func TestAESCipherReordering13(t *testing.T) {
})
}
}
func TestServerHandshakeContextCancellation(t *testing.T) {
c, s := localPipe(t)
clientConfig := testConfig.Clone()
clientErr := make(chan error, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
defer close(clientErr)
defer c.Close()
clientHello := &clientHelloMsg{
vers: VersionTLS10,
random: make([]byte, 32),
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
compressionMethods: []uint8{compressionNone},
}
cli := Client(c, clientConfig)
_, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal())
cancel()
clientErr <- err
}()
conn := Server(s, testConfig)
err := conn.HandshakeContext(ctx)
if err == nil {
t.Fatal("Server handshake did not error when the context was canceled")
}
if err != context.Canceled {
t.Errorf("Unexpected server handshake error: %v", err)
}
if err := <-clientErr; err != nil {
t.Errorf("Unexpected client error: %v", err)
}
if runtime.GOARCH == "wasm" {
t.Skip("conn.Close does not error as expected when called multiple times on WASM")
}
err = conn.Close()
if err == nil {
t.Error("Server connection was not closed when the context was canceled")
}
}
// TestHandshakeContextHierarchy tests whether the contexts
// available to GetClientCertificate and GetCertificate are
// derived from the context provided to HandshakeContext, and
// that those contexts are cancelled after HandshakeContext has
// returned.
func TestHandshakeContextHierarchy(t *testing.T) {
c, s := localPipe(t)
clientErr := make(chan error, 1)
clientConfig := testConfig.Clone()
serverConfig := testConfig.Clone()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
key := struct{}{}
ctx = context.WithValue(ctx, key, true)
go func() {
defer close(clientErr)
defer c.Close()
var innerCtx context.Context
clientConfig.Certificates = nil
clientConfig.GetClientCertificate = func(certificateRequest *CertificateRequestInfo) (*Certificate, error) {
if val, ok := certificateRequest.Context().Value(key).(bool); !ok || !val {
t.Errorf("GetClientCertificate context was not child of HandshakeContext")
}
innerCtx = certificateRequest.Context()
return &Certificate{
Certificate: [][]byte{testRSACertificate},
PrivateKey: testRSAPrivateKey,
}, nil
}
cli := Client(c, clientConfig)
err := cli.HandshakeContext(ctx)
if err != nil {
clientErr <- err
return
}
select {
case <-innerCtx.Done():
default:
t.Errorf("GetClientCertificate context was not cancelled after HandshakeContext returned.")
}
}()
var innerCtx context.Context
serverConfig.Certificates = nil
serverConfig.ClientAuth = RequestClientCert
serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
if val, ok := clientHello.Context().Value(key).(bool); !ok || !val {
t.Errorf("GetClientCertificate context was not child of HandshakeContext")
}
innerCtx = clientHello.Context()
return &Certificate{
Certificate: [][]byte{testRSACertificate},
PrivateKey: testRSAPrivateKey,
}, nil
}
conn := Server(s, serverConfig)
err := conn.HandshakeContext(ctx)
if err != nil {
t.Errorf("Unexpected server handshake error: %v", err)
}
select {
case <-innerCtx.Done():
default:
t.Errorf("GetCertificate context was not cancelled after HandshakeContext returned.")
}
if err := <-clientErr; err != nil {
t.Errorf("Unexpected client error: %v", err)
}
}