diff --git a/common.go b/common.go index 5fd92d3..dba9650 100644 --- a/common.go +++ b/common.go @@ -447,6 +447,10 @@ type ClientHelloInfo struct { // might be rejected if used. SupportedVersions []uint16 + // Extensions lists the IDs of the extensions presented by the client + // in the client hello. + Extensions []uint16 + // Conn is the underlying net.Conn for the connection. Do not read // from, or write to, this connection; that will cause the TLS // connection to fail. diff --git a/handshake_messages.go b/handshake_messages.go index 8620b66..823caff 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -97,6 +97,8 @@ type clientHelloMsg struct { pskBinders [][]byte quicTransportParameters []byte encryptedClientHello []byte + // extensions are only populated on the server-side of a handshake + extensions []uint16 } func (m *clientHelloMsg) marshalMsg(echInner bool) ([]byte, error) { @@ -467,6 +469,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return false } seenExts[extension] = true + m.extensions = append(m.extensions, extension) switch extension { case extensionServerName: diff --git a/handshake_messages_test.go b/handshake_messages_test.go index 197a1c5..2c360e6 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -76,6 +76,18 @@ func TestMarshalUnmarshal(t *testing.T) { m.activeCertHandles = nil } + if ch, ok := m.(*clientHelloMsg); ok { + // extensions is special cased, as it is only populated by the + // server-side of a handshake and is not expected to roundtrip + // through marshal + unmarshal. m ends up with the list of + // extensions necessary to serialize the other fields of + // clientHelloMsg, so check that it is non-empty, then clear it. + if len(ch.extensions) == 0 { + t.Errorf("expected ch.extensions to be populated on unmarshal") + } + ch.extensions = nil + } + // clientHelloMsg and serverHelloMsg, when unmarshalled, store // their original representation, for later use in the handshake // transcript. In order to prevent DeepEqual from failing since diff --git a/handshake_server.go b/handshake_server.go index ac3d915..bc4e51b 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -963,6 +963,7 @@ func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) SignatureSchemes: clientHello.supportedSignatureAlgorithms, SupportedProtos: clientHello.alpnProtocols, SupportedVersions: supportedVersions, + Extensions: clientHello.extensions, Conn: c.conn, config: c.config, ctx: ctx, diff --git a/handshake_server_test.go b/handshake_server_test.go index 94d3d0f..788a26a 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -23,6 +23,7 @@ import ( "runtime" "slices" "strings" + "sync/atomic" "testing" "time" ) @@ -1066,6 +1067,65 @@ func TestHandshakeServerSNIGetCertificateNotFound(t *testing.T) { runServerTestTLS12(t, test) } +// TestHandshakeServerGetCertificateExtensions tests to make sure that the +// Extensions passed to GetCertificate match what we expect based on the +// clientHelloMsg +func TestHandshakeServerGetCertificateExtensions(t *testing.T) { + const errMsg = "TestHandshakeServerGetCertificateExtensions error" + // ensure the test condition inside our GetCertificate callback + // is actually invoked + var called atomic.Int32 + + testVersions := []uint16{VersionTLS12, VersionTLS13} + for _, vers := range testVersions { + t.Run(fmt.Sprintf("TLS version %04x", vers), func(t *testing.T) { + pk, _ := ecdh.X25519().GenerateKey(rand.Reader) + clientHello := &clientHelloMsg{ + vers: vers, + random: make([]byte, 32), + cipherSuites: []uint16{TLS_CHACHA20_POLY1305_SHA256}, + compressionMethods: []uint8{compressionNone}, + serverName: "test", + keyShares: []keyShare{{group: X25519, data: pk.PublicKey().Bytes()}}, + supportedCurves: []CurveID{X25519}, + supportedSignatureAlgorithms: []SignatureScheme{Ed25519}, + } + + // the clientHelloMsg initialized just above is serialized with + // two extensions: server_name(0) and application_layer_protocol_negotiation(16) + expectedExtensions := []uint16{ + extensionServerName, + extensionSupportedCurves, + extensionSignatureAlgorithms, + extensionKeyShare, + } + + if vers == VersionTLS13 { + clientHello.supportedVersions = []uint16{VersionTLS13} + expectedExtensions = append(expectedExtensions, extensionSupportedVersions) + } + + // Go's TLS client presents extensions in the ClientHello sorted by extension ID + slices.Sort(expectedExtensions) + + serverConfig := testConfig.Clone() + serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { + if !slices.Equal(expectedExtensions, clientHello.Extensions) { + t.Errorf("expected extensions on ClientHelloInfo (%v) to match clientHelloMsg (%v)", expectedExtensions, clientHello.Extensions) + } + called.Add(1) + + return nil, errors.New(errMsg) + } + testClientHelloFailure(t, serverConfig, clientHello, errMsg) + }) + } + + if int(called.Load()) != len(testVersions) { + t.Error("expected our GetCertificate test to be called twice") + } +} + // TestHandshakeServerSNIGetCertificateError tests to make sure that errors in // GetCertificate result in a tls alert. func TestHandshakeServerSNIGetCertificateError(t *testing.T) {