crypto/tls: add GetClientCertificate callback

Currently, the selection of a client certificate done internally based
on the limitations given by the server's request and the certifcates in
the Config. This means that it's not possible for an application to
control that selection based on details of the request.

This change adds a callback, GetClientCertificate, that is called by a
Client during the handshake and which allows applications to select the
best certificate at that time.

(Based on https://golang.org/cl/25570/ by Bernd Fix.)

Fixes #16626.

Change-Id: Ia4cea03235d2aa3c9fd49c99c227593c8e86ddd9
Reviewed-on: https://go-review.googlesource.com/32115
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Adam Langley 2016-10-26 10:05:03 -07:00
parent 97a4987572
commit 0b375e2be0
4 changed files with 287 additions and 73 deletions

View file

@ -1408,3 +1408,137 @@ func TestHandshakeRace(t *testing.T) {
<-readDone
}
}
func TestTLS11SignatureSchemes(t *testing.T) {
expected := tls11SignatureSchemesNumECDSA + tls11SignatureSchemesNumRSA
if expected != len(tls11SignatureSchemes) {
t.Errorf("expected to find %d TLS 1.1 signature schemes, but found %d", expected, len(tls11SignatureSchemes))
}
}
var getClientCertificateTests = []struct {
setup func(*Config)
expectedClientError string
verify func(*testing.T, int, *ConnectionState)
}{
{
func(clientConfig *Config) {
// Returning a Certificate with no certificate data
// should result in an empty message being sent to the
// server.
clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
if len(cri.SignatureSchemes) == 0 {
panic("empty SignatureSchemes")
}
return new(Certificate), nil
}
},
"",
func(t *testing.T, testNum int, cs *ConnectionState) {
if l := len(cs.PeerCertificates); l != 0 {
t.Errorf("#%d: expected no certificates but got %d", testNum, l)
}
},
},
{
func(clientConfig *Config) {
// With TLS 1.1, the SignatureSchemes should be
// synthesised from the supported certificate types.
clientConfig.MaxVersion = VersionTLS11
clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
if len(cri.SignatureSchemes) == 0 {
panic("empty SignatureSchemes")
}
return new(Certificate), nil
}
},
"",
func(t *testing.T, testNum int, cs *ConnectionState) {
if l := len(cs.PeerCertificates); l != 0 {
t.Errorf("#%d: expected no certificates but got %d", testNum, l)
}
},
},
{
func(clientConfig *Config) {
// Returning an error should abort the handshake with
// that error.
clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
return nil, errors.New("GetClientCertificate")
}
},
"GetClientCertificate",
func(t *testing.T, testNum int, cs *ConnectionState) {
},
},
{
func(clientConfig *Config) {
clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
return &testConfig.Certificates[0], nil
}
},
"",
func(t *testing.T, testNum int, cs *ConnectionState) {
if l := len(cs.VerifiedChains); l != 0 {
t.Errorf("#%d: expected some verified chains, but found none", testNum)
}
},
},
}
func TestGetClientCertificate(t *testing.T) {
issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
if err != nil {
panic(err)
}
for i, test := range getClientCertificateTests {
serverConfig := testConfig.Clone()
serverConfig.ClientAuth = RequestClientCert
serverConfig.RootCAs = x509.NewCertPool()
serverConfig.RootCAs.AddCert(issuer)
clientConfig := testConfig.Clone()
test.setup(clientConfig)
type serverResult struct {
cs ConnectionState
err error
}
c, s := net.Pipe()
done := make(chan serverResult)
go func() {
defer s.Close()
server := Server(s, serverConfig)
err := server.Handshake()
var cs ConnectionState
if err == nil {
cs = server.ConnectionState()
}
done <- serverResult{cs, err}
}()
clientErr := Client(c, clientConfig).Handshake()
c.Close()
result := <-done
if clientErr != nil {
if len(test.expectedClientError) == 0 {
t.Errorf("#%d: client error: %v", i, clientErr)
} else if got := clientErr.Error(); got != test.expectedClientError {
t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
}
} else if len(test.expectedClientError) > 0 {
t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
} else if err := result.err; err != nil {
t.Errorf("#%d: server error: %v", i, err)
} else {
test.verify(t, i, &result.cs)
}
}
}