diff --git a/tls/common.go b/tls/common.go index 8704af6..30470ca 100644 --- a/tls/common.go +++ b/tls/common.go @@ -733,6 +733,8 @@ type Config struct { // used for debugging. KeyLogWriter io.Writer + SessionIDGenerator func(clientHello []byte, sessionID []byte) error + // mutex protects sessionTicketKeys and autoSessionTicketKeys. mutex sync.RWMutex // sessionTicketKeys contains zero or more ticket keys. If set, it means diff --git a/tls/handshake_client.go b/tls/handshake_client.go index 6b28926..198c6f6 100644 --- a/tls/handshake_client.go +++ b/tls/handshake_client.go @@ -111,13 +111,6 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) } - // A random session ID is used to detect when the server accepted a ticket - // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as - // a compatibility measure (see RFC 8446, Section 4.1.2). - if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { - return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) - } - if hello.vers >= VersionTLS12 { hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms() } @@ -144,6 +137,25 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} } + // A random session ID is used to detect when the server accepted a ticket + // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as + // a compatibility measure (see RFC 8446, Section 4.1.2). + + if config.SessionIDGenerator != nil { + buffer, err := hello.marshal() + if err != nil { + return nil, nil, err + } + if err := config.SessionIDGenerator(buffer, hello.sessionId); err != nil { + return nil, nil, errors.New("tls: generate session id failed: " + err.Error()) + } + hello.raw = nil + } else { + if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + } + return hello, key, nil } diff --git a/tls_compact/common.go b/tls_compact/common.go index 2562ae0..10efe43 100644 --- a/tls_compact/common.go +++ b/tls_compact/common.go @@ -723,6 +723,8 @@ type Config struct { // used for debugging. KeyLogWriter io.Writer + SessionIDGenerator func(clientHello []byte, sessionID []byte) error + // mutex protects sessionTicketKeys and autoSessionTicketKeys. mutex sync.RWMutex // sessionTicketKeys contains zero or more ticket keys. If set, it means the diff --git a/tls_compact/handshake_client.go b/tls_compact/handshake_client.go index 4907317..d606a7c 100644 --- a/tls_compact/handshake_client.go +++ b/tls_compact/handshake_client.go @@ -111,13 +111,6 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) } - // A random session ID is used to detect when the server accepted a ticket - // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as - // a compatibility measure (see RFC 8446, Section 4.1.2). - if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { - return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) - } - if hello.vers >= VersionTLS12 { hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms() } @@ -144,6 +137,25 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} } + // A random session ID is used to detect when the server accepted a ticket + // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as + // a compatibility measure (see RFC 8446, Section 4.1.2). + + if config.SessionIDGenerator != nil { + buffer, err := hello.marshal() + if err != nil { + return nil, nil, err + } + if err := config.SessionIDGenerator(buffer, hello.sessionId); err != nil { + return nil, nil, errors.New("tls: generate session id failed: " + err.Error()) + } + hello.raw = nil + } else { + if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + } + return hello, params, nil }