new: Support TLS-PSK (TLS 1.3) (#231)

* uTLS: X25519Kyber768Draft00 hybrid post-quantum key agreement by cloudflare/go (#222)

* crypto/tls: Add hybrid post-quantum key agreement  (#13)

* import: client-side KEM from cloudflare/go

* import: server-side KEM from cloudflare/go

* fix: modify test to get rid of CFEvents.

Note: uTLS does not promise any server-side functionality, and this change is made to be able to conduct unit tests which requires both side to be able to handle KEM Curves.

Co-authored-by: Christopher Wood <caw@heapingbits.net>
Co-Authored-By: Bas Westerbaan <bas@westerbaan.name>

----

Based on:

* crypto/tls: Add hybrid post-quantum key agreement 

Adds X25519Kyber512Draft00, X25519Kyber768Draft00, and
P256Kyber768Draft00 hybrid post-quantum key agreements with temporary
group identifiers.

The hybrid post-quantum key exchanges uses plain X{25519,448} instead
of HPKE, which we assume will be more likely to be adopted. The order
is chosen to match CECPQ2.

Not enabled by default.

Adds CFEvents to detect `HelloRetryRequest`s and to signal which
key agreement was used.

Co-authored-by: Christopher Wood <caw@heapingbits.net>

 [bas, 1.20.1: also adds P256Kyber768Draft00]
 [pwu, 1.20.4: updated circl to v1.3.3, moved code to cfevent.go]

* crypto: add support for CIRCL signature schemes

* only partially port the commit from cloudflare/go. We would stick to the official x509 at the cost of incompatibility.

Co-Authored-By: Bas Westerbaan <bas@westerbaan.name>
Co-Authored-By: Christopher Patton <3453007+cjpatton@users.noreply.github.com>
Co-Authored-By: Peter Wu <peter@lekensteyn.nl>

* crypto/tls: add new X25519Kyber768Draft00 code point

Ported from cloudflare/go to support the upcoming new post-quantum keyshare.

----

* Point tls.X25519Kyber768Draft00 to the new 0x6399 identifier while the
  old 0xfe31 identifier is available as tls.X25519Kyber768Draft00Old.
* Make sure that the kem.PrivateKey can always be mapped to the CurveID
  that was linked to it. This is needed since we now have two ID
  aliasing to the same scheme, and clients need to be able to detect
  whether the key share presented by the server actually matches the key
  share that the client originally sent.
* Update tests, add the new identifier and remove unnecessary code.

Link: https://mailarchive.ietf.org/arch/msg/tls/HAWpNpgptl--UZNSYuvsjB-Pc2k/
Link: https://datatracker.ietf.org/doc/draft-tls-westerbaan-xyber768d00/02/
Co-Authored-By: Peter Wu <peter@lekensteyn.nl>
Co-Authored-By: Bas Westerbaan <bas@westerbaan.name>

---------

Co-authored-by: Bas Westerbaan <bas@westerbaan.name>
Co-authored-by: Christopher Patton <3453007+cjpatton@users.noreply.github.com>
Co-authored-by: Peter Wu <peter@lekensteyn.nl>

* new: enable PQ parrots (#225)

* Redesign KeySharesEcdheParameters into KeySharesParameters which supports multiple types of keys.

* Optimize program logic to prevent using unwanted keys

* new: more parrots and safety update (#227)

* new: PQ and other parrots

Add new preset parrots:
- HelloChrome_114_Padding_PSK_Shuf
- HelloChrome_115_PQ
- HelloChrome_115_PQ_PSK

* new: ShuffleChromeTLSExtensions

Implement a new function `ShuffleChromeTLSExtensions(exts []TLSExtension) []TLSExtension`.

* update: include psk parameter for parrot-related functions

Update following functions' prototype to accept an optional pskExtension (of type *FakePreSharedKeyExtension):
- `UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID)` => `UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID, pskExtension ...*FakePreSharedKeyExtension)`
- `UTLSIdToSpec(id ClientHelloID)` => `UTLSIdToSpec(id ClientHelloID, pskExtension ...*FakePreSharedKeyExtension)`

* new: pre-defined error from UTLSIdToSpec

Update UTLSIdToSpec to return more comprehensive errors by pre-defining them, allowing easier error comparing/unwrapping.

* new: UtlsPreSharedKeyExtension

In `u_pre_shared_key.go`, create `PreSharedKeyExtension` as an interface, with 3 implementations:
- `UtlsPreSharedKeyExtension` implements full support for `pre_shared_key` less resuming after seeing HRR.
- `FakePreSharedKeyExtension` uses CipherSuiteID, SessionSecret and Identities to calculate the corresponding binders and send them, without setting the internal states. Therefore if the server accepts the PSK and tries to resume, the connection fails.
- `HardcodedPreSharedKeyExtension` allows user to hardcode Identities and Binders to be sent in the extension without setting the internal states. Therefore if the server accepts the PSK and tries to resume, the connection fails.

TODO: Only one of FakePreSharedKeyExtension and HardcodedPreSharedKeyExtension should be kept, the other one should be just removed. We still need to learn more of the safety of hardcoding both Identities and Binders without recalculating the latter.

* update: PSK minor changes and example

* Updates PSK implementations for more comprehensible interfaces when applying preset/json/raw fingerprints.
* Revert FakePreSharedKeyExtension to the old implementation. Add binder size checking.
* Implement TLS-PSK example

New bug: setting `tls.Config.ClientSessionCache` will cause PSK to fail. Currently users must set only `tls.UtlsPreSharedKeyExtension.ClientSessionCacheOverride`.

* fix: PSK failing if config session cache set

* Fix a bug causing PSK to fail if Config.ClientSessionCache is set.
* Removed `ClientSessionCacheOverride` from `UtlsPreSharedKeyExtension`. Set the `ClientSessionCache` in `Config`!

Co-Authored-By: zeeker999 <13848632+zeeker999@users.noreply.github.com>

* Optimize tls resumption (#235)

* feat: bug fix and refactor

* feat: improve example docs: add detailed explanation about the design feat: add assertion on uApplyPatch

* fix: address comments
feat: add option `OmitEmptyPsk` and throw error on empty psk by default
feat: revert changes to public interfaces

* fix: weird residue caused by merging conflict

* fix: remove merge conflict residue code

---------

Co-authored-by: Bas Westerbaan <bas@westerbaan.name>
Co-authored-by: Christopher Patton <3453007+cjpatton@users.noreply.github.com>
Co-authored-by: Peter Wu <peter@lekensteyn.nl>
Co-authored-by: zeeker999 <13848632+zeeker999@users.noreply.github.com>
Co-authored-by: 3andne <52860475+3andne@users.noreply.github.com>
This commit is contained in:
Gaukas Wang 2023-08-27 12:48:31 -06:00 committed by GitHub
parent 45e7f1de14
commit 8094658e76
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 1377 additions and 371 deletions

View file

@ -684,6 +684,14 @@ type Config struct {
// This field is ignored when InsecureSkipVerify is true.
InsecureSkipTimeVerify bool // [uTLS]
// OmitEmptyPsk determines whether utls will automatically conceal
// the psk extension when it is empty. When the psk extension is empty, the
// browser omits it from the client hello. Utls can mimic this behavior,
// but it deviates from the provided client hello specification, rendering
// it unsuitable as the default behavior. Users have the option to enable
// this behavior at their own discretion.
OmitEmptyPsk bool // [uTLS]
// InsecureServerNameToVerify is used to verify the hostname on the returned
// certificates. It is intended to use with spoofed ServerName.
// If InsecureServerNameToVerify is "*", crypto/tls will do normal
@ -881,6 +889,7 @@ func (c *Config) Clone() *Config {
InsecureSkipVerify: c.InsecureSkipVerify,
InsecureSkipTimeVerify: c.InsecureSkipTimeVerify,
InsecureServerNameToVerify: c.InsecureServerNameToVerify,
OmitEmptyPsk: c.OmitEmptyPsk,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,

View file

@ -96,7 +96,7 @@ type Conn struct {
// clientProtocol is the negotiated ALPN protocol.
clientProtocol string
utls utlsConnExtraFields // [UTLS】 used for extensive things such as ALPS
utls utlsConnExtraFields // [UTLS] used for extensive things such as ALPS, PSK, etc
// input/output
in, out halfConn

View file

@ -0,0 +1,159 @@
package main
import (
"fmt"
"net"
"strings"
"time"
tls "github.com/refraction-networking/utls"
)
type ClientSessionCache struct {
sessionKeyMap map[string]*tls.ClientSessionState
}
func NewClientSessionCache() tls.ClientSessionCache {
return &ClientSessionCache{
sessionKeyMap: make(map[string]*tls.ClientSessionState),
}
}
func (csc *ClientSessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) {
if session, ok = csc.sessionKeyMap[sessionKey]; ok {
fmt.Printf("Getting session for %s\n", sessionKey)
return session, true
}
fmt.Printf("Missing session for %s\n", sessionKey)
return nil, false
}
func (csc *ClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
if csc.sessionKeyMap == nil {
fmt.Printf("Deleting session for %s\n", sessionKey)
delete(csc.sessionKeyMap, sessionKey)
} else {
fmt.Printf("Putting session for %s\n", sessionKey)
csc.sessionKeyMap[sessionKey] = cs
}
}
func runResumptionCheck(helloID tls.ClientHelloID, serverAddr string, retry int, verbose bool) {
csc := NewClientSessionCache()
tcpConn, err := net.Dial("tcp", serverAddr)
if err != nil {
panic(err)
}
// Everything below this line is brought to you by uTLS API, enjoy!
// use chs
tlsConn := tls.UClient(tcpConn, &tls.Config{
ServerName: strings.Split(serverAddr, ":")[0],
// NextProtos: []string{"h2", "http/1.1"},
ClientSessionCache: csc, // set this so session tickets will be saved
OmitEmptyPsk: true,
}, helloID)
// HS
err = tlsConn.Handshake()
if err != nil {
panic(err)
}
var tlsVer uint16
if tlsConn.ConnectionState().HandshakeComplete {
tlsVer = tlsConn.ConnectionState().Version
if verbose {
fmt.Println("Handshake complete")
fmt.Printf("TLS Version: %04x\n", tlsVer)
}
if tlsVer == tls.VersionTLS13 {
if verbose {
fmt.Printf("Expecting PSK resumption\n")
}
} else if tlsVer == tls.VersionTLS12 {
if verbose {
fmt.Printf("Expecting session ticket resumption\n")
}
} else {
panic("Don't try resumption on old TLS versions")
}
if tlsConn.HandshakeState.State13.UsingPSK {
panic("unintended using of PSK happened...")
} else if tlsConn.DidTls12Resume() {
panic("unintended using of session ticket happened...")
} else {
if verbose {
fmt.Println("First connection, no PSK/session ticket to use.")
}
}
tlsConn.SetReadDeadline(time.Now().Add(1 * time.Second))
tlsConn.Read(make([]byte, 1024)) // trigger a read so NewSessionTicket gets handled
}
tlsConn.Close()
for i := 0; i < retry; i++ {
tcpConnPSK, err := net.Dial("tcp", serverAddr)
if err != nil {
panic(err)
}
tlsConnPSK := tls.UClient(tcpConnPSK, &tls.Config{
ServerName: strings.Split(serverAddr, ":")[0],
ClientSessionCache: csc,
OmitEmptyPsk: true,
}, helloID)
// HS
err = tlsConnPSK.Handshake()
if verbose {
fmt.Printf("tlsConnPSK.HandshakeState.Hello.Raw %v\n", tlsConnPSK.HandshakeState.Hello.Raw)
fmt.Printf("tlsConnPSK.HandshakeState.Hello.PskIdentities: %v\n", tlsConnPSK.HandshakeState.Hello.PskIdentities)
}
if err != nil {
panic(err)
}
if tlsConnPSK.ConnectionState().HandshakeComplete {
if verbose {
fmt.Println("Handshake complete")
}
newVer := tlsConnPSK.ConnectionState().Version
if verbose {
fmt.Printf("TLS Version: %04x\n", newVer)
}
if newVer != tlsVer {
panic("Tls version changed unexpectedly on the second connection")
}
if tlsVer == tls.VersionTLS13 && tlsConnPSK.HandshakeState.State13.UsingPSK {
fmt.Println("[PSK used]")
return
} else if tlsVer == tls.VersionTLS12 && tlsConnPSK.DidTls12Resume() {
fmt.Println("[session ticket used]")
return
}
}
time.Sleep(700 * time.Millisecond)
}
panic(fmt.Sprintf("PSK or session ticket not used for a resumption session, server %s, helloID: %s", serverAddr, helloID.Client))
}
func main() {
tls13Url := "www.microsoft.com:443"
tls12Url1 := "spocs.getpocket.com:443"
tls12Url2 := "marketplace.visualstudio.com:443"
runResumptionCheck(tls.HelloChrome_100_PSK, tls13Url, 1, false) // psk + utls
runResumptionCheck(tls.HelloGolang, tls13Url, 1, false) // psk + crypto/tls
runResumptionCheck(tls.HelloChrome_100_PSK, tls12Url1, 10, false) // session ticket + utls
runResumptionCheck(tls.HelloGolang, tls12Url1, 10, false) // session ticket + crypto/tls
runResumptionCheck(tls.HelloChrome_100_PSK, tls12Url2, 10, false) // session ticket + utls
runResumptionCheck(tls.HelloGolang, tls12Url2, 10, false) // session ticket + crypto/tls
}

View file

@ -312,6 +312,12 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
func (c *Conn) loadSession(hello *clientHelloMsg) (
session *SessionState, earlySecret, binderKey []byte, err error) {
// [UTLS SECTION START]
if c.utls.sessionController != nil {
c.utls.sessionController.onEnterLoadSessionCheck()
defer c.utls.sessionController.onLoadSessionReturn()
}
// [UTLS SECTION END]
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
return nil, nil, nil, nil
}
@ -450,6 +456,11 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
// Compute the PSK binders. See RFC 8446, Section 4.2.11.2.
earlySecret = cipherSuite.extract(session.secret, nil)
binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil)
// [UTLS SECTION START]
if c.utls.sessionController != nil && !c.utls.sessionController.shouldLoadSessionWriteBinders() {
return
}
// [UTLS SECTION END]
transcript := cipherSuite.hash.New()
helloBytes, err := hello.marshalWithoutBinders()
if err != nil {
@ -460,7 +471,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
if err := hello.updateBinders(pskBinders); err != nil {
return nil, nil, nil, err
}
return
}

View file

@ -319,7 +319,7 @@ func (m *clientHelloMsg) marshal() ([]byte, error) {
}
// marshalWithoutBinders returns the ClientHello through the
// FakePreSharedKeyExtension.identities field, according to RFC 8446, Section
// PreSharedKeyExtension.identities field, according to RFC 8446, Section
// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) {
bindersLen := 2 // uint16 length prefix

View file

@ -854,7 +854,7 @@ func TestCloneNonFuncFields(t *testing.T) {
f.Set(reflect.ValueOf("b"))
case "ClientAuth":
f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
case "InsecureSkipVerify", "InsecureSkipTimeVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
case "InsecureSkipVerify", "InsecureSkipTimeVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites", "OmitEmptyPsk":
f.Set(reflect.ValueOf(true))
case "InsecureServerNameToVerify":
f.Set(reflect.ValueOf("c"))

View file

@ -85,7 +85,9 @@ func (c *CompressionMethodsJSONUnmarshaler) CompressionMethods() []uint8 {
}
type TLSExtensionsJSONUnmarshaler struct {
extensions []TLSExtensionJSON
AllowUnknownExt bool // if set, unknown extensions will be added as GenericExtension, without recovering ext payload
UseRealPSK bool // if set, PSK extension will be real PSK extension, otherwise it will be fake PSK extension
extensions []TLSExtensionJSON
}
func (e *TLSExtensionsJSONUnmarshaler) UnmarshalJSON(jsonStr []byte) error {
@ -107,14 +109,28 @@ func (e *TLSExtensionsJSONUnmarshaler) UnmarshalJSON(jsonStr []byte) error {
// get extension type from ID
var ext TLSExtension = ExtensionFromID(extID)
if ext == nil {
// fallback to generic extension
ext = genericExtension(extID, accepter.extNameOnly.Name)
if e.AllowUnknownExt {
// fallback to generic extension, without recovering ext payload
ext = genericExtension(extID, accepter.extNameOnly.Name)
} else {
return fmt.Errorf("extension %s (%d) is not JSON compatible", accepter.extNameOnly.Name, extID)
}
}
switch extID {
case extensionPreSharedKey:
// PSK extension, need to see if we do real or fake PSK
if e.UseRealPSK {
ext = &UtlsPreSharedKeyExtension{}
} else {
ext = &FakePreSharedKeyExtension{}
}
}
if extJsonCompatible, ok := ext.(TLSExtensionJSON); ok {
exts = append(exts, extJsonCompatible)
} else {
return fmt.Errorf("extension %d (%s) is not JSON compatible", extID, accepter.extNameOnly.Name)
return fmt.Errorf("extension %s (%d) is not JSON compatible", accepter.extNameOnly.Name, extID)
}
}
}

View file

@ -210,7 +210,7 @@ func (chs *ClientHelloSpec) ReadCompressionMethods(compressionMethods []byte) er
// a byte slice into []TLSExtension.
//
// If keepPSK is not set, the PSK extension will cause an error.
func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool) error {
func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool, realPSK bool) error {
extensions := cryptobyte.String(b)
for !extensions.Empty() {
var extension uint16
@ -225,6 +225,16 @@ func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool)
ext := ExtensionFromID(extension)
extWriter, ok := ext.(TLSExtensionWriter)
if ext != nil && ok { // known extension and implements TLSExtensionWriter properly
switch extension {
case extensionPreSharedKey:
// PSK extension, need to see if we do real or fake PSK
if realPSK {
extWriter = &UtlsPreSharedKeyExtension{}
} else {
extWriter = &FakePreSharedKeyExtension{}
}
}
if extension == extensionSupportedVersions {
chs.TLSVersMin = 0
chs.TLSVersMax = 0
@ -247,13 +257,15 @@ func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool)
func (chs *ClientHelloSpec) AlwaysAddPadding() {
alreadyHasPadding := false
for _, ext := range chs.Extensions {
for idx, ext := range chs.Extensions {
if _, ok := ext.(*UtlsPaddingExtension); ok {
alreadyHasPadding = true
break
}
if _, ok := ext.(*FakePreSharedKeyExtension); ok {
alreadyHasPadding = true // PSK must be last, so we don't need to add padding
if _, ok := ext.(PreSharedKeyExtension); ok {
alreadyHasPadding = true // PSK must be last, so we can't append padding after it
// instead we will insert padding before PSK
chs.Extensions = append(chs.Extensions[:idx], append([]TLSExtension{&UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle}}, chs.Extensions[idx:]...)...)
break
}
}
@ -452,14 +464,20 @@ func (chs *ClientHelloSpec) ImportTLSClientHelloFromJSON(jsonB []byte) error {
}
// FromRaw converts a ClientHello message in the form of raw bytes into a ClientHelloSpec.
func (chs *ClientHelloSpec) FromRaw(raw []byte, allowBluntMimicry ...bool) error {
//
// ctrlFlags: []bool{bluntMimicry, realPSK}
func (chs *ClientHelloSpec) FromRaw(raw []byte, ctrlFlags ...bool) error {
if chs == nil {
return errors.New("cannot unmarshal into nil ClientHelloSpec")
}
var bluntMimicry = false
if len(allowBluntMimicry) == 1 {
bluntMimicry = allowBluntMimicry[0]
var realPSK = false
if len(ctrlFlags) > 0 {
bluntMimicry = ctrlFlags[0]
}
if len(ctrlFlags) > 1 {
realPSK = ctrlFlags[1]
}
*chs = ClientHelloSpec{} // reset
@ -526,7 +544,7 @@ func (chs *ClientHelloSpec) FromRaw(raw []byte, allowBluntMimicry ...bool) error
return errors.New("unable to read extensions data")
}
if err := chs.ReadTLSExtensions(extensions, bluntMimicry); err != nil {
if err := chs.ReadTLSExtensions(extensions, bluntMimicry, realPSK); err != nil {
return err
}
@ -708,3 +726,61 @@ func EnableWeakCiphers() {
suiteECDHE | suiteTLS12 | suiteSHA384, cipherAES, utlsMacSHA384, nil},
}...)
}
func mapSlice[T any, U any](slice []T, transform func(T) U) []U {
newSlice := make([]U, 0, len(slice))
for _, t := range slice {
newSlice = append(newSlice, transform(t))
}
return newSlice
}
func panicOnNil(caller string, params ...any) {
for i, p := range params {
if p == nil {
panic(fmt.Sprintf("tls: %s failed: the [%d] parameter is nil", caller, i))
}
}
}
func anyTrue[T any](slice []T, predicate func(i int, t *T) bool) bool {
for i := 0; i < len(slice); i++ {
if predicate(i, &slice[i]) {
return true
}
}
return false
}
func allTrue[T any](slice []T, predicate func(i int, t *T) bool) bool {
for i := 0; i < len(slice); i++ {
if !predicate(i, &slice[i]) {
return false
}
}
return true
}
func uAssert(condition bool, msg string) {
if !condition {
panic(msg)
}
}
func sliceEq[T comparable](sliceA []T, sliceB []T) bool {
if len(sliceA) != len(sliceB) {
return false
}
for i := 0; i < len(sliceA); i++ {
if sliceA[i] != sliceB[i] {
return false
}
}
return true
}
type Initializable interface {
// IsInitialized returns a boolean indicating whether the extension has been initialized.
// If false is returned, utls will initialize the extension.
IsInitialized() bool
}

209
u_conn.go
View file

@ -14,23 +14,26 @@ import (
"errors"
"fmt"
"hash"
"io"
"net"
"strconv"
)
type ClientHelloBuildStatus int
const NotBuilt ClientHelloBuildStatus = 0
const BuildByUtls ClientHelloBuildStatus = 1
const BuildByGoTLS ClientHelloBuildStatus = 2
type UConn struct {
*Conn
Extensions []TLSExtension
ClientHelloID ClientHelloID
pskExtension []*FakePreSharedKeyExtension
Extensions []TLSExtension
ClientHelloID ClientHelloID
sessionController *sessionController
ClientHelloBuilt bool
HandshakeState PubClientHandshakeState
clientHelloBuildStatus ClientHelloBuildStatus
// sessionID may or may not depend on ticket; nil => random
GetSessionID func(ticket []byte) [32]byte
HandshakeState PubClientHandshakeState
greaseSeed [ssl_grease_last_index]uint16
@ -44,15 +47,17 @@ type UConn struct {
// UClient returns a new uTLS client, with behavior depending on clientHelloID.
// Config CAN be nil, but make sure to eventually specify ServerName.
func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID, pskExtension ...*FakePreSharedKeyExtension) *UConn {
func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn {
if config == nil {
config = &Config{}
}
tlsConn := Conn{conn: conn, config: config, isClient: true}
handshakeState := PubClientHandshakeState{C: &tlsConn, Hello: &PubClientHelloMsg{}}
uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, pskExtension: pskExtension, HandshakeState: handshakeState}
uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, HandshakeState: handshakeState}
uconn.HandshakeState.uconn = &uconn
uconn.handshakeFn = uconn.clientHandshake
uconn.sessionController = newSessionController(&uconn)
uconn.utls.sessionController = uconn.sessionController
return &uconn
}
@ -73,9 +78,10 @@ func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID, pskExte
// default/mimicked ClientHello.
func (uconn *UConn) BuildHandshakeState() error {
if uconn.ClientHelloID == HelloGolang {
if uconn.ClientHelloBuilt {
if uconn.clientHelloBuildStatus == BuildByGoTLS {
return nil
}
uAssert(uconn.clientHelloBuildStatus == NotBuilt, "BuildHandshakeState failed: invalid call, client hello has already been built by utls")
// use default Golang ClientHello.
hello, keySharePrivate, err := uconn.makeClientHello()
@ -92,8 +98,10 @@ func (uconn *UConn) BuildHandshakeState() error {
return fmt.Errorf("uTLS: unknown keySharePrivate type: %T", keySharePrivate)
}
uconn.HandshakeState.C = uconn.Conn
uconn.clientHelloBuildStatus = BuildByGoTLS
} else {
if !uconn.ClientHelloBuilt {
uAssert(uconn.clientHelloBuildStatus == BuildByUtls || uconn.clientHelloBuildStatus == NotBuilt, "BuildHandshakeState failed: invalid call, client hello has already been built by go-tls")
if uconn.clientHelloBuildStatus == NotBuilt {
err := uconn.applyPresetByID(uconn.ClientHelloID)
if err != nil {
return err
@ -107,52 +115,106 @@ func (uconn *UConn) BuildHandshakeState() error {
if err != nil {
return err
}
err = uconn.uLoadSession()
if err != nil {
return err
}
err = uconn.MarshalClientHello()
if err != nil {
return err
}
uconn.uApplyPatch()
uconn.sessionController.finalCheck()
uconn.clientHelloBuildStatus = BuildByUtls
}
uconn.ClientHelloBuilt = true
return nil
}
func (uconn *UConn) uLoadSession() error {
if cfg := uconn.config; cfg.SessionTicketsDisabled || cfg.ClientSessionCache == nil {
return nil
}
switch uconn.sessionController.shouldLoadSession() {
case shouldReturn:
case shouldSetTicket:
uconn.sessionController.setSessionTicketToUConn()
case shouldSetPsk:
uconn.sessionController.setPskToUConn()
case shouldLoad:
hello := uconn.HandshakeState.Hello.getPrivatePtr()
uconn.sessionController.utlsAboutToLoadSession()
session, earlySecret, binderKey, err := uconn.loadSession(hello)
if session == nil || err != nil {
return err
}
if session.version == VersionTLS12 {
// We use the session ticket extension for tls 1.2 session resumption
uconn.sessionController.initSessionTicketExt(session, hello.sessionTicket)
uconn.sessionController.setSessionTicketToUConn()
} else {
uconn.sessionController.initPskExt(session, earlySecret, binderKey, hello.pskIdentities)
}
}
return nil
}
func (uconn *UConn) uApplyPatch() {
helloLen := len(uconn.HandshakeState.Hello.Raw)
if uconn.sessionController.shouldUpdateBinders() {
uconn.sessionController.updateBinders()
uconn.sessionController.setPskToUConn()
}
uAssert(helloLen == len(uconn.HandshakeState.Hello.Raw), "tls: uApplyPatch Failed: the patch should never change the length of the marshaled clientHello")
}
func (uconn *UConn) DidTls12Resume() bool {
return uconn.didResume
}
// SetSessionState sets the session ticket, which may be preshared or fake.
// If session is nil, the body of session ticket extension will be unset,
// but the extension itself still MAY be present for mimicking purposes.
// Session tickets to be reused - use same cache on following connections.
//
// Deprecated: This method is deprecated in favor of SetSessionTicketExtension,
// as it only handles session override of TLS 1.2
func (uconn *UConn) SetSessionState(session *ClientSessionState) error {
var sessionTicket []uint8
sessionTicketExt := &SessionTicketExtension{Initialized: true}
if session != nil {
sessionTicket = session.ticket
uconn.HandshakeState.Session = session.session
sessionTicketExt.Ticket = session.ticket
sessionTicketExt.Session = session.session
}
uconn.HandshakeState.Hello.TicketSupported = true
uconn.HandshakeState.Hello.SessionTicket = sessionTicket
return uconn.SetSessionTicketExtension(sessionTicketExt)
}
for _, ext := range uconn.Extensions {
st, ok := ext.(*SessionTicketExtension)
if !ok {
continue
}
st.Session = session
if session != nil {
if len(session.SessionTicket()) > 0 {
if uconn.GetSessionID != nil {
sid := uconn.GetSessionID(session.SessionTicket())
uconn.HandshakeState.Hello.SessionId = sid[:]
return nil
}
}
var sessionID [32]byte
_, err := io.ReadFull(uconn.config.rand(), sessionID[:])
if err != nil {
return err
}
uconn.HandshakeState.Hello.SessionId = sessionID[:]
}
// SetSessionTicket sets the session ticket extension.
// If extension is nil, this will be a no-op.
func (uconn *UConn) SetSessionTicketExtension(sessionTicketExt ISessionTicketExtension) error {
if uconn.config.SessionTicketsDisabled || uconn.config.ClientSessionCache == nil {
return fmt.Errorf("tls: SetSessionTicketExtension failed: session is disabled")
}
if sessionTicketExt == nil {
return nil
}
return nil
return uconn.sessionController.overrideSessionTicketExt(sessionTicketExt)
}
// SetPskExtension sets the psk extension for tls 1.3 resumption. This is a no-op if the psk is nil.
func (uconn *UConn) SetPskExtension(pskExt PreSharedKeyExtension) error {
if uconn.config.SessionTicketsDisabled || uconn.config.ClientSessionCache == nil {
return fmt.Errorf("tls: SetPskExtension failed: session is disabled")
}
if pskExt == nil {
return nil
}
uconn.HandshakeState.Hello.TicketSupported = true
return uconn.sessionController.overridePskExt(pskExt)
}
// If you want session tickets to be reused - use same cache on following connections
@ -397,7 +459,7 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
hello := c.HandshakeState.Hello.getPrivatePtr()
defer func() { c.HandshakeState.Hello = hello.getPublicPtr() }()
sessionIsAlreadySet := c.HandshakeState.Session != nil
sessionIsLocked := c.utls.sessionController.isSessionLocked()
// after this point exactly 1 out of 2 HandshakeState pointers is non-nil,
// useTLS13 variable tells which pointer
@ -434,9 +496,24 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
if c.handshakes > 0 {
hello.secureRenegotiation = c.clientFinished[:]
}
// [uTLS section ends]
session, earlySecret, binderKey, err := c.loadSession(hello)
var (
session *SessionState
earlySecret []byte
binderKey []byte
)
if !sessionIsLocked {
// [uTLS section ends]
session, earlySecret, binderKey, err = c.loadSession(hello)
// [uTLS section start]
} else {
session = c.HandshakeState.Session
earlySecret = c.HandshakeState.State13.EarlySecret
binderKey = c.HandshakeState.State13.BinderKey
}
// [uTLS section ends]
if err != nil {
return err
}
@ -456,21 +533,20 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
}()
}
cacheKey := c.clientSessionCacheKey()
if c.config.ClientSessionCache != nil {
cs, ok := c.config.ClientSessionCache.Get(cacheKey)
if !sessionIsAlreadySet && ok { // uTLS: do not overwrite already set session
err = c.SetSessionState(cs)
if err != nil {
return
}
}
}
if _, err := c.writeHandshakeRecord(hello, nil); err != nil {
return err
}
if hello.earlyData {
suite := cipherSuiteTLS13ByID(session.cipherSuite)
transcript := suite.hash.New()
if err := transcriptMsg(hello, transcript); err != nil {
return err
}
earlyTrafficSecret := suite.deriveSecret(earlySecret, clientEarlyTrafficLabel, transcript)
c.quicSetWriteSecret(QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret)
}
msg, err := c.readHandshake(nil)
if err != nil {
return err
@ -491,9 +567,11 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
hs13 := c.HandshakeState.toPrivate13()
hs13.serverHello = serverHello
hs13.hello = hello
if !sessionIsAlreadySet {
hs13.keySharesParams = NewKeySharesParameters()
if !sessionIsLocked {
hs13.earlySecret = earlySecret
hs13.binderKey = binderKey
hs13.session = session
}
hs13.ctx = ctx
// In TLS 1.3, session tickets are delivered after the handshake.
@ -508,6 +586,7 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
hs12.serverHello = serverHello
hs12.hello = hello
hs12.ctx = ctx
hs12.session = session
err = hs12.handshake()
if handshakeState := hs12.toPublic12(); handshakeState != nil {
c.HandshakeState = *handshakeState
@ -515,17 +594,6 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
if err != nil {
return err
}
// If we had a successful handshake and hs.session is different from
// the one already cached - cache a new one.
if cacheKey != "" && hs12.session != nil && session != hs12.session {
hs12cs := &ClientSessionState{
ticket: hs12.ticket,
session: hs12.session,
}
c.config.ClientSessionCache.Put(cacheKey, hs12cs)
}
return nil
}
@ -556,7 +624,7 @@ func (uconn *UConn) MarshalClientHello() error {
if paddingExt == nil {
paddingExt = pe
} else {
return errors.New("multiple padding extensions!")
return errors.New("multiple padding extensions")
}
}
}
@ -598,7 +666,9 @@ func (uconn *UConn) MarshalClientHello() error {
if len(uconn.Extensions) > 0 {
binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen))
for _, ext := range uconn.Extensions {
bufferedWriter.ReadFrom(ext)
if _, err := bufferedWriter.ReadFrom(ext); err != nil {
return err
}
}
}
@ -784,7 +854,10 @@ func (c *Conn) utlsConnectionStateLocked(state *ConnectionState) {
}
type utlsConnExtraFields struct {
// Application Settings (ALPS)
hasApplicationSettings bool
peerApplicationSettings []byte
localApplicationSettings []byte
sessionController *sessionController
}

View file

@ -12,6 +12,7 @@ import (
"net"
"os"
"os/exec"
"runtime/debug"
"strings"
"testing"
"time"
@ -522,6 +523,9 @@ func (test *clientTest) runUTLS(t *testing.T, write bool, hello helloStrategy, o
doneChan := make(chan bool)
go func() {
defer func() {
if err := recover(); err != nil {
fmt.Printf("panic occurred: %v\n %s\n", err, string(debug.Stack()))
}
// Give time to the send buffer to drain, to avoid the kernel
// sending a RST and cutting off the flow. See Issue 18701.
time.Sleep(10 * time.Millisecond)

View file

@ -18,6 +18,8 @@ type Fingerprinter struct {
// have any padding, but you suspect that other changes you make to the final hello
// (including things like different SNI lengths) would cause padding to be necessary
AlwaysAddPadding bool
RealPSKResumption bool // if set, PSK extension (if any) will be real PSK extension, otherwise it will be fake PSK extension
}
// FingerprintClientHello returns a ClientHelloSpec which is based on the
@ -43,7 +45,8 @@ func (f *Fingerprinter) FingerprintClientHello(data []byte) (clientHelloSpec *Cl
// as a more precise name for the function
func (f *Fingerprinter) RawClientHello(raw []byte) (clientHelloSpec *ClientHelloSpec, err error) {
clientHelloSpec = &ClientHelloSpec{}
err = clientHelloSpec.FromRaw(raw, f.AllowBluntMimicry)
err = clientHelloSpec.FromRaw(raw, f.AllowBluntMimicry, f.RealPSKResumption)
if err != nil {
return nil, err
}

View file

@ -17,23 +17,12 @@ import (
)
var ErrUnknownClientHelloID = errors.New("tls: unknown ClientHelloID")
var ErrNotPSKClientHelloID = errors.New("tls: ClientHello does not contain pre_shared_key extension")
var ErrPSKExtensionExpected = errors.New("tls: pre_shared_key extension expected when fetching preset ClientHelloSpec")
// UTLSIdToSpec converts a ClientHelloID to a corresponding ClientHelloSpec.
//
// Exported internal function utlsIdToSpec per request.
func UTLSIdToSpec(id ClientHelloID, pskExtension ...*FakePreSharedKeyExtension) (ClientHelloSpec, error) {
if len(pskExtension) > 1 {
return ClientHelloSpec{}, errors.New("tls: at most one FakePreSharedKeyExtensions is allowed")
}
chs, err := utlsIdToSpec(id)
if err != nil && errors.Is(err, ErrUnknownClientHelloID) {
chs, err = utlsIdToSpecWithPSK(id, pskExtension...)
}
return chs, err
func UTLSIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
return utlsIdToSpec(id)
}
func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
@ -1995,24 +1984,6 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
},
},
}, nil
default:
if id.Client == helloRandomized || id.Client == helloRandomizedALPN || id.Client == helloRandomizedNoALPN {
// Use empty values as they can be filled later by UConn.ApplyPreset or manually.
return generateRandomizedSpec(&id, "", nil, nil)
}
return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrUnknownClientHelloID, id.Str())
}
}
func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...*FakePreSharedKeyExtension) (ClientHelloSpec, error) {
switch id {
case HelloChrome_100_PSK, HelloChrome_112_PSK_Shuf, HelloChrome_114_Padding_PSK_Shuf, HelloChrome_115_PQ_PSK:
if len(pskExtension) == 0 || pskExtension[0] == nil {
return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrPSKExtensionExpected, id.Str())
}
}
switch id {
case HelloChrome_100_PSK:
return ClientHelloSpec{
CipherSuites: []uint16{
@ -2081,7 +2052,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...*FakePreSharedKeyExte
}},
&ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
&UtlsGREASEExtension{},
pskExtension[0],
&UtlsPreSharedKeyExtension{},
},
}, nil
case HelloChrome_112_PSK_Shuf:
@ -2152,7 +2123,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...*FakePreSharedKeyExte
}},
&ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
&UtlsGREASEExtension{},
pskExtension[0],
&UtlsPreSharedKeyExtension{},
}),
}, nil
case HelloChrome_114_Padding_PSK_Shuf:
@ -2224,7 +2195,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...*FakePreSharedKeyExte
&ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
&UtlsGREASEExtension{},
&UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle},
pskExtension[0],
&UtlsPreSharedKeyExtension{},
}),
}, nil
// Chrome w/ Post-Quantum Key Agreement
@ -2298,12 +2269,17 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...*FakePreSharedKeyExte
}},
&ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
&UtlsGREASEExtension{},
pskExtension[0],
&UtlsPreSharedKeyExtension{},
}),
}, nil
}
default:
if id.Client == helloRandomized || id.Client == helloRandomizedALPN || id.Client == helloRandomizedNoALPN {
// Use empty values as they can be filled later by UConn.ApplyPreset or manually.
return generateRandomizedSpec(&id, "", nil)
}
return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrUnknownClientHelloID, id.Str())
return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrUnknownClientHelloID, id.Str())
}
}
// ShuffleChromeTLSExtensions shuffles the extensions in the ClientHelloSpec to avoid ossification.
@ -2315,7 +2291,7 @@ func ShuffleChromeTLSExtensions(exts []TLSExtension) []TLSExtension {
// and returns true on success. For these extensions are considered positionally invariant.
var skipShuf = func(idx int, exts []TLSExtension) bool {
switch exts[idx].(type) {
case *UtlsGREASEExtension, *UtlsPaddingExtension, *FakePreSharedKeyExtension:
case *UtlsGREASEExtension, *UtlsPaddingExtension, PreSharedKeyExtension:
return true
default:
return false
@ -2345,9 +2321,8 @@ func (uconn *UConn) applyPresetByID(id ClientHelloID) (err error) {
}
case helloCustom:
return nil
default:
spec, err = UTLSIdToSpec(id, uconn.pskExtension...)
spec, err = UTLSIdToSpec(id)
if err != nil {
return err
}
@ -2379,7 +2354,6 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
}
uconn.HandshakeState.State13.KeySharesParams = NewKeySharesParameters()
hello := uconn.HandshakeState.Hello
session := uconn.HandshakeState.Session
switch len(hello.Random) {
case 0:
@ -2420,7 +2394,12 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
hello.CipherSuites[i] = GetBoringGREASEValue(uconn.greaseSeed, ssl_grease_cipher)
}
}
uconn.GetSessionID = p.GetSessionID
var sessionID [32]byte
_, err = io.ReadFull(uconn.config.rand(), sessionID[:])
if err != nil {
return err
}
uconn.HandshakeState.Hello.SessionId = sessionID[:]
uconn.Extensions = make([]TLSExtension, len(p.Extensions))
copy(uconn.Extensions, p.Extensions)
@ -2445,20 +2424,6 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
return errors.New("at most 2 grease extensions are supported")
}
grease_extensions_seen += 1
case *SessionTicketExtension:
var cs *ClientSessionState
if session == nil && uconn.config.ClientSessionCache != nil {
cacheKey := uconn.clientSessionCacheKey()
cs, _ = uconn.config.ClientSessionCache.Get(cacheKey)
if cs != nil {
session = cs.session
}
// TODO: use uconn.loadSession(hello.getPrivateObj()) to support TLS 1.3 PSK-style resumption
}
err := uconn.SetSessionState(cs)
if err != nil {
return err
}
case *SupportedCurvesExtension:
for i := range ext.Curves {
if isGREASEUint16(uint16(ext.Curves[i])) {
@ -2525,22 +2490,21 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
// but NextProtos is also used by ALPN and our spec nmay not actually have a NPN extension
hello.NextProtoNeg = haveNPN
err = uconn.sessionController.syncSessionExts()
if err != nil {
return err
}
return nil
}
func (uconn *UConn) generateRandomizedSpec() (ClientHelloSpec, error) {
css := &ClientSessionState{
session: uconn.HandshakeState.Session,
ticket: uconn.HandshakeState.Hello.SessionTicket,
}
return generateRandomizedSpec(&uconn.ClientHelloID, uconn.serverName, css, uconn.config.NextProtos)
return generateRandomizedSpec(&uconn.ClientHelloID, uconn.serverName, uconn.config.NextProtos)
}
func generateRandomizedSpec(
id *ClientHelloID,
serverName string,
session *ClientSessionState,
nextProtos []string,
) (ClientHelloSpec, error) {
p := ClientHelloSpec{}
@ -2606,7 +2570,7 @@ func generateRandomizedSpec(
p.CipherSuites = removeRandomCiphers(r, shuffledSuites, id.Weights.CipherSuites_Remove_RandomCiphers)
sni := SNIExtension{serverName}
sessionTicket := SessionTicketExtension{Session: session}
sessionTicket := SessionTicketExtension{}
sigAndHashAlgos := []SignatureScheme{
ECDSAWithP256AndSHA256,

469
u_pre_shared_key.go Normal file
View file

@ -0,0 +1,469 @@
package tls
import (
"encoding/json"
"errors"
"io"
"golang.org/x/crypto/cryptobyte"
)
var ErrEmptyPsk = errors.New("tls: empty psk detected; remove the psk extension for this connection or set OmitEmptyPsk to true to conceal it in utls")
type PreSharedKeyCommon struct {
Identities []PskIdentity
Binders [][]byte
BinderKey []byte // this will be used to compute the binder when hello message is ready
EarlySecret []byte
Session *SessionState
}
// The lifecycle of a PreSharedKeyExtension:
//
// Creation Phase:
// - The extension is created.
//
// Write Phase:
//
// - [writeToUConn() called]:
//
// > - During this phase, it is important to note that implementations should not write any session data to the UConn (Underlying Connection) as the session is not yet loaded. The session context is not active at this point.
//
// Initialization Phase:
//
// - [IsInitialized() called]:
//
// If IsInitialized() returns true
//
// > - GetPreSharedKeyCommon() will be called subsequently and the PSK states in handshake/clientHello will be fully initialized.
//
// If IsInitialized() returns false:
//
// > - [conn.loadSession() called]:
//
// >> - Once the session is available:
//
// >>> - [InitializeByUtls() called]:
//
// >>>> - The InitializeByUtls() method is invoked to initialize the extension based on the loaded session data.
//
// >>>> - This step prepares the extension for further processing.
//
// Marshal Phase:
//
// - [Len() called], [Read() called]:
//
// > - Implementations should marshal the extension into bytes, using placeholder binders to maintain the correct length.
//
// Binders Preparation Phase:
//
// - [PatchBuiltHello(hello) called]:
//
// > - The client hello is already marshaled in the "hello.Raw" format.
//
// > - Implementations are expected to update the binders within the marshaled client hello.
//
// - [GetPreSharedKeyCommon() called]:
//
// > - Implementations should gather and provide the final pre-shared key (PSK) related data.
//
// > - This data will be incorporated into both the clientHello and HandshakeState, ensuring that the PSK-related information is properly set and ready for the handshake process.
type PreSharedKeyExtension interface {
// TLSExtension must be implemented by all PreSharedKeyExtension implementations.
TLSExtension
// If false is returned, utls will invoke `InitializeByUtls()` for the necessary initialization.
Initializable
SetOmitEmptyPsk(val bool)
// InitializeByUtls is invoked when IsInitialized() returns false.
// It initializes the extension using a real and valid TLS 1.3 session.
InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity)
// GetPreSharedKeyCommon retrieves the final PreSharedKey-related states as defined in PreSharedKeyCommon.
GetPreSharedKeyCommon() PreSharedKeyCommon
// PatchBuiltHello is called once the hello message is fully applied and marshaled.
// Its purpose is to update the binders of PSK (Pre-Shared Key) identities.
PatchBuiltHello(hello *PubClientHelloMsg) error
mustEmbedUnimplementedPreSharedKeyExtension() // this works like a type guard
}
type UnimplementedPreSharedKeyExtension struct{}
func (UnimplementedPreSharedKeyExtension) mustEmbedUnimplementedPreSharedKeyExtension() {}
func (*UnimplementedPreSharedKeyExtension) IsInitialized() bool {
panic("tls: IsInitialized is not implemented for the PreSharedKeyExtension")
}
func (*UnimplementedPreSharedKeyExtension) InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) {
panic("tls: Initialize is not implemented for the PreSharedKeyExtension")
}
func (*UnimplementedPreSharedKeyExtension) writeToUConn(*UConn) error {
panic("tls: writeToUConn is not implemented for the PreSharedKeyExtension")
}
func (*UnimplementedPreSharedKeyExtension) Len() int {
panic("tls: Len is not implemented for the PreSharedKeyExtension")
}
func (*UnimplementedPreSharedKeyExtension) Read([]byte) (int, error) {
panic("tls: Read is not implemented for the PreSharedKeyExtension")
}
func (*UnimplementedPreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon {
panic("tls: GetPreSharedKeyCommon is not implemented for the PreSharedKeyExtension")
}
func (*UnimplementedPreSharedKeyExtension) PatchBuiltHello(hello *PubClientHelloMsg) error {
panic("tls: ReadWithRawHello is not implemented for the PreSharedKeyExtension")
}
func (*UnimplementedPreSharedKeyExtension) SetOmitEmptyPsk(val bool) {
panic("tls: SetOmitEmptyPsk is not implemented for the PreSharedKeyExtension")
}
// UtlsPreSharedKeyExtension is an extension used to set the PSK extension in the
// ClientHello.
type UtlsPreSharedKeyExtension struct {
UnimplementedPreSharedKeyExtension
PreSharedKeyCommon
cipherSuite *cipherSuiteTLS13
cachedLength *int
OmitEmptyPsk bool
}
func (e *UtlsPreSharedKeyExtension) IsInitialized() bool {
return e.Session != nil
}
func (e *UtlsPreSharedKeyExtension) InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) {
e.Session = session
e.EarlySecret = earlySecret
e.BinderKey = binderKey
e.cipherSuite = cipherSuiteTLS13ByID(e.Session.cipherSuite)
e.Identities = identities
e.Binders = make([][]byte, 0, len(e.Identities))
for i := 0; i < len(e.Identities); i++ {
e.Binders = append(e.Binders, make([]byte, e.cipherSuite.hash.Size()))
}
}
func (e *UtlsPreSharedKeyExtension) writeToUConn(uc *UConn) error {
uc.HandshakeState.Hello.TicketSupported = true // This doesn't matter though, as utls doesn't care about this field. We write this for consistency.
return nil
}
func (e *UtlsPreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon {
return e.PreSharedKeyCommon
}
func pskExtLen(identities []PskIdentity, binders [][]byte) int {
if len(identities) == 0 || len(binders) == 0 {
// If there isn't psk identities, we don't write this ticket to the client hello, and therefore the length should be 0.
return 0
}
length := 4 // extension type + extension length
length += 2 // identities length
for _, identity := range identities {
length += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
}
length += 2 // binders length
for _, binder := range binders {
length += len(binder) + 1
}
return length
}
func (e *UtlsPreSharedKeyExtension) Len() int {
if e.Session == nil {
return 0
}
if e.cachedLength != nil {
return *e.cachedLength
}
length := pskExtLen(e.Identities, e.Binders)
e.cachedLength = &length
return length
}
func readPskIntoBytes(b []byte, identities []PskIdentity, binders [][]byte) (int, error) {
extLen := pskExtLen(identities, binders)
if extLen == 0 {
return 0, io.EOF
}
if len(b) < extLen {
return 0, io.ErrShortBuffer
}
b[0] = byte(extensionPreSharedKey >> 8)
b[1] = byte(extensionPreSharedKey)
b[2] = byte((extLen - 4) >> 8)
b[3] = byte(extLen - 4)
// identities length
identitiesLength := 0
for _, identity := range identities {
identitiesLength += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
}
b[4] = byte(identitiesLength >> 8)
b[5] = byte(identitiesLength)
// identities
offset := 6
for _, identity := range identities {
b[offset] = byte(len(identity.Label) >> 8)
b[offset+1] = byte(len(identity.Label))
offset += 2
copy(b[offset:], identity.Label)
offset += len(identity.Label)
b[offset] = byte(identity.ObfuscatedTicketAge >> 24)
b[offset+1] = byte(identity.ObfuscatedTicketAge >> 16)
b[offset+2] = byte(identity.ObfuscatedTicketAge >> 8)
b[offset+3] = byte(identity.ObfuscatedTicketAge)
offset += 4
}
// binders length
bindersLength := 0
for _, binder := range binders {
// check if binder size is valid
bindersLength += len(binder) + 1 // binder length + binder
}
b[offset] = byte(bindersLength >> 8)
b[offset+1] = byte(bindersLength)
offset += 2
// binders
for _, binder := range binders {
b[offset] = byte(len(binder))
offset++
copy(b[offset:], binder)
offset += len(binder)
}
return extLen, io.EOF
}
func (e *UtlsPreSharedKeyExtension) SetOmitEmptyPsk(val bool) {
e.OmitEmptyPsk = val
}
func (e *UtlsPreSharedKeyExtension) Read(b []byte) (int, error) {
if !e.OmitEmptyPsk && e.Len() == 0 {
return 0, ErrEmptyPsk
}
return readPskIntoBytes(b, e.Identities, e.Binders)
}
func (e *UtlsPreSharedKeyExtension) PatchBuiltHello(hello *PubClientHelloMsg) error {
if e.Len() == 0 {
return nil
}
private := hello.getCachedPrivatePtr()
if private == nil {
private = hello.getPrivatePtr()
}
private.raw = hello.Raw
private.pskBinders = e.Binders // set the placeholder to the private Hello
//--- mirror loadSession() begin ---//
transcript := e.cipherSuite.hash.New()
helloBytes, err := private.marshalWithoutBinders() // no marshal() will be actually called, as we have set the field `raw`
if err != nil {
return err
}
transcript.Write(helloBytes)
pskBinders := [][]byte{e.cipherSuite.finishedHash(e.BinderKey, transcript)}
if err := private.updateBinders(pskBinders); err != nil {
return err
}
//--- mirror loadSession() end ---//
e.Binders = pskBinders
// no need to care about other PSK related fields, they will be handled separately
return io.EOF
}
func (e *UtlsPreSharedKeyExtension) Write(b []byte) (int, error) {
return len(b), nil // ignore the data
}
func (e *UtlsPreSharedKeyExtension) UnmarshalJSON(_ []byte) error {
return nil // ignore the data
}
// FakePreSharedKeyExtension is an extension used to set the PSK extension in the
// ClientHello.
//
// It does not compute binders based on ClientHello, but uses the binders specified instead.
// We still need to learn more of the safety
// of hardcoding both Identities and Binders without recalculating the latter.
type FakePreSharedKeyExtension struct {
UnimplementedPreSharedKeyExtension
Identities []PskIdentity `json:"identities"`
Binders [][]byte `json:"binders"`
OmitEmptyPsk bool
}
func (e *FakePreSharedKeyExtension) IsInitialized() bool {
return e.Identities != nil && e.Binders != nil
}
func (e *FakePreSharedKeyExtension) InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) {
panic("InitializeByUtls failed: don't let utls initialize FakePreSharedKeyExtension; provide your own identities and binders or use UtlsPreSharedKeyExtension")
}
func (e *FakePreSharedKeyExtension) writeToUConn(uc *UConn) error {
if uc.config.ClientSessionCache == nil {
return nil // don't write the extension if there is no session cache
}
if session, ok := uc.config.ClientSessionCache.Get(uc.clientSessionCacheKey()); !ok || session == nil {
return nil // don't write the extension if there is no session cache available for this session
}
uc.HandshakeState.Hello.PskIdentities = e.Identities
uc.HandshakeState.Hello.PskBinders = e.Binders
return nil
}
func (e *FakePreSharedKeyExtension) Len() int {
return pskExtLen(e.Identities, e.Binders)
}
func (e *FakePreSharedKeyExtension) SetOmitEmptyPsk(val bool) {
e.OmitEmptyPsk = val
}
func (e *FakePreSharedKeyExtension) Read(b []byte) (int, error) {
if !e.OmitEmptyPsk && e.Len() == 0 {
return 0, ErrEmptyPsk
}
for _, b := range e.Binders {
if !(anyTrue(validHashLen, func(_ int, valid *int) bool {
return len(b) == *valid
})) {
return 0, errors.New("tls: FakePreSharedKeyExtension.Read failed: invalid binder size")
}
}
return readPskIntoBytes(b, e.Identities, e.Binders)
}
func (e *FakePreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon {
return PreSharedKeyCommon{
Identities: e.Identities,
Binders: e.Binders,
}
}
var validHashLen = mapSlice(cipherSuitesTLS13, func(c *cipherSuiteTLS13) int {
return c.hash.Size()
})
func (*FakePreSharedKeyExtension) PatchBuiltHello(*PubClientHelloMsg) error {
return nil // no need to patch the hello since we don't need to update binders
}
func (e *FakePreSharedKeyExtension) Write(b []byte) (n int, err error) {
fullLen := len(b)
s := cryptobyte.String(b)
var identitiesLength uint16
if !s.ReadUint16(&identitiesLength) {
return 0, errors.New("tls: invalid PSK extension")
}
// identities
for identitiesLength > 0 {
var identityLength uint16
if !s.ReadUint16(&identityLength) {
return 0, errors.New("tls: invalid PSK extension")
}
identitiesLength -= 2
if identityLength > identitiesLength {
return 0, errors.New("tls: invalid PSK extension")
}
var identity []byte
if !s.ReadBytes(&identity, int(identityLength)) {
return 0, errors.New("tls: invalid PSK extension")
}
identitiesLength -= identityLength // identity
var obfuscatedTicketAge uint32
if !s.ReadUint32(&obfuscatedTicketAge) {
return 0, errors.New("tls: invalid PSK extension")
}
e.Identities = append(e.Identities, PskIdentity{
Label: identity,
ObfuscatedTicketAge: obfuscatedTicketAge,
})
identitiesLength -= 4 // obfuscated ticket age
}
var bindersLength uint16
if !s.ReadUint16(&bindersLength) {
return 0, errors.New("tls: invalid PSK extension")
}
// binders
for bindersLength > 0 {
var binderLength uint8
if !s.ReadUint8(&binderLength) {
return 0, errors.New("tls: invalid PSK extension")
}
bindersLength -= 1
if uint16(binderLength) > bindersLength {
return 0, errors.New("tls: invalid PSK extension")
}
var binder []byte
if !s.ReadBytes(&binder, int(binderLength)) {
return 0, errors.New("tls: invalid PSK extension")
}
e.Binders = append(e.Binders, binder)
bindersLength -= uint16(binderLength)
}
return fullLen, nil
}
func (e *FakePreSharedKeyExtension) UnmarshalJSON(data []byte) error {
var pskAccepter struct {
PskIdentities []PskIdentity `json:"identities"`
PskBinders [][]byte `json:"binders"`
}
if err := json.Unmarshal(data, &pskAccepter); err != nil {
return err
}
e.Identities = pskAccepter.PskIdentities
e.Binders = pskAccepter.PskBinders
return nil
}
// type guard
var (
_ PreSharedKeyExtension = (*UtlsPreSharedKeyExtension)(nil)
_ TLSExtensionJSON = (*UtlsPreSharedKeyExtension)(nil)
_ PreSharedKeyExtension = (*FakePreSharedKeyExtension)(nil)
_ TLSExtensionJSON = (*FakePreSharedKeyExtension)(nil)
_ TLSExtensionWriter = (*FakePreSharedKeyExtension)(nil)
)
// type ExternalPreSharedKeyExtension struct{} // TODO: wait for whoever cares about external PSK to implement it

View file

@ -45,7 +45,7 @@ type TLS13OnlyState struct {
EarlySecret []byte
BinderKey []byte
CertReq *CertificateRequestMsgTLS13
UsingPSK bool
UsingPSK bool // don't set this field when building client hello
SentDummyCCS bool
Transcript hash.Hash
TrafficSecret []byte // client_application_traffic_secret_0
@ -251,7 +251,7 @@ type PubServerHelloMsg struct {
OcspStapling bool
Scts [][]byte
ExtendedMasterSecret bool
TicketSupported bool
TicketSupported bool // used by go tls to determine whether to add the session ticket ext
SecureRenegotiation []byte
SecureRenegotiationSupported bool
AlpnProtocol string
@ -357,13 +357,15 @@ type PubClientHelloMsg struct {
PskIdentities []PskIdentity
PskBinders [][]byte
QuicTransportParameters []byte
cachedPrivateHello *clientHelloMsg // todo: further optimize to reduce clientHelloMsg construction
}
func (chm *PubClientHelloMsg) getPrivatePtr() *clientHelloMsg {
if chm == nil {
return nil
} else {
return &clientHelloMsg{
private := &clientHelloMsg{
raw: chm.Raw,
vers: chm.Vers,
random: chm.Random,
@ -395,6 +397,16 @@ func (chm *PubClientHelloMsg) getPrivatePtr() *clientHelloMsg {
nextProtoNeg: chm.NextProtoNeg,
}
chm.cachedPrivateHello = private
return private
}
}
func (chm *PubClientHelloMsg) getCachedPrivatePtr() *clientHelloMsg {
if chm == nil {
return nil
} else {
return chm.cachedPrivateHello
}
}
@ -432,6 +444,7 @@ func (chm *clientHelloMsg) getPublicPtr() *PubClientHelloMsg {
PskIdentities: pskIdentities(chm.pskIdentities).ToPublic(),
PskBinders: chm.pskBinders,
QuicTransportParameters: chm.quicTransportParameters,
cachedPrivateHello: chm,
}
}
}

346
u_session_controller.go Normal file
View file

@ -0,0 +1,346 @@
package tls
import (
"errors"
"fmt"
)
// Tracking the state of calling conn.loadSession
type LoadSessionTrackerState int
const NeverCalled LoadSessionTrackerState = 0
const UtlsAboutToCall LoadSessionTrackerState = 1
const CalledByULoadSession LoadSessionTrackerState = 2
const CalledByGoTLS LoadSessionTrackerState = 3
// The state of the session controller
type sessionControllerState int
const NoSession sessionControllerState = 0
const SessionTicketExtInitialized sessionControllerState = 1
const SessionTicketExtAllSet sessionControllerState = 2
const PskExtInitialized sessionControllerState = 3
const PskExtAllSet sessionControllerState = 4
// sessionController is responsible for managing and controlling all session related states. It manages the lifecycle of the session ticket extension and the psk extension, including initialization, removal if the client hello spec doesn't contain any of them, and setting the prepared state to the client hello.
//
// Users should never directly modify the underlying state. Violations will result in undefined behaviors.
//
// Users should never construct sessionController by themselves, use the function `newSessionController` instead.
type sessionController struct {
// sessionTicketExt logically owns the session ticket extension
sessionTicketExt ISessionTicketExtension
// pskExtension logically owns the psk extension
pskExtension PreSharedKeyExtension
// uconnRef is a reference to the uconn
uconnRef *UConn
// state represents the internal state of the sessionController. Users are advised to modify the state only through designated methods and avoid direct manipulation, as doing so may result in undefined behavior.
state sessionControllerState
// loadSessionTracker keeps track of how the conn.loadSession method is being utilized.
loadSessionTracker LoadSessionTrackerState
// callingLoadSession is a boolean flag that indicates whether the `conn.loadSession` function is currently being invoked.
callingLoadSession bool
// locked is a boolean flag that becomes true once all states are appropriately set. Once `locked` is true, further modifications are disallowed, except for the binders.
locked bool
}
// newSessionController constructs a new SessionController
func newSessionController(uconn *UConn) *sessionController {
return &sessionController{
uconnRef: uconn,
sessionTicketExt: &SessionTicketExtension{},
pskExtension: &UtlsPreSharedKeyExtension{},
state: NoSession,
locked: false,
callingLoadSession: false,
loadSessionTracker: NeverCalled,
}
}
func (s *sessionController) isSessionLocked() bool {
return s.locked
}
type shouldLoadSessionResult int
const shouldReturn shouldLoadSessionResult = 0
const shouldSetTicket shouldLoadSessionResult = 1
const shouldSetPsk shouldLoadSessionResult = 2
const shouldLoad shouldLoadSessionResult = 3
// shouldLoadSession determines the appropriate action to take when it is time to load the session for the clientHello.
// There are several possible scenarios:
// - If a session ticket is already initialized, typically via the `initSessionTicketExt()` function, the ticket should be set in the client hello.
// - If a pre-shared key (PSK) is already initialized, typically via the `overridePskExt()` function, the PSK should be set in the client hello.
// - If both the `sessionTicketExt` and `pskExtension` are nil, which might occur if the client hello spec does not include them, we should skip the loadSession().
// - In all other cases, the function proceeds to load the session.
func (s *sessionController) shouldLoadSession() shouldLoadSessionResult {
if s.sessionTicketExt == nil && s.pskExtension == nil || s.uconnRef.clientHelloBuildStatus != NotBuilt {
// No need to load session since we don't have the related extensions.
return shouldReturn
}
if s.state == SessionTicketExtInitialized {
return shouldSetTicket
}
if s.state == PskExtInitialized {
return shouldSetPsk
}
return shouldLoad
}
// utlsAboutToLoadSession updates the loadSessionTracker to `UtlsAboutToCall` to signal the initiation of a session loading operation,
// provided that the preconditions are met. If the preconditions are not met (due to incorrect utls implementation), this function triggers a panic.
func (s *sessionController) utlsAboutToLoadSession() {
uAssert(s.state == NoSession && !s.locked, "tls: aboutToLoadSession failed: must only load session when the session of the client hello is not locked and when there's currently no session")
s.loadSessionTracker = UtlsAboutToCall
}
func (s *sessionController) assertHelloNotBuilt(caller string) {
if s.uconnRef.clientHelloBuildStatus != NotBuilt {
panic(fmt.Sprintf("tls: %s failed: we can't modify the session after the clientHello is built", caller))
}
}
func (s *sessionController) assertControllerState(caller string, desired sessionControllerState, moreDesiredStates ...sessionControllerState) {
if s.state != desired && !anyTrue(moreDesiredStates, func(_ int, state *sessionControllerState) bool {
return s.state == *state
}) {
panic(fmt.Sprintf("tls: %s failed: undesired controller state %d", caller, s.state))
}
}
func (s *sessionController) assertNotLocked(caller string) {
if s.locked {
panic(fmt.Sprintf("tls: %s failed: you must not modify the session after it's locked", caller))
}
}
// finalCheck performs a comprehensive check on the updated state to ensure the correctness of the changes.
// If the checks pass successfully, the sessionController's state will be locked.
// Any failure in passing the tests indicates incorrect implementations in the utls, which will result in triggering a panic.
// Refer to the documentation for the `locked` field for more detailed information.
func (s *sessionController) finalCheck() {
s.assertControllerState("SessionController.finalCheck", PskExtAllSet, SessionTicketExtAllSet, NoSession)
s.locked = true
}
func initializationGuard[E Initializable, I func(E)](extension E, initializer I) {
uAssert(!extension.IsInitialized(), "tls: initialization failed: the extension is already initialized")
initializer(extension)
uAssert(extension.IsInitialized(), "tls: initialization failed: the extension is not initialized after initialization")
}
// initSessionTicketExt initializes the ticket and sets the state to `TicketInitialized`.
func (s *sessionController) initSessionTicketExt(session *SessionState, ticket []byte) {
s.assertNotLocked("initSessionTicketExt")
s.assertHelloNotBuilt("initSessionTicketExt")
s.assertControllerState("initSessionTicketExt", NoSession)
panicOnNil("initSessionTicketExt", s.sessionTicketExt, session, ticket)
initializationGuard(s.sessionTicketExt, func(e ISessionTicketExtension) {
s.sessionTicketExt.InitializeByUtls(session, ticket)
})
s.state = SessionTicketExtInitialized
}
// initPSK initializes the PSK extension using a valid session. The PSK extension
// should not be initialized previously, and the parameters must not be nil;
// otherwise, this function will trigger a panic.
func (s *sessionController) initPskExt(session *SessionState, earlySecret []byte, binderKey []byte, pskIdentities []pskIdentity) {
s.assertNotLocked("initPskExt")
s.assertHelloNotBuilt("initPskExt")
s.assertControllerState("initPskExt", NoSession)
panicOnNil("initPskExt", s.pskExtension, session, earlySecret, pskIdentities)
initializationGuard(s.pskExtension, func(e PreSharedKeyExtension) {
publicPskIdentities := mapSlice(pskIdentities, func(private pskIdentity) PskIdentity {
return PskIdentity{
Label: private.label,
ObfuscatedTicketAge: private.obfuscatedTicketAge,
}
})
e.InitializeByUtls(session, earlySecret, binderKey, publicPskIdentities)
})
s.state = PskExtInitialized
}
// setSessionTicketToUConn write the ticket states from the session ticket extension to the client hello and handshake state.
func (s *sessionController) setSessionTicketToUConn() {
uAssert(s.sessionTicketExt != nil && s.state == SessionTicketExtInitialized, "tls: setSessionTicketExt failed: invalid state")
s.uconnRef.HandshakeState.Session = s.sessionTicketExt.GetSession()
s.uconnRef.HandshakeState.Hello.SessionTicket = s.sessionTicketExt.GetTicket()
s.state = SessionTicketExtAllSet
}
// setPskToUConn sets the psk to the handshake state and client hello.
func (s *sessionController) setPskToUConn() {
uAssert(s.pskExtension != nil && (s.state == PskExtInitialized || s.state == PskExtAllSet), "tls: setPskToUConn failed: invalid state")
pskCommon := s.pskExtension.GetPreSharedKeyCommon()
if s.state == PskExtInitialized {
s.uconnRef.HandshakeState.State13.EarlySecret = pskCommon.EarlySecret
s.uconnRef.HandshakeState.Session = pskCommon.Session
s.uconnRef.HandshakeState.Hello.PskIdentities = pskCommon.Identities
s.uconnRef.HandshakeState.Hello.PskBinders = pskCommon.Binders
} else if s.state == PskExtAllSet {
uAssert(s.uconnRef.HandshakeState.Session == pskCommon.Session && sliceEq(s.uconnRef.HandshakeState.State13.EarlySecret, pskCommon.EarlySecret) &&
allTrue(s.uconnRef.HandshakeState.Hello.PskIdentities, func(i int, psk *PskIdentity) bool {
return pskCommon.Identities[i].ObfuscatedTicketAge == psk.ObfuscatedTicketAge && sliceEq(pskCommon.Identities[i].Label, psk.Label)
}), "tls: setPskToUConn failed: only binders are allowed to change on state `PskAllSet`")
}
s.uconnRef.HandshakeState.State13.BinderKey = pskCommon.BinderKey
s.state = PskExtAllSet
}
// shouldUpdateBinders determines whether binders should be updated based on the presence of an initialized psk extension.
// This function returns true if an initialized psk extension exists. Binders are allowed to be updated when the state is `PskAllSet`,
// as the `BuildHandshakeState` function can be called multiple times in this case. However, it's important to note that
// the session state, apart from binders, should not be altered more than once.
func (s *sessionController) shouldUpdateBinders() bool {
if s.pskExtension == nil {
return false
}
return (s.state == PskExtInitialized || s.state == PskExtAllSet)
}
func (s *sessionController) updateBinders() {
uAssert(s.shouldUpdateBinders(), "tls: updateBinders failed: shouldn't update binders")
s.pskExtension.PatchBuiltHello(s.uconnRef.HandshakeState.Hello)
}
func (s *sessionController) overrideExtension(extension Initializable, override func(), initializedState sessionControllerState) error {
panicOnNil("overrideExtension", extension)
s.assertNotLocked("overrideExtension")
s.assertControllerState("overrideExtension", NoSession)
override()
if extension.IsInitialized() {
s.state = initializedState
}
return nil
}
// overridePskExt allows the user of utls to customize the psk extension.
func (s *sessionController) overridePskExt(pskExt PreSharedKeyExtension) error {
return s.overrideExtension(pskExt, func() { s.pskExtension = pskExt }, PskExtInitialized)
}
// overridePskExt allows the user of utls to customize the session ticket extension.
func (s *sessionController) overrideSessionTicketExt(sessionTicketExt ISessionTicketExtension) error {
return s.overrideExtension(sessionTicketExt, func() { s.sessionTicketExt = sessionTicketExt }, SessionTicketExtInitialized)
}
// syncSessionExts synchronizes the sessionController with the session-related
// extensions from the extension list after applying client hello specs.
//
// - If the extension list is missing the session ticket extension or PSK
// extension, owned extensions are dropped and states are reset.
// - If the user provides a session ticket extension or PSK extension, the
// corresponding extension from the extension list will be replaced.
// - If the user doesn't provide session-related extensions, the extensions
// from the extension list will be utilized.
//
// This function ensures that there is only one session ticket extension or PSK
// extension, and that the PSK extension is the last extension in the extension
// list.
func (s *sessionController) syncSessionExts() error {
uAssert(s.uconnRef.clientHelloBuildStatus == NotBuilt, "tls: checkSessionExts failed: we can't modify the session after the clientHello is built")
s.assertNotLocked("checkSessionExts")
s.assertHelloNotBuilt("checkSessionExts")
s.assertControllerState("checkSessionExts", NoSession, SessionTicketExtInitialized, PskExtInitialized)
numSessionExt := 0
hasPskExt := false
for i, e := range s.uconnRef.Extensions {
switch ext := e.(type) {
case ISessionTicketExtension:
uAssert(numSessionExt == 0, "tls: checkSessionExts failed: multiple ISessionTicketExtensions in the extension list")
if s.sessionTicketExt == nil {
// If there isn't a user-provided session ticket extension, use the one from the spec
s.sessionTicketExt = ext
} else {
// Otherwise, replace the one in the extension list with the user-provided one
s.uconnRef.Extensions[i] = s.sessionTicketExt
}
numSessionExt += 1
case PreSharedKeyExtension:
uAssert(i == len(s.uconnRef.Extensions)-1, "tls: checkSessionExts failed: PreSharedKeyExtension must be the last extension")
if s.pskExtension == nil {
// If there isn't a user-provided psk extension, use the one from the spec
s.pskExtension = ext
} else {
// Otherwise, replace the one in the extension list with the user-provided one
s.uconnRef.Extensions[i] = s.pskExtension
}
s.pskExtension.SetOmitEmptyPsk(s.uconnRef.config.OmitEmptyPsk)
hasPskExt = true
}
}
if numSessionExt == 0 {
if s.state == SessionTicketExtInitialized {
return errors.New("tls: checkSessionExts failed: the user provided a session ticket, but the specification doesn't contain one")
}
s.sessionTicketExt = nil
s.uconnRef.HandshakeState.Session = nil
s.uconnRef.HandshakeState.Hello.SessionTicket = nil
}
if !hasPskExt {
if s.state == PskExtInitialized {
return errors.New("tls: checkSessionExts failed: the user provided a psk, but the specification doesn't contain one")
}
s.pskExtension = nil
s.uconnRef.HandshakeState.State13.BinderKey = nil
s.uconnRef.HandshakeState.State13.EarlySecret = nil
s.uconnRef.HandshakeState.Session = nil
s.uconnRef.HandshakeState.Hello.PskIdentities = nil
}
return nil
}
// onEnterLoadSessionCheck is intended to be invoked upon entering the `conn.loadSession` function.
// It is designed to ensure the correctness of the utls implementation. If the utls implementation is found to be incorrect, this function will trigger a panic.
func (s *sessionController) onEnterLoadSessionCheck() {
uAssert(!s.locked, "tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: session is set and locked, no call to loadSession is allowed")
switch s.loadSessionTracker {
case UtlsAboutToCall, NeverCalled:
s.callingLoadSession = true
case CalledByULoadSession, CalledByGoTLS:
panic("tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: you must not call loadSession() twice")
default:
panic("tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: unimplemented state")
}
}
// onLoadSessionReturn is intended to be invoked upon returning from the `conn.loadSession` function.
// It serves as a validation step for the correctness of the underlying utls implementation.
// If the utls implementation is incorrect, this function will trigger a panic.
func (s *sessionController) onLoadSessionReturn() {
uAssert(s.callingLoadSession, "tls: LoadSessionCoordinator.onLoadSessionReturn failed: it's not loading sessions, perhaps this function is not being called by loadSession.")
switch s.loadSessionTracker {
case NeverCalled:
s.loadSessionTracker = CalledByGoTLS
case UtlsAboutToCall:
s.loadSessionTracker = CalledByULoadSession
default:
panic("tls: LoadSessionCoordinator.onLoadSessionReturn failed: unimplemented state")
}
s.callingLoadSession = false
}
// shouldLoadSessionWriteBinders checks if `conn.loadSession` should proceed to write binders and marshal the client hello. If the utls implementation
// is incorrect, this function will trigger a panic.
func (s *sessionController) shouldLoadSessionWriteBinders() bool {
uAssert(s.callingLoadSession, "tls: shouldWriteBinders failed: LoadSessionCoordinator isn't loading sessions, perhaps this function is not being called by loadSession.")
switch s.loadSessionTracker {
case NeverCalled:
return true
case UtlsAboutToCall:
return false
default:
panic("tls: shouldWriteBinders failed: unimplemented state")
}
}

82
u_session_ticket.go Normal file
View file

@ -0,0 +1,82 @@
package tls
import "io"
type ISessionTicketExtension interface {
TLSExtension
// If false is returned, utls will invoke `InitializeByUtls()` for the necessary initialization.
Initializable
// InitializeByUtls is invoked when IsInitialized() returns false.
// It initializes the extension using a real and valid TLS 1.2 session.
InitializeByUtls(session *SessionState, ticket []byte)
GetSession() *SessionState
GetTicket() []byte
}
// SessionTicketExtension implements session_ticket (35)
type SessionTicketExtension struct {
Session *SessionState
Ticket []byte
Initialized bool
}
func (e *SessionTicketExtension) writeToUConn(uc *UConn) error {
// session states are handled later. At this point tickets aren't
// being loaded by utls, so don't write anything to the UConn.
uc.HandshakeState.Hello.TicketSupported = true // This doesn't really matter, this field is only used to add session ticket ext in go tls.
return nil
}
func (e *SessionTicketExtension) Len() int {
return 4 + len(e.Ticket)
}
func (e *SessionTicketExtension) Read(b []byte) (int, error) {
if len(b) < e.Len() {
return 0, io.ErrShortBuffer
}
extBodyLen := e.Len() - 4
b[0] = byte(extensionSessionTicket >> 8)
b[1] = byte(extensionSessionTicket)
b[2] = byte(extBodyLen >> 8)
b[3] = byte(extBodyLen)
if extBodyLen > 0 {
copy(b[4:], e.Ticket)
}
return e.Len(), io.EOF
}
func (e *SessionTicketExtension) IsInitialized() bool {
return e.Initialized
}
func (e *SessionTicketExtension) InitializeByUtls(session *SessionState, ticket []byte) {
uAssert(!e.Initialized, "tls: InitializeByUtls failed: the SessionTicketExtension is initialized")
uAssert(session.version == VersionTLS12 && session != nil && ticket != nil, "tls: InitializeByUtls failed: the session is not a tls 1.2 session")
e.Session = session
e.Ticket = ticket
e.Initialized = true
}
func (e *SessionTicketExtension) UnmarshalJSON(_ []byte) error {
return nil // no-op
}
func (e *SessionTicketExtension) Write(_ []byte) (int, error) {
// RFC 5077, Section 3.2
return 0, nil
}
func (e *SessionTicketExtension) GetSession() *SessionState {
return e.Session
}
func (e *SessionTicketExtension) GetTicket() []byte {
return e.Ticket
}

View file

@ -48,7 +48,7 @@ func ExtensionFromID(id uint16) TLSExtension {
case extensionSessionTicket:
return &SessionTicketExtension{}
case extensionPreSharedKey:
return &FakePreSharedKeyExtension{}
return (PreSharedKeyExtension)(&FakePreSharedKeyExtension{}) // To use the result, caller needs further inspection to decide between Fake or Utls.
// case extensionEarlyData:
// return &EarlyDataExtension{}
case extensionSupportedVersions:
@ -800,52 +800,6 @@ func (e *SCTExtension) Write(_ []byte) (int, error) {
return 0, nil
}
// SessionTicketExtension implements session_ticket (35)
type SessionTicketExtension struct {
Session *ClientSessionState
}
func (e *SessionTicketExtension) writeToUConn(uc *UConn) error {
if e.Session != nil {
uc.HandshakeState.Session = e.Session.session
uc.HandshakeState.Hello.SessionTicket = e.Session.ticket
}
return nil
}
func (e *SessionTicketExtension) Len() int {
if e.Session != nil {
return 4 + len(e.Session.ticket)
}
return 4
}
func (e *SessionTicketExtension) Read(b []byte) (int, error) {
if len(b) < e.Len() {
return 0, io.ErrShortBuffer
}
extBodyLen := e.Len() - 4
b[0] = byte(extensionSessionTicket >> 8)
b[1] = byte(extensionSessionTicket)
b[2] = byte(extBodyLen >> 8)
b[3] = byte(extBodyLen)
if extBodyLen > 0 {
copy(b[4:], e.Session.ticket)
}
return e.Len(), io.EOF
}
func (e *SessionTicketExtension) UnmarshalJSON(_ []byte) error {
return nil // no-op
}
func (e *SessionTicketExtension) Write(_ []byte) (int, error) {
// RFC 5077, Section 3.2
return 0, nil
}
// GenericExtension allows to include in ClientHello arbitrary unsupported extensions.
// It is not defined in TLS RFCs nor by IANA.
// If a server echoes this extension back, the handshake will likely fail due to no further support.
@ -1893,175 +1847,3 @@ func (e *FakeDelegatedCredentialsExtension) UnmarshalJSON(data []byte) error {
}
return nil
}
// FakePreSharedKeyExtension is an extension used to set the PSK extension in the
// ClientHello.
//
// Unfortunately, even when the PSK extension is set, there will be no PSK-based
// resumption since crypto/tls does not implement PSK.
type FakePreSharedKeyExtension struct {
PskIdentities []PskIdentity `json:"identities"`
PskBinders [][]byte `json:"binders"`
}
func (e *FakePreSharedKeyExtension) writeToUConn(uc *UConn) error {
if uc.config.ClientSessionCache == nil {
return nil // don't write the extension if there is no session cache
}
if session, ok := uc.config.ClientSessionCache.Get(uc.clientSessionCacheKey()); !ok || session == nil {
return nil // don't write the extension if there is no session cache available for this session
}
uc.HandshakeState.Hello.PskIdentities = e.PskIdentities
uc.HandshakeState.Hello.PskBinders = e.PskBinders
return nil
}
func (e *FakePreSharedKeyExtension) Len() int {
length := 4 // extension type + extension length
length += 2 // identities length
for _, identity := range e.PskIdentities {
length += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
}
length += 2 // binders length
for _, binder := range e.PskBinders {
length += len(binder)
}
return length
}
func (e *FakePreSharedKeyExtension) Read(b []byte) (int, error) {
if len(b) < e.Len() {
return 0, io.ErrShortBuffer
}
b[0] = byte(extensionPreSharedKey >> 8)
b[1] = byte(extensionPreSharedKey)
b[2] = byte((e.Len() - 4) >> 8)
b[3] = byte(e.Len() - 4)
// identities length
identitiesLength := 0
for _, identity := range e.PskIdentities {
identitiesLength += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
}
b[4] = byte(identitiesLength >> 8)
b[5] = byte(identitiesLength)
// identities
offset := 6
for _, identity := range e.PskIdentities {
b[offset] = byte(len(identity.Label) >> 8)
b[offset+1] = byte(len(identity.Label))
offset += 2
copy(b[offset:], identity.Label)
offset += len(identity.Label)
b[offset] = byte(identity.ObfuscatedTicketAge >> 24)
b[offset+1] = byte(identity.ObfuscatedTicketAge >> 16)
b[offset+2] = byte(identity.ObfuscatedTicketAge >> 8)
b[offset+3] = byte(identity.ObfuscatedTicketAge)
offset += 4
}
// binders length
bindersLength := 0
for _, binder := range e.PskBinders {
bindersLength += len(binder)
}
b[offset] = byte(bindersLength >> 8)
b[offset+1] = byte(bindersLength)
offset += 2
// binders
for _, binder := range e.PskBinders {
copy(b[offset:], binder)
offset += len(binder)
}
return e.Len(), io.EOF
}
func (e *FakePreSharedKeyExtension) Write(b []byte) (n int, err error) {
fullLen := len(b)
s := cryptobyte.String(b)
var identitiesLength uint16
if !s.ReadUint16(&identitiesLength) {
return 0, errors.New("tls: invalid PSK extension")
}
// identities
for identitiesLength > 0 {
var identityLength uint16
if !s.ReadUint16(&identityLength) {
return 0, errors.New("tls: invalid PSK extension")
}
identitiesLength -= 2
if identityLength > identitiesLength {
return 0, errors.New("tls: invalid PSK extension")
}
var identity []byte
if !s.ReadBytes(&identity, int(identityLength)) {
return 0, errors.New("tls: invalid PSK extension")
}
identitiesLength -= identityLength // identity
var obfuscatedTicketAge uint32
if !s.ReadUint32(&obfuscatedTicketAge) {
return 0, errors.New("tls: invalid PSK extension")
}
e.PskIdentities = append(e.PskIdentities, PskIdentity{
Label: identity,
ObfuscatedTicketAge: obfuscatedTicketAge,
})
identitiesLength -= 4 // obfuscated ticket age
}
var bindersLength uint16
if !s.ReadUint16(&bindersLength) {
return 0, errors.New("tls: invalid PSK extension")
}
// binders
for bindersLength > 0 {
var binderLength uint8
if !s.ReadUint8(&binderLength) {
return 0, errors.New("tls: invalid PSK extension")
}
bindersLength -= 1
if uint16(binderLength) > bindersLength {
return 0, errors.New("tls: invalid PSK extension")
}
var binder []byte
if !s.ReadBytes(&binder, int(binderLength)) {
return 0, errors.New("tls: invalid PSK extension")
}
e.PskBinders = append(e.PskBinders, binder)
bindersLength -= uint16(binderLength)
}
return fullLen, nil
}
func (e *FakePreSharedKeyExtension) UnmarshalJSON(data []byte) error {
var pskAccepter struct {
PskIdentities []PskIdentity `json:"identities"`
PskBinders [][]byte `json:"binders"`
}
if err := json.Unmarshal(data, &pskAccepter); err != nil {
return err
}
e.PskIdentities = pskAccepter.PskIdentities
e.PskBinders = pskAccepter.PskBinders
return nil
}