diff --git a/common.go b/common.go index b3d7f89..064b18a 100644 --- a/common.go +++ b/common.go @@ -89,6 +89,7 @@ const ( extensionSupportedPoints uint16 = 11 extensionSignatureAlgorithms uint16 = 13 extensionALPN uint16 = 16 + extensionStatusRequestV2 uint16 = 17 extensionSCT uint16 = 18 extensionDelegatedCredentials uint16 = 34 extensionSessionTicket uint16 = 35 @@ -100,7 +101,7 @@ const ( extensionCertificateAuthorities uint16 = 47 extensionSignatureAlgorithmsCert uint16 = 50 extensionKeyShare uint16 = 51 - extensionNextProtoNeg uint16 = 13172 // not IANA assigned // Pending discussion on whether or not remove this. crypto/tls removed it on Nov 21, 2019. + extensionNextProtoNeg uint16 = 13172 // not IANA assigned // Pending discussion on whether or not remove this. crypto/tls removed it on Nov 21, 2019. extensionRenegotiationInfo uint16 = 0xff01 ) diff --git a/conn.go b/conn.go index 805354b..954d623 100644 --- a/conn.go +++ b/conn.go @@ -1081,11 +1081,14 @@ func (c *Conn) readHandshake() (any, error) { m = new(endOfEarlyDataMsg) case typeKeyUpdate: m = new(keyUpdateMsg) - // [UTLS SECTION BEGINS] - case typeCompressedCertificate: - m = new(compressedCertificateMsg) - // [UTLS SECTION ENDS] default: + // [UTLS SECTION BEGINS] + var err error + m, err = c.utlsHandshakeMessageType(data[0]) // see u_conn.go + if err == nil { + break + } + // [UTLS SECTION ENDS] return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index 47aa202..04e8218 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -6,7 +6,6 @@ package tls import ( "bytes" - "compress/zlib" "context" "crypto" "crypto/hmac" @@ -14,12 +13,8 @@ import ( "errors" "fmt" "hash" - "io" "sync/atomic" "time" - - "github.com/andybalholm/brotli" - "github.com/klauspost/compress/zstd" ) type clientHandshakeStateTLS13 struct { @@ -516,19 +511,15 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { } // [UTLS SECTION BEGINS] - receivedCompressedCert := false - // Check to see if we advertised any compression algorithms - if hs.uconn != nil && len(hs.uconn.certCompressionAlgs) > 0 { - // Check to see if the message is a compressed certificate message, otherwise move on. - compressedCertMsg, ok := msg.(*compressedCertificateMsg) - if ok { - receivedCompressedCert = true - hs.transcript.Write(compressedCertMsg.marshal()) - - msg, err = hs.decompressCert(*compressedCertMsg) - if err != nil { - return fmt.Errorf("tls: failed to decompress certificate message: %w", err) - } + var skipWritingCertToTranscript bool = false + if hs.uconn != nil { + processedMsg, err := hs.utlsReadServerCertificate(msg) + if err != nil { + return err + } + if processedMsg != nil { + skipWritingCertToTranscript = true + msg = processedMsg // msg is now a processed-by-extension certificateMsg } } // [UTLS SECTION ENDS] @@ -544,7 +535,7 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { } // [UTLS SECTION BEGINS] // Previously, this was simply 'hs.transcript.Write(certMsg.marshal())' (without the if). - if !receivedCompressedCert { + if !skipWritingCertToTranscript { hs.transcript.Write(certMsg.marshal()) } // [UTLS SECTION ENDS] @@ -729,80 +720,6 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error { return nil } -// [UTLS SECTION BEGINS] -func (hs *clientHandshakeStateTLS13) decompressCert(m compressedCertificateMsg) (*certificateMsgTLS13, error) { - var ( - decompressed io.Reader - compressed = bytes.NewReader(m.compressedCertificateMessage) - c = hs.c - ) - - // Check to see if the peer responded with an algorithm we advertised. - supportedAlg := false - for _, alg := range hs.uconn.certCompressionAlgs { - if m.algorithm == uint16(alg) { - supportedAlg = true - } - } - if !supportedAlg { - c.sendAlert(alertBadCertificate) - return nil, fmt.Errorf("unadvertised algorithm (%d)", m.algorithm) - } - - switch CertCompressionAlgo(m.algorithm) { - case CertCompressionBrotli: - decompressed = brotli.NewReader(compressed) - - case CertCompressionZlib: - rc, err := zlib.NewReader(compressed) - if err != nil { - c.sendAlert(alertBadCertificate) - return nil, fmt.Errorf("failed to open zlib reader: %w", err) - } - defer rc.Close() - decompressed = rc - - case CertCompressionZstd: - rc, err := zstd.NewReader(compressed) - if err != nil { - c.sendAlert(alertBadCertificate) - return nil, fmt.Errorf("failed to open zstd reader: %w", err) - } - defer rc.Close() - decompressed = rc - - default: - c.sendAlert(alertBadCertificate) - return nil, fmt.Errorf("unsupported algorithm (%d)", m.algorithm) - } - - rawMsg := make([]byte, m.uncompressedLength+4) // +4 for message type and uint24 length field - rawMsg[0] = typeCertificate - rawMsg[1] = uint8(m.uncompressedLength >> 16) - rawMsg[2] = uint8(m.uncompressedLength >> 8) - rawMsg[3] = uint8(m.uncompressedLength) - - n, err := decompressed.Read(rawMsg[4:]) - if err != nil { - c.sendAlert(alertBadCertificate) - return nil, err - } - if n < len(rawMsg)-4 { - // If, after decompression, the specified length does not match the actual length, the party - // receiving the invalid message MUST abort the connection with the "bad_certificate" alert. - // https://datatracker.ietf.org/doc/html/rfc8879#section-4 - c.sendAlert(alertBadCertificate) - return nil, fmt.Errorf("decompressed len (%d) does not match specified len (%d)", n, m.uncompressedLength) - } - certMsg := new(certificateMsgTLS13) - if !certMsg.unmarshal(rawMsg) { - return nil, c.sendAlert(alertUnexpectedMessage) - } - return certMsg, nil -} - -// [UTLS SECTION ENDS] - func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { if !c.isClient { c.sendAlert(alertUnexpectedMessage) diff --git a/u_common.go b/u_common.go index a80bce3..61bd144 100644 --- a/u_common.go +++ b/u_common.go @@ -16,6 +16,14 @@ import ( // Things, supported by utls, but not crypto/tls' are prefixed with "utls" // Supported things, that have changed their ID are prefixed with "Old" // Supported but disabled things are prefixed with "Disabled". We will _enable_ them. + +// TLS handshake message types. +const ( + // https://datatracker.ietf.org/doc/html/rfc8879#section-7.2 + typeCompressedCertificate uint8 = 25 +) + +// TLS const ( utlsExtensionPadding uint16 = 21 utlsExtensionExtendedMasterSecret uint16 = 23 // https://tools.ietf.org/html/rfc7627 @@ -25,15 +33,10 @@ const ( // extensions with 'fake' prefix break connection, if server echoes them back fakeExtensionTokenBinding uint16 = 24 - fakeExtensionChannelIDOld uint16 = 30031 // not IANA assigned + fakeOldExtensionChannelID uint16 = 30031 // not IANA assigned fakeExtensionChannelID uint16 = 30032 // not IANA assigned fakeExtensionALPS uint16 = 17513 // not IANA assigned fakeExtensionDelegatedCredentials uint16 = 34 - - fakeRecordSizeLimit uint16 = 0x001c - - // https://datatracker.ietf.org/doc/html/rfc8879#section-7.2 - typeCompressedCertificate uint8 = 25 ) const ( @@ -60,6 +63,11 @@ const ( FAKE_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA = uint16(0xc008) ) +// Other things +const ( + fakeRecordSizeLimit uint16 = 0x001c +) + // newest signatures var ( FakePKCS1WithSHA224 SignatureScheme = 0x0301 diff --git a/u_conn.go b/u_conn.go index b1feec3..b729a0c 100644 --- a/u_conn.go +++ b/u_conn.go @@ -710,3 +710,13 @@ func makeSupportedVersions(minVers, maxVers uint16) []uint16 { } return a } + +// Extending (*Conn).readHandshake() to support more customized handshake messages. +func (c *Conn) utlsHandshakeMessageType(msgType byte) (handshakeMessage, error) { + switch msgType { + case typeCompressedCertificate: + return new(compressedCertificateMsg), nil + default: + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } +} diff --git a/u_fingerprinter.go b/u_fingerprinter.go index d0c1bcc..d905c94 100644 --- a/u_fingerprinter.go +++ b/u_fingerprinter.go @@ -321,7 +321,7 @@ func (f *Fingerprinter) FingerprintClientHello(data []byte) (*ClientHelloSpec, e case fakeExtensionChannelID: clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &FakeChannelIDExtension{}) - case fakeExtensionChannelIDOld: + case fakeOldExtensionChannelID: clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &FakeChannelIDExtension{true}) case fakeExtensionTokenBinding: diff --git a/u_handshake_client.go b/u_handshake_client.go new file mode 100644 index 0000000..8cd6b15 --- /dev/null +++ b/u_handshake_client.go @@ -0,0 +1,113 @@ +// Copyright 2022 uTLS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "compress/zlib" + "fmt" + "io" + + "github.com/andybalholm/brotli" + "github.com/klauspost/compress/zstd" +) + +// This function is called by (*clientHandshakeStateTLS13.)readServerCertificate() +// to retrieve the certificate out of a message read by (*Conn.)readHandshake() +func (hs *clientHandshakeStateTLS13) utlsReadServerCertificate(msg any) (processedMsg any, err error) { + for _, ext := range hs.uconn.Extensions { + switch ext.(type) { + case *UtlsCompressCertExtension: + // Included Compressed Certificate extension + if len(hs.uconn.certCompressionAlgs) > 0 { + compressedCertMsg, ok := msg.(*compressedCertificateMsg) + if ok { + hs.transcript.Write(compressedCertMsg.marshal()) + msg, err = hs.decompressCert(*compressedCertMsg) + if err != nil { + return nil, fmt.Errorf("tls: failed to decompress certificate message: %w", err) + } else { + return msg, nil + } + } + } + default: + continue + } + } + return nil, nil +} + +// called by (*clientHandshakeStateTLS13.)utlsReadServerCertificate() when UtlsCompressCertExtension is used +func (hs *clientHandshakeStateTLS13) decompressCert(m compressedCertificateMsg) (*certificateMsgTLS13, error) { + var ( + decompressed io.Reader + compressed = bytes.NewReader(m.compressedCertificateMessage) + c = hs.c + ) + + // Check to see if the peer responded with an algorithm we advertised. + supportedAlg := false + for _, alg := range hs.uconn.certCompressionAlgs { + if m.algorithm == uint16(alg) { + supportedAlg = true + } + } + if !supportedAlg { + c.sendAlert(alertBadCertificate) + return nil, fmt.Errorf("unadvertised algorithm (%d)", m.algorithm) + } + + switch CertCompressionAlgo(m.algorithm) { + case CertCompressionBrotli: + decompressed = brotli.NewReader(compressed) + + case CertCompressionZlib: + rc, err := zlib.NewReader(compressed) + if err != nil { + c.sendAlert(alertBadCertificate) + return nil, fmt.Errorf("failed to open zlib reader: %w", err) + } + defer rc.Close() + decompressed = rc + + case CertCompressionZstd: + rc, err := zstd.NewReader(compressed) + if err != nil { + c.sendAlert(alertBadCertificate) + return nil, fmt.Errorf("failed to open zstd reader: %w", err) + } + defer rc.Close() + decompressed = rc + + default: + c.sendAlert(alertBadCertificate) + return nil, fmt.Errorf("unsupported algorithm (%d)", m.algorithm) + } + + rawMsg := make([]byte, m.uncompressedLength+4) // +4 for message type and uint24 length field + rawMsg[0] = typeCertificate + rawMsg[1] = uint8(m.uncompressedLength >> 16) + rawMsg[2] = uint8(m.uncompressedLength >> 8) + rawMsg[3] = uint8(m.uncompressedLength) + + n, err := decompressed.Read(rawMsg[4:]) + if err != nil { + c.sendAlert(alertBadCertificate) + return nil, err + } + if n < len(rawMsg)-4 { + // If, after decompression, the specified length does not match the actual length, the party + // receiving the invalid message MUST abort the connection with the "bad_certificate" alert. + // https://datatracker.ietf.org/doc/html/rfc8879#section-4 + c.sendAlert(alertBadCertificate) + return nil, fmt.Errorf("decompressed len (%d) does not match specified len (%d)", n, m.uncompressedLength) + } + certMsg := new(certificateMsgTLS13) + if !certMsg.unmarshal(rawMsg) { + return nil, c.sendAlert(alertUnexpectedMessage) + } + return certMsg, nil +} diff --git a/u_handshake_messages.go b/u_handshake_messages.go index 07203c2..13446c7 100644 --- a/u_handshake_messages.go +++ b/u_handshake_messages.go @@ -1,3 +1,7 @@ +// Copyright 2022 uTLS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package tls import ( diff --git a/u_tls_extensions.go b/u_tls_extensions.go index c4816d3..5ace5ba 100644 --- a/u_tls_extensions.go +++ b/u_tls_extensions.go @@ -863,7 +863,7 @@ func (e *FakeChannelIDExtension) Read(b []byte) (int, error) { } extensionID := fakeExtensionChannelID if e.OldExtensionID { - extensionID = fakeExtensionChannelIDOld + extensionID = fakeOldExtensionChannelID } // https://tools.ietf.org/html/draft-balfanz-tls-channelid-00 b[0] = byte(extensionID >> 8)