From e911b27e233eae2b760537920f2ab3109f276944 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Sun, 21 May 2023 21:17:56 +0200 Subject: [PATCH] crypto/tls: use SessionState on the client side Another internal change, that allows exposing the new APIs easily in following CLs. For #60105 Change-Id: I9c61b9f6e9d29af633f952444f514bcbbe82fe4e Reviewed-on: https://go-review.googlesource.com/c/go/+/496819 Reviewed-by: Matthew Dempsky TryBot-Result: Gopher Robot Reviewed-by: Damien Neil Run-TryBot: Filippo Valsorda --- cache.go | 2 +- common.go | 19 ---- handshake_client.go | 128 +++++++++++++---------- handshake_client_test.go | 4 +- handshake_client_tls13.go | 44 ++++---- handshake_messages_test.go | 66 +++++++++--- handshake_server.go | 28 ++--- handshake_server_tls13.go | 22 ++-- ticket.go | 205 ++++++++++++++++++++++++++++++++++--- 9 files changed, 350 insertions(+), 168 deletions(-) diff --git a/cache.go b/cache.go index 09f5825..a767761 100644 --- a/cache.go +++ b/cache.go @@ -39,7 +39,7 @@ type certCache struct { sync.Map } -var clientCertCache = new(certCache) +var globalCertCache = new(certCache) // activeCert is a handle to a certificate held in the cache. Once there are // no alive activeCerts for a given certificate, the certificate is removed diff --git a/common.go b/common.go index 58e9730..ccaf7d3 100644 --- a/common.go +++ b/common.go @@ -330,25 +330,6 @@ func requiresClientCert(c ClientAuthType) bool { } } -// ClientSessionState contains the state needed by clients to resume TLS -// sessions. -type ClientSessionState struct { - sessionTicket []uint8 // Encrypted ticket used for session resumption with server - vers uint16 // TLS version negotiated for the session - cipherSuite uint16 // Ciphersuite negotiated for the session - masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret - serverCertificates []*x509.Certificate // Certificate chain presented by the server - verifiedChains [][]*x509.Certificate // Certificate chains we built for verification - receivedAt time.Time // When the session ticket was received from the server - ocspResponse []byte // Stapled OCSP response presented by the server - scts [][]byte // SCTs presented by the server - - // TLS 1.3 fields. - nonce []byte // Ticket nonce sent by the server, to derive PSK - useBy time.Time // Expiration of the ticket lifetime as set by the server - ageAdd uint32 // Random obfuscation factor for sending the ticket age -} - // ClientSessionCache is a cache of ClientSessionState objects that can be used // by a client to resume a TLS session with a given server. ClientSessionCache // implementations should expect to be called concurrently from different diff --git a/handshake_client.go b/handshake_client.go index 9f74cc4..2156e91 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -31,7 +31,8 @@ type clientHandshakeState struct { suite *cipherSuite finishedHash finishedHash masterSecret []byte - session *ClientSessionState + session *SessionState // the session being resumed + ticket []byte // a fresh ticket received during this handshake } var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme @@ -177,11 +178,11 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { } c.serverName = hello.serverName - cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello) + session, earlySecret, binderKey, err := c.loadSession(hello) if err != nil { return err } - if cacheKey != "" && session != nil { + if session != nil { defer func() { // If we got a handshake failure when resuming a session, throw away // the session ticket. See RFC 5077, Section 3.2. @@ -190,7 +191,9 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { // does require servers to abort on invalid binders, so we need to // delete tickets to recover from a corrupted PSK. if err != nil { - c.config.ClientSessionCache.Put(cacheKey, nil) + if cacheKey := c.clientSessionCacheKey(); cacheKey != "" { + c.config.ClientSessionCache.Put(cacheKey, nil) + } } }() } @@ -255,19 +258,13 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { return err } - // If we had a successful handshake and hs.session is different from - // the one already cached - cache a new one. - if cacheKey != "" && hs.session != nil && session != hs.session { - c.config.ClientSessionCache.Put(cacheKey, hs.session) - } - return nil } -func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, - session *ClientSessionState, earlySecret, binderKey []byte, err error) { +func (c *Conn) loadSession(hello *clientHelloMsg) ( + session *SessionState, earlySecret, binderKey []byte, err error) { if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return "", nil, nil, nil, nil + return nil, nil, nil, nil } hello.ticketSupported = true @@ -282,29 +279,30 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // renegotiation is primarily used to allow a client to send a client // certificate, which would be skipped if session resumption occurred. if c.handshakes != 0 { - return "", nil, nil, nil, nil + return nil, nil, nil, nil } // Try to resume a previously negotiated TLS session, if available. - cacheKey = c.clientSessionCacheKey() + cacheKey := c.clientSessionCacheKey() if cacheKey == "" { - return "", nil, nil, nil, nil + return nil, nil, nil, nil } - session, ok := c.config.ClientSessionCache.Get(cacheKey) - if !ok || session == nil { - return cacheKey, nil, nil, nil, nil + cs, ok := c.config.ClientSessionCache.Get(cacheKey) + if !ok || cs == nil { + return nil, nil, nil, nil } + session = cs.session // Check that version used for the previous session is still valid. versOk := false for _, v := range hello.supportedVersions { - if v == session.vers { + if v == session.version { versOk = true break } } if !versOk { - return cacheKey, nil, nil, nil, nil + return nil, nil, nil, nil } // Check that the cached server certificate is not expired, and that it's @@ -313,41 +311,41 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, if !c.config.InsecureSkipVerify { if len(session.verifiedChains) == 0 { // The original connection had InsecureSkipVerify, while this doesn't. - return cacheKey, nil, nil, nil, nil + return nil, nil, nil, nil } - serverCert := session.serverCertificates[0] + serverCert := session.peerCertificates[0] if c.config.time().After(serverCert.NotAfter) { // Expired certificate, delete the entry. c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil, nil + return nil, nil, nil, nil } if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { - return cacheKey, nil, nil, nil, nil + return nil, nil, nil, nil } } - if session.vers != VersionTLS13 { + if session.version != VersionTLS13 { // In TLS 1.2 the cipher suite must match the resumed session. Ensure we // are still offering it. if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { - return cacheKey, nil, nil, nil, nil + return nil, nil, nil, nil } - hello.sessionTicket = session.sessionTicket + hello.sessionTicket = cs.ticket return } // Check that the session ticket is not expired. - if c.config.time().After(session.useBy) { + if c.config.time().After(time.Unix(int64(session.useBy), 0)) { c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil, nil + return nil, nil, nil, nil } // In TLS 1.3 the KDF hash must match the resumed session. Ensure we // offer at least one cipher suite with that hash. cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) if cipherSuite == nil { - return cacheKey, nil, nil, nil, nil + return nil, nil, nil, nil } cipherSuiteOk := false for _, offeredID := range hello.cipherSuites { @@ -358,32 +356,30 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, } } if !cipherSuiteOk { - return cacheKey, nil, nil, nil, nil + return nil, nil, nil, nil } // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. - ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond) + ticketAge := c.config.time().Sub(time.Unix(int64(session.createdAt), 0)) identity := pskIdentity{ - label: session.sessionTicket, - obfuscatedTicketAge: ticketAge + session.ageAdd, + label: cs.ticket, + obfuscatedTicketAge: uint32(ticketAge/time.Millisecond) + session.ageAdd, } hello.pskIdentities = []pskIdentity{identity} hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())} // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. - psk := cipherSuite.expandLabel(session.masterSecret, "resumption", - session.nonce, cipherSuite.hash.Size()) - earlySecret = cipherSuite.extract(psk, nil) + earlySecret = cipherSuite.extract(session.secret, nil) binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) transcript := cipherSuite.hash.New() helloBytes, err := hello.marshalWithoutBinders() if err != nil { - return "", nil, nil, nil, err + return nil, nil, nil, err } transcript.Write(helloBytes) pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} if err := hello.updateBinders(pskBinders); err != nil { - return "", nil, nil, nil, err + return nil, nil, nil, err } return @@ -485,6 +481,9 @@ func (hs *clientHandshakeState) handshake() error { return err } } + if err := hs.saveSessionTicket(); err != nil { + return err + } c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) c.isHandshakeComplete.Store(true) @@ -752,7 +751,7 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) { return false, nil } - if hs.session.vers != c.vers { + if hs.session.version != c.vers { c.sendAlert(alertHandshakeFailure) return false, errors.New("tls: server resumed a session with a different version") } @@ -762,9 +761,10 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) { return false, errors.New("tls: server resumed a session with a different cipher suite") } - // Restore masterSecret, peerCerts, and ocspResponse from previous state - hs.masterSecret = hs.session.masterSecret - c.peerCertificates = hs.session.serverCertificates + // Restore master secret and certificates from previous state + hs.masterSecret = hs.session.secret + c.peerCertificates = hs.session.peerCertificates + c.activeCertHandles = hs.c.activeCertHandles c.verifiedChains = hs.session.verifiedChains c.ocspResponse = hs.session.ocspResponse // Let the ServerHello SCTs override the session SCTs from the original @@ -836,8 +836,13 @@ func (hs *clientHandshakeState) readSessionTicket() error { if !hs.serverHello.ticketSupported { return nil } - c := hs.c + + if !hs.hello.ticketSupported { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent unrequested session ticket") + } + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err @@ -848,18 +853,29 @@ func (hs *clientHandshakeState) readSessionTicket() error { return unexpectedMessageError(sessionTicketMsg, msg) } - hs.session = &ClientSessionState{ - sessionTicket: sessionTicketMsg.ticket, - vers: c.vers, - cipherSuite: hs.suite.id, - masterSecret: hs.masterSecret, - serverCertificates: c.peerCertificates, - verifiedChains: c.verifiedChains, - receivedAt: c.config.time(), - ocspResponse: c.ocspResponse, - scts: c.scts, + hs.ticket = sessionTicketMsg.ticket + return nil +} + +func (hs *clientHandshakeState) saveSessionTicket() error { + if hs.ticket == nil { + return nil + } + c := hs.c + + cacheKey := c.clientSessionCacheKey() + if cacheKey == "" { + return nil } + session, err := c.sessionState() + if err != nil { + return err + } + session.secret = hs.masterSecret + + cs := &ClientSessionState{ticket: hs.ticket, session: session} + c.config.ClientSessionCache.Put(cacheKey, cs) return nil } @@ -885,7 +901,7 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { activeHandles := make([]*activeCert, len(certificates)) certs := make([]*x509.Certificate, len(certificates)) for i, asn1Data := range certificates { - cert, err := clientCertCache.newCert(asn1Data) + cert, err := globalCertCache.newCert(asn1Data) if err != nil { c.sendAlert(alertBadCertificate) return errors.New("tls: failed to parse certificate from server: " + err.Error()) diff --git a/handshake_client_test.go b/handshake_client_test.go index fef5038..cf7c09b 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -916,14 +916,14 @@ func testResumption(t *testing.T, version uint16) { } getTicket := func() []byte { - return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket + return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.ticket } deleteTicket := func() { ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey clientConfig.ClientSessionCache.Put(ticketKey, nil) } corruptTicket := func() { - clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.masterSecret[0] ^= 0xff + clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.session.secret[0] ^= 0xff } randomKey := func() [32]byte { var k [32]byte diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index 15e0a74..b26992b 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -23,7 +23,7 @@ type clientHandshakeStateTLS13 struct { hello *clientHelloMsg ecdheKey *ecdh.PrivateKey - session *ClientSessionState + session *SessionState earlySecret []byte binderKey []byte @@ -256,8 +256,8 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { } if pskSuite.hash == hs.suite.hash { // Update binders and obfuscated_ticket_age. - ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond) - hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd + ticketAge := c.config.time().Sub(time.Unix(int64(hs.session.createdAt), 0)) + hs.hello.pskIdentities[0].obfuscatedTicketAge = uint32(ticketAge/time.Millisecond) + hs.session.ageAdd transcript := hs.suite.hash.New() transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) @@ -355,7 +355,8 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { hs.usingPSK = true c.didResume = true - c.peerCertificates = hs.session.serverCertificates + c.peerCertificates = hs.session.peerCertificates + c.activeCertHandles = hs.session.activeCertHandles c.verifiedChains = hs.session.verifiedChains c.ocspResponse = hs.session.ocspResponse c.scts = hs.session.scts @@ -719,28 +720,21 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { return c.sendAlert(alertInternalError) } - // Save the resumption_master_secret and nonce instead of deriving the PSK - // to do the least amount of work on NewSessionTicket messages before we - // know if the ticket will be used. Forward secrecy of resumed connections - // is guaranteed by the requirement for pskModeDHE. - session := &ClientSessionState{ - sessionTicket: msg.label, - vers: c.vers, - cipherSuite: c.cipherSuite, - masterSecret: c.resumptionSecret, - serverCertificates: c.peerCertificates, - verifiedChains: c.verifiedChains, - receivedAt: c.config.time(), - nonce: msg.nonce, - useBy: c.config.time().Add(lifetime), - ageAdd: msg.ageAdd, - ocspResponse: c.ocspResponse, - scts: c.scts, - } + psk := cipherSuite.expandLabel(c.resumptionSecret, "resumption", + msg.nonce, cipherSuite.hash.Size()) - cacheKey := c.clientSessionCacheKey() - if cacheKey != "" { - c.config.ClientSessionCache.Put(cacheKey, session) + session, err := c.sessionState() + if err != nil { + c.sendAlert(alertInternalError) + return err + } + session.secret = psk + session.useBy = uint64(c.config.time().Add(lifetime).Unix()) + session.ageAdd = msg.ageAdd + cs := &ClientSessionState{ticket: msg.label, session: session} + + if cacheKey := c.clientSessionCacheKey(); cacheKey != "" { + c.config.ClientSessionCache.Put(cacheKey, cs) } return nil diff --git a/handshake_messages_test.go b/handshake_messages_test.go index b280f09..85efacf 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -6,7 +6,9 @@ package tls import ( "bytes" + "crypto/x509" "encoding/hex" + "math" "math/rand" "reflect" "strings" @@ -71,6 +73,10 @@ func TestMarshalUnmarshal(t *testing.T) { } m.marshal() // to fill any marshal cache in the message + if m, ok := m.(*SessionState); ok { + m.activeCertHandles = nil + } + if !reflect.DeepEqual(m1, m) { t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled) break @@ -97,7 +103,7 @@ func TestFuzz(t *testing.T) { rand := rand.New(rand.NewSource(0)) for _, m := range tests { for j := 0; j < 1000; j++ { - len := rand.Intn(100) + len := rand.Intn(1000) bytes := randomBytes(len, rand) // This just looks for crashes due to bounds errors etc. m.unmarshal(bytes) @@ -313,23 +319,59 @@ func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { return reflect.ValueOf(m) } +var sessionTestCerts []*x509.Certificate + +func init() { + cert, err := x509.ParseCertificate(testRSACertificate) + if err != nil { + panic(err) + } + sessionTestCerts = append(sessionTestCerts, cert) + cert, err = x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + sessionTestCerts = append(sessionTestCerts, cert) +} + func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value { s := &SessionState{} - s.version = uint16(rand.Intn(10000)) - s.cipherSuite = uint16(rand.Intn(10000)) - s.secret = randomBytes(rand.Intn(100)+1, rand) + isTLS13 := rand.Intn(10) > 5 + if isTLS13 { + s.version = VersionTLS13 + } else { + s.version = uint16(rand.Intn(VersionTLS13)) + } + s.isClient = rand.Intn(10) > 5 + s.cipherSuite = uint16(rand.Intn(math.MaxUint16)) s.createdAt = uint64(rand.Int63()) - for i := 0; i < rand.Intn(2)+1; i++ { - s.certificate.Certificate = append( - s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) + s.secret = randomBytes(rand.Intn(100)+1, rand) + if s.isClient || rand.Intn(10) > 5 { + if rand.Intn(10) > 5 { + s.peerCertificates = sessionTestCerts + } else { + s.peerCertificates = sessionTestCerts[:1] + } } - if rand.Intn(10) > 5 { - s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) + if rand.Intn(10) > 5 && s.peerCertificates != nil { + s.ocspResponse = randomBytes(rand.Intn(100)+1, rand) } - if rand.Intn(10) > 5 { + if rand.Intn(10) > 5 && s.peerCertificates != nil { for i := 0; i < rand.Intn(2)+1; i++ { - s.certificate.SignedCertificateTimestamps = append( - s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) + s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand)) + } + } + if s.isClient { + for i := 0; i < rand.Intn(3); i++ { + if rand.Intn(10) > 5 { + s.verifiedChains = append(s.verifiedChains, s.peerCertificates) + } else { + s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1]) + } + } + if isTLS13 { + s.useBy = uint64(rand.Int63()) + s.ageAdd = uint32(rand.Int63() & math.MaxUint32) } } return reflect.ValueOf(s) diff --git a/handshake_server.go b/handshake_server.go index 5e5badc..7dda656 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -448,7 +448,7 @@ func (hs *serverHandshakeState) checkForResumption() bool { return false } - sessionHasClientCerts := len(hs.sessionState.certificate.Certificate) != 0 + sessionHasClientCerts := len(hs.sessionState.peerCertificates) != 0 needClientCerts := requiresClientCert(c.config.ClientAuth) if needClientCerts && !sessionHasClientCerts { return false @@ -481,7 +481,7 @@ func (hs *serverHandshakeState) doResumeHandshake() error { return err } - if err := c.processCertsFromClient(hs.sessionState.certificate); err != nil { + if err := c.processCertsFromClient(hs.sessionState.certificate()); err != nil { return err } @@ -759,27 +759,15 @@ func (hs *serverHandshakeState) sendSessionTicket() error { c := hs.c m := new(newSessionTicketMsg) - createdAt := uint64(c.config.time().Unix()) + state, err := c.sessionState() + if err != nil { + return err + } + state.secret = hs.masterSecret if hs.sessionState != nil { // If this is re-wrapping an old key, then keep // the original time it was created. - createdAt = hs.sessionState.createdAt - } - - var certsFromClient [][]byte - for _, cert := range c.peerCertificates { - certsFromClient = append(certsFromClient, cert.Raw) - } - state := SessionState{ - version: c.vers, - cipherSuite: hs.suite.id, - createdAt: createdAt, - secret: hs.masterSecret, - certificate: Certificate{ - Certificate: certsFromClient, - OCSPStaple: c.ocspResponse, - SignedCertificateTimestamps: c.scts, - }, + state.createdAt = hs.sessionState.createdAt } stateBytes, err := state.Bytes() if err != nil { diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index f770a21..6753ad4 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -301,7 +301,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { // PSK connections don't re-establish client certificates, but carry // them over in the session ticket. Ensure the presence of client certs // in the ticket is consistent with the configured requirements. - sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0 + sessionHasClientCerts := len(sessionState.peerCertificates) != 0 needClientCerts := requiresClientCert(c.config.ClientAuth) if needClientCerts && !sessionHasClientCerts { continue @@ -331,7 +331,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { } c.didResume = true - if err := c.processCertsFromClient(sessionState.certificate); err != nil { + if err := c.processCertsFromClient(sessionState.certificate()); err != nil { return err } @@ -776,21 +776,11 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { m := new(newSessionTicketMsgTLS13) - var certsFromClient [][]byte - for _, cert := range c.peerCertificates { - certsFromClient = append(certsFromClient, cert.Raw) - } - state := &SessionState{ - version: c.vers, - cipherSuite: hs.suite.id, - createdAt: uint64(c.config.time().Unix()), - secret: psk, - certificate: Certificate{ - Certificate: certsFromClient, - OCSPStaple: c.ocspResponse, - SignedCertificateTimestamps: c.scts, - }, + state, err := c.sessionState() + if err != nil { + return err } + state.secret = psk stateBytes, err := state.Bytes() if err != nil { c.sendAlert(alertInternalError) diff --git a/ticket.go b/ticket.go index dfa0d43..44bedd6 100644 --- a/ticket.go +++ b/ticket.go @@ -10,6 +10,7 @@ import ( "crypto/hmac" "crypto/sha256" "crypto/subtle" + "crypto/x509" "errors" "io" @@ -18,12 +19,63 @@ import ( // A SessionState is a resumable session. type SessionState struct { - version uint16 // uint16 version; - // uint8 revision = 1; + // Encoded as a SessionState (in the language of RFC 8446, Section 3). + // + // enum { server(1), client(2) } SessionStateType; + // + // opaque Certificate<1..2^24-1>; + // + // Certificate CertificateChain<0..2^24-1>; + // + // struct { + // uint16 version; + // SessionStateType type; + // uint16 cipher_suite; + // uint64 created_at; + // opaque secret<1..2^8-1>; + // CertificateEntry certificate_list<0..2^24-1>; + // select (SessionState.type) { + // case server: /* empty */; + // case client: struct { + // CertificateChain verified_chains<0..2^24-1>; /* excluding leaf */ + // select (SessionState.version) { + // case VersionTLS10..VersionTLS12: /* empty */; + // case VersionTLS13: struct { + // uint64 use_by; + // uint32 age_add; + // }; + // }; + // }; + // }; + // } SessionState; + // + + version uint16 + isClient bool cipherSuite uint16 - createdAt uint64 - secret []byte // opaque master_secret<1..2^8-1>; - certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; + // createdAt is the generation time of the secret on the sever (which for + // TLS 1.0–1.2 might be earlier than the current session) and the time at + // which the ticket was received on the client. + createdAt uint64 // seconds since UNIX epoch + secret []byte // master secret for TLS 1.2, or the PSK for TLS 1.3 + peerCertificates []*x509.Certificate + activeCertHandles []*activeCert + ocspResponse []byte + scts [][]byte + + // Client-side fields. + verifiedChains [][]*x509.Certificate + + // Client-side TLS 1.3-only fields. + useBy uint64 // seconds since UNIX epoch + ageAdd uint32 +} + +// ClientSessionState contains the state needed by clients to resume TLS +// sessions. +type ClientSessionState struct { + ticket []byte + session *SessionState } // Bytes encodes the session, including any private fields, so that it can be @@ -31,38 +83,157 @@ type SessionState struct { // // The specific encoding should be considered opaque and may change incompatibly // between Go versions. -func (m *SessionState) Bytes() ([]byte, error) { +func (s *SessionState) Bytes() ([]byte, error) { var b cryptobyte.Builder - b.AddUint16(m.version) - b.AddUint8(1) // revision - b.AddUint16(m.cipherSuite) - addUint64(&b, m.createdAt) + b.AddUint16(s.version) + if s.isClient { + b.AddUint8(2) // client + } else { + b.AddUint8(1) // server + } + b.AddUint16(s.cipherSuite) + addUint64(&b, s.createdAt) b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secret) + b.AddBytes(s.secret) }) - marshalCertificate(&b, m.certificate) + marshalCertificate(&b, s.certificate()) + if s.isClient { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for _, chain := range s.verifiedChains { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + // We elide the first certificate because it's always the leaf. + if len(chain) == 0 { + b.SetError(errors.New("tls: internal error: empty verified chain")) + return + } + for _, cert := range chain[1:] { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert.Raw) + }) + } + }) + } + }) + if s.version >= VersionTLS13 { + addUint64(&b, s.useBy) + b.AddUint32(s.ageAdd) + } + } return b.Bytes() } +func (s *SessionState) certificate() Certificate { + return Certificate{ + Certificate: certificatesToBytesSlice(s.peerCertificates), + OCSPStaple: s.ocspResponse, + SignedCertificateTimestamps: s.scts, + } +} + +func certificatesToBytesSlice(certs []*x509.Certificate) [][]byte { + s := make([][]byte, 0, len(certs)) + for _, c := range certs { + s = append(s, c.Raw) + } + return s +} + // ParseSessionState parses a [SessionState] encoded by [SessionState.Bytes]. func ParseSessionState(data []byte) (*SessionState, error) { ss := &SessionState{} s := cryptobyte.String(data) - var revision uint8 + var typ uint8 + var cert Certificate if !s.ReadUint16(&ss.version) || - !s.ReadUint8(&revision) || - revision != 1 || + !s.ReadUint8(&typ) || + (typ != 1 && typ != 2) || !s.ReadUint16(&ss.cipherSuite) || !readUint64(&s, &ss.createdAt) || !readUint8LengthPrefixed(&s, &ss.secret) || len(ss.secret) == 0 || - !unmarshalCertificate(&s, &ss.certificate) || - !s.Empty() { + !unmarshalCertificate(&s, &cert) { + return nil, errors.New("tls: invalid session encoding") + } + for _, cert := range cert.Certificate { + c, err := globalCertCache.newCert(cert) + if err != nil { + return nil, err + } + ss.activeCertHandles = append(ss.activeCertHandles, c) + ss.peerCertificates = append(ss.peerCertificates, c.cert) + } + ss.ocspResponse = cert.OCSPStaple + ss.scts = cert.SignedCertificateTimestamps + if isClient := typ == 2; !isClient { + if !s.Empty() { + return nil, errors.New("tls: invalid session encoding") + } + return ss, nil + } + ss.isClient = true + if len(ss.peerCertificates) == 0 { + return nil, errors.New("tls: no server certificates in client session") + } + var chainList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&chainList) { + return nil, errors.New("tls: invalid session encoding") + } + for !chainList.Empty() { + var certList cryptobyte.String + if !chainList.ReadUint24LengthPrefixed(&certList) { + return nil, errors.New("tls: invalid session encoding") + } + var chain []*x509.Certificate + chain = append(chain, ss.peerCertificates[0]) + for !certList.Empty() { + var cert []byte + if !readUint24LengthPrefixed(&certList, &cert) { + return nil, errors.New("tls: invalid session encoding") + } + c, err := globalCertCache.newCert(cert) + if err != nil { + return nil, err + } + ss.activeCertHandles = append(ss.activeCertHandles, c) + chain = append(chain, c.cert) + } + ss.verifiedChains = append(ss.verifiedChains, chain) + } + if ss.version < VersionTLS13 { + if !s.Empty() { + return nil, errors.New("tls: invalid session encoding") + } + return ss, nil + } + if !s.ReadUint64(&ss.useBy) || !s.ReadUint32(&ss.ageAdd) || !s.Empty() { return nil, errors.New("tls: invalid session encoding") } return ss, nil } +// sessionState returns a partially filled-out [SessionState] with information +// from the current connection. +func (c *Conn) sessionState() (*SessionState, error) { + var verifiedChains [][]*x509.Certificate + if c.isClient { + verifiedChains = c.verifiedChains + if len(c.peerCertificates) == 0 { + return nil, errors.New("tls: internal error: empty peer certificates") + } + } + return &SessionState{ + version: c.vers, + cipherSuite: c.cipherSuite, + createdAt: uint64(c.config.time().Unix()), + peerCertificates: c.peerCertificates, + activeCertHandles: c.activeCertHandles, + ocspResponse: c.ocspResponse, + scts: c.scts, + isClient: c.isClient, + verifiedChains: verifiedChains, + }, nil +} + func (c *Conn) encryptTicket(state []byte) ([]byte, error) { if len(c.ticketKeys) == 0 { return nil, errors.New("tls: internal error: session ticket keys unavailable")