Implement certificate compression (#95)

Certificate compression is defined in RFC 8879:
https://datatracker.ietf.org/doc/html/rfc8879

This implementation is client-side only, for server certificates.

- Fixes #104.
This commit is contained in:
hwh33 2022-07-19 19:12:30 -06:00 committed by GitHub
parent 9d36ce3658
commit 7344e34650
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 276 additions and 50 deletions

View file

@ -6,14 +6,19 @@ package tls
import (
"bytes"
"compress/zlib"
"crypto"
"crypto/hmac"
"crypto/rsa"
"errors"
"fmt"
"hash"
"io"
"sync/atomic"
"time"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
)
type clientHandshakeStateTLS13 struct {
@ -484,6 +489,24 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
}
}
// [UTLS SECTION BEGINS]
receivedCompressedCert := false
// Check to see if we advertised any compression algorithms
if hs.uconn != nil && len(hs.uconn.certCompressionAlgs) > 0 {
// Check to see if the message is a compressed certificate message, otherwise move on.
compressedCertMsg, ok := msg.(*compressedCertificateMsg)
if ok {
receivedCompressedCert = true
hs.transcript.Write(compressedCertMsg.marshal())
msg, err = hs.decompressCert(*compressedCertMsg)
if err != nil {
return fmt.Errorf("tls: failed to decompress certificate message: %w", err)
}
}
}
// [UTLS SECTION ENDS]
certMsg, ok := msg.(*certificateMsgTLS13)
if !ok {
c.sendAlert(alertUnexpectedMessage)
@ -493,7 +516,12 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
c.sendAlert(alertDecodeError)
return errors.New("tls: received empty certificates message")
}
hs.transcript.Write(certMsg.marshal())
// [UTLS SECTION BEGINS]
// Previously, this was simply 'hs.transcript.Write(certMsg.marshal())' (without the if).
if !receivedCompressedCert {
hs.transcript.Write(certMsg.marshal())
}
// [UTLS SECTION ENDS]
c.scts = certMsg.certificate.SignedCertificateTimestamps
c.ocspResponse = certMsg.certificate.OCSPStaple
@ -690,6 +718,80 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
return nil
}
// [UTLS SECTION BEGINS]
func (hs *clientHandshakeStateTLS13) decompressCert(m compressedCertificateMsg) (*certificateMsgTLS13, error) {
var (
decompressed io.Reader
compressed = bytes.NewReader(m.compressedCertificateMessage)
c = hs.c
)
// Check to see if the peer responded with an algorithm we advertised.
supportedAlg := false
for _, alg := range hs.uconn.certCompressionAlgs {
if m.algorithm == uint16(alg) {
supportedAlg = true
}
}
if !supportedAlg {
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("unadvertised algorithm (%d)", m.algorithm)
}
switch CertCompressionAlgo(m.algorithm) {
case CertCompressionBrotli:
decompressed = brotli.NewReader(compressed)
case CertCompressionZlib:
rc, err := zlib.NewReader(compressed)
if err != nil {
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("failed to open zlib reader: %w", err)
}
defer rc.Close()
decompressed = rc
case CertCompressionZstd:
rc, err := zstd.NewReader(compressed)
if err != nil {
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("failed to open zstd reader: %w", err)
}
defer rc.Close()
decompressed = rc
default:
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("unsupported algorithm (%d)", m.algorithm)
}
rawMsg := make([]byte, m.uncompressedLength+4) // +4 for message type and uint24 length field
rawMsg[0] = typeCertificate
rawMsg[1] = uint8(m.uncompressedLength >> 16)
rawMsg[2] = uint8(m.uncompressedLength >> 8)
rawMsg[3] = uint8(m.uncompressedLength)
n, err := decompressed.Read(rawMsg[4:])
if err != nil {
c.sendAlert(alertBadCertificate)
return nil, err
}
if n < len(rawMsg)-4 {
// If, after decompression, the specified length does not match the actual length, the party
// receiving the invalid message MUST abort the connection with the "bad_certificate" alert.
// https://datatracker.ietf.org/doc/html/rfc8879#section-4
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("decompressed len (%d) does not match specified len (%d)", n, m.uncompressedLength)
}
certMsg := new(certificateMsgTLS13)
if !certMsg.unmarshal(rawMsg) {
return nil, c.sendAlert(alertUnexpectedMessage)
}
return certMsg, nil
}
// [UTLS SECTION ENDS]
func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
if !c.isClient {
c.sendAlert(alertUnexpectedMessage)