diff --git a/README.md b/README.md index 8079da4..2e9148b 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,7 @@ Currently, there is a simple function to set session ticket to any desired state ```Golang // If you want you session tickets to be reused - use same cache on following connections -func (uconn *UConn) SetSessionState(session *ClientSessionState) +func (uconn *UConn) SetSessionState12(session *ClientSessionState) ``` Note that session tickets (fake ones or otherwise) are not reused. @@ -294,7 +294,7 @@ Some customizations(such as setting session ticket/clientHello) have easy-to-use tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, masterSecret, nil, nil) - tlsConn.SetSessionState(sessionState) + tlsConn.SetSessionState12(sessionState) ``` For other customizations there are following functions diff --git a/examples/old/examples.go b/examples/old/examples.go index b17a7e9..2b2a2ba 100644 --- a/examples/old/examples.go +++ b/examples/old/examples.go @@ -140,7 +140,7 @@ func HttpGetTicket(hostname string, addr string) (*http.Response, error) { masterSecret, nil, nil) - err = uTlsConn.SetSessionState(sessionState) + err = uTlsConn.SetSessionState12(sessionState) if err != nil { return nil, err } @@ -174,7 +174,7 @@ func HttpGetTicketHelloID(hostname string, addr string, helloID tls.ClientHelloI masterSecret, nil, nil) - uTlsConn.SetSessionState(sessionState) + uTlsConn.SetSessionState12(sessionState) err = uTlsConn.Handshake() if err != nil { return nil, fmt.Errorf("uTlsConn.Handshake() error: %+v", err) diff --git a/examples/tls-psk/main.go b/examples/tls-psk/main.go index 0a97647..15517c0 100644 --- a/examples/tls-psk/main.go +++ b/examples/tls-psk/main.go @@ -38,7 +38,7 @@ func (csc *ClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState } } -func main() { +func runPskCheck(helloID tls.ClientHelloID) { const serverAddr string = "refraction.network:443" csc := NewClientSessionCache() tcpConn, err := net.Dial("tcp", serverAddr) @@ -53,7 +53,7 @@ func main() { ServerName: strings.Split(serverAddr, ":")[0], // NextProtos: []string{"h2", "http/1.1"}, ClientSessionCache: csc, // set this so session tickets will be saved - }, tls.HelloChrome_100) + }, helloID) // HS err = tlsConn.Handshake() @@ -88,10 +88,12 @@ func main() { tlsConnPSK := tls.UClient(tcpConnPSK, &tls.Config{ ServerName: strings.Split(serverAddr, ":")[0], ClientSessionCache: csc, - }, tls.HelloChrome_100_PSK, &tls.UtlsPreSharedKeyExtension{}) + }, helloID) // HS err = tlsConnPSK.Handshake() + fmt.Println(tlsConnPSK.HandshakeState.Hello.Raw) + fmt.Println(tlsConnPSK.HandshakeState.Hello.PskIdentities) if err != nil { panic(err) } @@ -111,3 +113,8 @@ func main() { } } } + +func main() { + runPskCheck(tls.HelloChrome_100_PSK) + runPskCheck(tls.HelloGolang) +} diff --git a/handshake_client.go b/handshake_client.go index aa56de3..176e5a3 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -312,6 +312,12 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { func (c *Conn) loadSession(hello *clientHelloMsg) ( session *SessionState, earlySecret, binderKey []byte, err error) { + // [UTLS SECTION START] + if c.utls.sessionController != nil { + c.utls.sessionController.onEnterLoadSessionCheck() + defer c.utls.sessionController.onLoadSessionReturn() + } + // [UTLS SECTION END] if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { return nil, nil, nil, nil } @@ -324,12 +330,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) ( hello.pskModes = []uint8{pskModeDHE} } - // [UTLS BEGINS] - if c.utls.session != nil { - return c.utls.session, c.utls.earlySecret, c.utls.binderKey, nil - } - // [UTLS ENDS] - // Session resumption is not allowed if renegotiating because // renegotiation is primarily used to allow a client to send a client // certificate, which would be skipped if session resumption occurred. @@ -456,6 +456,11 @@ func (c *Conn) loadSession(hello *clientHelloMsg) ( // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. earlySecret = cipherSuite.extract(session.secret, nil) binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) + // [UTLS SECTION START] + if c.utls.sessionController != nil && !c.utls.sessionController.shouldWriteBinders() { + return + } + // [UTLS SECTION END] transcript := cipherSuite.hash.New() helloBytes, err := hello.marshalWithoutBinders() if err != nil { @@ -466,11 +471,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) ( if err := hello.updateBinders(pskBinders); err != nil { return nil, nil, nil, err } - - c.utls.session = session // [uTLS] - c.utls.earlySecret = earlySecret // [uTLS] - c.utls.binderKey = binderKey // [uTLS] - return } diff --git a/u_clienthello_json_test.go b/u_clienthello_json_test.go index 9ab86c7..b1f1fb8 100644 --- a/u_clienthello_json_test.go +++ b/u_clienthello_json_test.go @@ -29,7 +29,7 @@ func testClientHelloSpecJSONUnmarshaler( t.Fatal(err) } - truthSpec, _ := utlsIdToSpec(truthClientHelloID) + truthSpec, _ := utlsIdToSpec(truthClientHelloID, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{}) jsonSpec := chsju.ClientHelloSpec() // Compare CipherSuites @@ -85,7 +85,7 @@ func testClientHelloSpecUnmarshalJSON( t.Fatal(err) } - truthSpec, _ := utlsIdToSpec(truthClientHelloID) + truthSpec, _ := utlsIdToSpec(truthClientHelloID, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{}) // Compare CipherSuites if !reflect.DeepEqual(jsonSpec.CipherSuites, truthSpec.CipherSuites) { diff --git a/u_common.go b/u_common.go index af2959b..d74d7a4 100644 --- a/u_common.go +++ b/u_common.go @@ -726,3 +726,38 @@ func EnableWeakCiphers() { suiteECDHE | suiteTLS12 | suiteSHA384, cipherAES, utlsMacSHA384, nil}, }...) } + +func panicOnNil(failureMsg string, params ...any) { + for i, p := range params { + if p == nil { + panic(fmt.Sprintf("%s: the [%d] parameter is nil", failureMsg, i)) + } + } +} + +func anyTrue[T any](slice []T, predicate func(t *T) bool) bool { + for i := 0; i < len(slice); i++ { + if predicate(&slice[i]) { + return true + } + } + return false +} + +func uAssert(condition bool, msg string) { + if !condition { + panic(msg) + } +} + +func sliceEq(sliceA []any, sliceB []any) bool { + if len(sliceA) != len(sliceB) { + return false + } + for i := 0; i < len(sliceA); i++ { + if sliceA[i] != sliceB[i] { + return false + } + } + return true +} diff --git a/u_conn.go b/u_conn.go index 2e49e2d..4249928 100644 --- a/u_conn.go +++ b/u_conn.go @@ -14,23 +14,26 @@ import ( "errors" "fmt" "hash" - "io" "net" "strconv" ) +type ClientHelloBuildStatus int + +const NotBuilt ClientHelloBuildStatus = 0 +const BuildByUtls ClientHelloBuildStatus = 1 +const BuildByGoTLS ClientHelloBuildStatus = 2 + type UConn struct { *Conn - Extensions []TLSExtension - ClientHelloID ClientHelloID - pskExtension []PreSharedKeyExtension + Extensions []TLSExtension + ClientHelloID ClientHelloID + sessionController *sessionController - ClientHelloBuilt bool - HandshakeState PubClientHandshakeState + clientHelloBuildStatus ClientHelloBuildStatus - // sessionID may or may not depend on ticket; nil => random - GetSessionID func(ticket []byte) [32]byte + HandshakeState PubClientHandshakeState greaseSeed [ssl_grease_last_index]uint16 @@ -44,15 +47,17 @@ type UConn struct { // UClient returns a new uTLS client, with behavior depending on clientHelloID. // Config CAN be nil, but make sure to eventually specify ServerName. -func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID, pskExtension ...PreSharedKeyExtension) *UConn { +func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn { if config == nil { config = &Config{} } tlsConn := Conn{conn: conn, config: config, isClient: true} handshakeState := PubClientHandshakeState{C: &tlsConn, Hello: &PubClientHelloMsg{}} - uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, pskExtension: pskExtension, HandshakeState: handshakeState} + uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, HandshakeState: handshakeState} uconn.HandshakeState.uconn = &uconn uconn.handshakeFn = uconn.clientHandshake + uconn.sessionController = newSessionController(&uconn) + uconn.utls.sessionController = uconn.sessionController return &uconn } @@ -73,9 +78,10 @@ func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID, pskExte // default/mimicked ClientHello. func (uconn *UConn) BuildHandshakeState() error { if uconn.ClientHelloID == HelloGolang { - if uconn.ClientHelloBuilt { + if uconn.clientHelloBuildStatus == BuildByGoTLS { return nil } + uAssert(uconn.clientHelloBuildStatus == NotBuilt, "BuildHandshakeState failed: invalid call, client hello has already been built by utls") // use default Golang ClientHello. hello, keySharePrivate, err := uconn.makeClientHello() @@ -92,8 +98,10 @@ func (uconn *UConn) BuildHandshakeState() error { return fmt.Errorf("uTLS: unknown keySharePrivate type: %T", keySharePrivate) } uconn.HandshakeState.C = uconn.Conn + uconn.clientHelloBuildStatus = BuildByGoTLS } else { - if !uconn.ClientHelloBuilt { + uAssert(uconn.clientHelloBuildStatus == BuildByUtls || uconn.clientHelloBuildStatus == NotBuilt, "BuildHandshakeState failed: invalid call, client hello has already been built by go-tls") + if uconn.clientHelloBuildStatus == NotBuilt { err := uconn.applyPresetByID(uconn.ClientHelloID) if err != nil { return err @@ -107,51 +115,93 @@ func (uconn *UConn) BuildHandshakeState() error { if err != nil { return err } + + err = uconn.uLoadSession() + if err != nil { + return err + } + err = uconn.MarshalClientHello() if err != nil { return err } + + uconn.uApplyPatch() + + uconn.sessionController.finalCheck() + uconn.clientHelloBuildStatus = BuildByUtls } - uconn.ClientHelloBuilt = true return nil } -// SetSessionState sets the session ticket, which may be preshared or fake. +func (uconn *UConn) uLoadSession() error { + if cfg := uconn.config; cfg.SessionTicketsDisabled || cfg.ClientSessionCache == nil { + return nil + } + switch uconn.sessionController.shouldLoadSession() { + case shouldReturn: + case shouldSetTicket: + uconn.sessionController.setSessionTicketToUConn() + case shouldSetPsk: + uconn.sessionController.setPsk() + case shouldLoad: + hello := uconn.HandshakeState.Hello.getPrivatePtr() + uconn.sessionController.aboutToLoadSession() + session, earlySecret, binderKey, err := uconn.loadSession(hello) + if session == nil || err != nil { + return err + } + if session.version == VersionTLS12 { + // We use the session ticket extension for tls 1.2 session resumption + uconn.sessionController.initSessionTicketExt(session, hello.sessionTicket) + uconn.sessionController.setSessionTicketToUConn() + } else { + uconn.sessionController.initPsk(session, earlySecret, binderKey, hello.pskIdentities) + } + } + + return nil +} + +func (uconn *UConn) uApplyPatch() { + if uconn.sessionController.shouldUpdateBinders() { + uconn.sessionController.updateBinders() + uconn.sessionController.setPsk() + } +} + +// SetSessionState12 sets the session ticket, which may be preshared or fake. // If session is nil, the body of session ticket extension will be unset, // but the extension itself still MAY be present for mimicking purposes. // Session tickets to be reused - use same cache on following connections. -func (uconn *UConn) SetSessionState(session *ClientSessionState) error { - var sessionTicket []uint8 - if session != nil { - sessionTicket = session.ticket - uconn.HandshakeState.Session = session.session +func (uconn *UConn) SetSessionState12(session *ClientSessionState) error { + if uconn.config.SessionTicketsDisabled || uconn.config.ClientSessionCache == nil { + return fmt.Errorf("SetSessionState12 failed: session is disabled") } - uconn.HandshakeState.Hello.TicketSupported = true - uconn.HandshakeState.Hello.SessionTicket = sessionTicket - - for _, ext := range uconn.Extensions { - st, ok := ext.(*SessionTicketExtension) - if !ok { - continue - } - st.Session = session - if session != nil { - if len(session.SessionTicket()) > 0 { - if uconn.GetSessionID != nil { - sid := uconn.GetSessionID(session.SessionTicket()) - uconn.HandshakeState.Hello.SessionId = sid[:] - return nil - } - } - var sessionID [32]byte - _, err := io.ReadFull(uconn.config.rand(), sessionID[:]) - if err != nil { - return err - } - uconn.HandshakeState.Hello.SessionId = sessionID[:] - } + if session == nil { return nil } + if session.session == nil { + return fmt.Errorf("SetSessionState12 failed: session must not be nil") + } + if session.session.version != VersionTLS12 { + return fmt.Errorf("SetSessionState12 failed: SetSessionState12 only works for tls 1.2 session ticket; for tls 1.3 please customize PSK with SetSessionState13()") + } + uconn.sessionController.initSessionTicketExt(session.session, session.ticket) + return nil +} + +// SetSessionState13 sets the psk extension for tls 1.3 resumption +func (uconn *UConn) SetSessionState13(psk PreSharedKeyExtension) error { + if uconn.config.SessionTicketsDisabled || uconn.config.ClientSessionCache == nil { + return fmt.Errorf("SetSessionState13 failed: session is disabled") + } + if psk == nil { + return nil + } + + uconn.HandshakeState.Hello.TicketSupported = true + uconn.sessionController.overridePskExt(psk) return nil } @@ -397,7 +447,7 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) { hello := c.HandshakeState.Hello.getPrivatePtr() defer func() { c.HandshakeState.Hello = hello.getPublicPtr() }() - sessionIsAlreadySet := c.HandshakeState.Session != nil + sessionIsLocked := c.utls.sessionController.isSessionLocked() // after this point exactly 1 out of 2 HandshakeState pointers is non-nil, // useTLS13 variable tells which pointer @@ -434,9 +484,24 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) { if c.handshakes > 0 { hello.secureRenegotiation = c.clientFinished[:] } - // [uTLS section ends] - session, earlySecret, binderKey, err := c.loadSession(hello) + var ( + session *SessionState + earlySecret []byte + binderKey []byte + ) + if !sessionIsLocked { + // [uTLS section ends] + + session, earlySecret, binderKey, err = c.loadSession(hello) + + // [uTLS section start] + } else { + session = c.HandshakeState.Session + earlySecret = c.HandshakeState.State13.EarlySecret + binderKey = c.HandshakeState.State13.BinderKey + } + // [uTLS section ends] if err != nil { return err } @@ -491,7 +556,7 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) { hs13.serverHello = serverHello hs13.hello = hello hs13.keySharesParams = NewKeySharesParameters() - if !sessionIsAlreadySet { + if !sessionIsLocked { hs13.earlySecret = earlySecret hs13.binderKey = binderKey hs13.session = session @@ -547,7 +612,7 @@ func (uconn *UConn) MarshalClientHello() error { if paddingExt == nil { paddingExt = pe } else { - return errors.New("multiple padding extensions!") + return errors.New("multiple padding extensions") } } } @@ -589,27 +654,8 @@ func (uconn *UConn) MarshalClientHello() error { if len(uconn.Extensions) > 0 { binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen)) for _, ext := range uconn.Extensions { - switch typedExt := ext.(type) { - case PreSharedKeyExtension: - // PSK extension is handled separately - err := bufferedWriter.Flush() - if err != nil { - return fmt.Errorf("bufferedWriter.Flush(): %w", err) - } - hello.Raw = helloBuffer.Bytes() - // prepare buffer - buf := make([]byte, typedExt.Len()) - n, err := typedExt.ReadWithRawHello(hello.Raw, buf) - if err != nil && !errors.Is(err, io.EOF) { - return fmt.Errorf("(*PreSharedKeyExtension).ReadWithRawHello(): %w", err) - } - if n != typedExt.Len() { - return errors.New("uconn: PreSharedKeyExtension: read wrong number of bytes") - } - bufferedWriter.Write(buf) - hello.PskBinders = typedExt.GetBinders() - default: - bufferedWriter.ReadFrom(ext) + if _, err := bufferedWriter.ReadFrom(ext); err != nil { + return err } } } @@ -801,8 +847,5 @@ type utlsConnExtraFields struct { peerApplicationSettings []byte localApplicationSettings []byte - // session resumption (PSK) - session *SessionState - earlySecret []byte - binderKey []byte + sessionController *sessionController } diff --git a/u_fingerprinter_test.go b/u_fingerprinter_test.go index ad140bf..af125dc 100644 --- a/u_fingerprinter_test.go +++ b/u_fingerprinter_test.go @@ -252,7 +252,7 @@ func TestUTLSFingerprintClientHelloBluntMimicry(t *testing.T) { var extensionId uint16 = 0xfeed extensionData := []byte("random data") - specWithGeneric, err := utlsIdToSpec(HelloChrome_Auto) + specWithGeneric, err := utlsIdToSpec(HelloChrome_Auto, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{}) if err != nil { t.Errorf("got error: %v; expected to succeed", err) } @@ -293,11 +293,11 @@ func TestUTLSFingerprintClientHelloBluntMimicry(t *testing.T) { func TestUTLSFingerprintClientHelloAlwaysAddPadding(t *testing.T) { serverName := "foobar" - specWithoutPadding, err := utlsIdToSpec(HelloIOS_12_1) + specWithoutPadding, err := utlsIdToSpec(HelloIOS_12_1, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{}) if err != nil { t.Errorf("got error: %v; expected to succeed", err) } - specWithPadding, err := utlsIdToSpec(HelloChrome_83) + specWithPadding, err := utlsIdToSpec(HelloChrome_83, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{}) if err != nil { t.Errorf("got error: %v; expected to succeed", err) } diff --git a/u_parrots.go b/u_parrots.go index f8a3c75..6efb8b6 100644 --- a/u_parrots.go +++ b/u_parrots.go @@ -17,26 +17,18 @@ import ( ) var ErrUnknownClientHelloID = errors.New("tls: unknown ClientHelloID") -var ErrNotPSKClientHelloID = errors.New("tls: ClientHello does not contain pre_shared_key extension") -var ErrPSKExtensionExpected = errors.New("tls: pre_shared_key extension expected when fetching preset ClientHelloSpec") // UTLSIdToSpec converts a ClientHelloID to a corresponding ClientHelloSpec. // // Exported internal function utlsIdToSpec per request. -func UTLSIdToSpec(id ClientHelloID, pskExtension ...PreSharedKeyExtension) (ClientHelloSpec, error) { - if len(pskExtension) > 1 { - return ClientHelloSpec{}, errors.New("tls: at most one PreSharedKeyExtensions is allowed") - } - - chs, err := utlsIdToSpec(id) - if err != nil && errors.Is(err, ErrUnknownClientHelloID) { - chs, err = utlsIdToSpecWithPSK(id, pskExtension...) - } - - return chs, err +func UTLSIdToSpec(id ClientHelloID, pskExt PreSharedKeyExtension, sessionTicketExt *SessionTicketExtension) (ClientHelloSpec, error) { + return utlsIdToSpec(id, pskExt, sessionTicketExt) } -func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { +func utlsIdToSpec(id ClientHelloID, pskExt PreSharedKeyExtension, sessionTicketExt *SessionTicketExtension) (ClientHelloSpec, error) { + if pskExt == nil || sessionTicketExt == nil { + panic("utlsIdToSpec failed: pskExt and sessionTicketExt must be non-nil pointers") + } switch id { case HelloChrome_58, HelloChrome_62: return ClientHelloSpec{ @@ -64,7 +56,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient}, &SNIExtension{}, &ExtendedMasterSecretExtension{}, - &SessionTicketExtension{}, + sessionTicketExt, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ ECDSAWithP256AndSHA256, PSSWithSHA256, @@ -119,7 +111,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient}, &SNIExtension{}, &ExtendedMasterSecretExtension{}, - &SessionTicketExtension{}, + sessionTicketExt, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ ECDSAWithP256AndSHA256, PSSWithSHA256, @@ -198,7 +190,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SupportedPointsExtension{SupportedPoints: []byte{ 0x00, // pointFormatUncompressed }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ @@ -271,7 +263,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SupportedPointsExtension{SupportedPoints: []byte{ 0x00, // pointFormatUncompressed }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ @@ -343,7 +335,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SupportedPointsExtension{SupportedPoints: []byte{ 0x00, // pointFormatUncompressed }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ @@ -415,7 +407,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SupportedPointsExtension{SupportedPoints: []byte{ 0x00, // pointFormatUncompressed }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ @@ -488,7 +480,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SupportedPointsExtension{SupportedPoints: []byte{ 0x00, // pointFormatUncompressed }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ @@ -559,7 +551,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SupportedPointsExtension{SupportedPoints: []byte{ 0x00, // pointFormatUncompressed }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ @@ -632,7 +624,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SupportedPointsExtension{SupportedPoints: []byte{ 0x00, // pointFormatUncompressed }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ @@ -695,7 +687,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient}, &SupportedCurvesExtension{[]CurveID{X25519, CurveP256, CurveP384, CurveP521}}, &SupportedPointsExtension{SupportedPoints: []byte{pointFormatUncompressed}}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ @@ -757,7 +749,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SupportedPointsExtension{SupportedPoints: []byte{ pointFormatUncompressed, }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &KeyShareExtension{[]KeyShare{ @@ -828,7 +820,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SupportedPointsExtension{SupportedPoints: []byte{ //ec_point_formats pointFormatUncompressed, }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, //application_layer_protocol_negotiation &StatusRequestExtension{}, &FakeDelegatedCredentialsExtension{ @@ -909,7 +901,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { &SupportedPointsExtension{SupportedPoints: []byte{ //ec_point_formats pointFormatUncompressed, }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2"}}, //application_layer_protocol_negotiation &StatusRequestExtension{}, &FakeDelegatedCredentialsExtension{ @@ -994,7 +986,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { 0x0, // uncompressed }, }, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{ AlpnProtocols: []string{ "h2", @@ -1426,7 +1418,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { 0x0, // pointFormatUncompressed }, }, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{ AlpnProtocols: []string{ "h2", @@ -1531,7 +1523,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { 0x0, // uncompressed }, }, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{ AlpnProtocols: []string{ "h2", @@ -1749,7 +1741,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { 0x0, // pointFormatUncompressed }, }, - &SessionTicketExtension{}, + sessionTicketExt, &NPNExtension{}, &ALPNExtension{ AlpnProtocols: []string{ @@ -1823,7 +1815,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { 0x0, // uncompressed }, }, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{ AlpnProtocols: []string{ "h2", @@ -1931,7 +1923,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { 0x0, // uncompressed }, }, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{ AlpnProtocols: []string{ "h2", @@ -1995,24 +1987,6 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) { }, }, }, nil - default: - if id.Client == helloRandomized || id.Client == helloRandomizedALPN || id.Client == helloRandomizedNoALPN { - // Use empty values as they can be filled later by UConn.ApplyPreset or manually. - return generateRandomizedSpec(&id, "", nil, nil) - } - return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrUnknownClientHelloID, id.Str()) - } -} - -func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension) (ClientHelloSpec, error) { - switch id { - case HelloChrome_100_PSK, HelloChrome_112_PSK_Shuf, HelloChrome_114_Padding_PSK_Shuf, HelloChrome_115_PQ_PSK: - if len(pskExtension) == 0 || pskExtension[0] == nil { - return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrPSKExtensionExpected, id.Str()) - } - } - - switch id { case HelloChrome_100_PSK: return ClientHelloSpec{ CipherSuites: []uint16{ @@ -2050,7 +2024,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension &SupportedPointsExtension{SupportedPoints: []byte{ 0x00, // pointFormatUncompressed }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ @@ -2081,7 +2055,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension }}, &ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}}, &UtlsGREASEExtension{}, - pskExtension[0], + pskExt, }, }, nil case HelloChrome_112_PSK_Shuf: @@ -2121,7 +2095,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension &SupportedPointsExtension{SupportedPoints: []byte{ 0x00, // pointFormatUncompressed }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ @@ -2152,7 +2126,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension }}, &ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}}, &UtlsGREASEExtension{}, - pskExtension[0], + pskExt, }), }, nil case HelloChrome_114_Padding_PSK_Shuf: @@ -2192,7 +2166,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension &SupportedPointsExtension{SupportedPoints: []byte{ 0x00, // pointFormatUncompressed }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ @@ -2224,7 +2198,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension &ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}}, &UtlsGREASEExtension{}, &UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle}, - pskExtension[0], + pskExt, }), }, nil // Chrome w/ Post-Quantum Key Agreement @@ -2266,7 +2240,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension &SupportedPointsExtension{SupportedPoints: []byte{ 0x00, // pointFormatUncompressed }}, - &SessionTicketExtension{}, + sessionTicketExt, &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, &StatusRequestExtension{}, &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{ @@ -2298,12 +2272,17 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension }}, &ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}}, &UtlsGREASEExtension{}, - pskExtension[0], + pskExt, }), }, nil - } + default: + if id.Client == helloRandomized || id.Client == helloRandomizedALPN || id.Client == helloRandomizedNoALPN { + // Use empty values as they can be filled later by UConn.ApplyPreset or manually. + return generateRandomizedSpec(&id, "", nil) + } - return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrUnknownClientHelloID, id.Str()) + return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrUnknownClientHelloID, id.Str()) + } } // ShuffleChromeTLSExtensions shuffles the extensions in the ClientHelloSpec to avoid ossification. @@ -2345,9 +2324,8 @@ func (uconn *UConn) applyPresetByID(id ClientHelloID) (err error) { } case helloCustom: return nil - default: - spec, err = UTLSIdToSpec(id, uconn.pskExtension...) + spec, err = UTLSIdToSpec(id, uconn.sessionController.pskExtension, uconn.sessionController.sessionTicketExt) if err != nil { return err } @@ -2379,7 +2357,6 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error { } uconn.HandshakeState.State13.KeySharesParams = NewKeySharesParameters() hello := uconn.HandshakeState.Hello - session := uconn.HandshakeState.Session switch len(hello.Random) { case 0: @@ -2420,7 +2397,12 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error { hello.CipherSuites[i] = GetBoringGREASEValue(uconn.greaseSeed, ssl_grease_cipher) } } - uconn.GetSessionID = p.GetSessionID + var sessionID [32]byte + _, err = io.ReadFull(uconn.config.rand(), sessionID[:]) + if err != nil { + return err + } + uconn.HandshakeState.Hello.SessionId = sessionID[:] uconn.Extensions = make([]TLSExtension, len(p.Extensions)) copy(uconn.Extensions, p.Extensions) @@ -2445,23 +2427,6 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error { return errors.New("at most 2 grease extensions are supported") } grease_extensions_seen += 1 - case *SessionTicketExtension: - var cs *ClientSessionState - if session == nil && uconn.config.ClientSessionCache != nil { - cacheKey := uconn.clientSessionCacheKey() - cs, _ = uconn.config.ClientSessionCache.Get(cacheKey) - if cs != nil { - session = cs.session - } - } - // TLS 1.3 (PSK) resumption is handled by PreSharedKeyExtension in MarshalClientHello() - if session != nil && session.version == VersionTLS13 { - break - } - err := uconn.SetSessionState(cs) - if err != nil { - return err - } case *SupportedCurvesExtension: for i := range ext.Curves { if isGREASEUint16(uint16(ext.Curves[i])) { @@ -2528,22 +2493,18 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error { // but NextProtos is also used by ALPN and our spec nmay not actually have a NPN extension hello.NextProtoNeg = haveNPN + uconn.sessionController.checkSessionExt() + return nil } func (uconn *UConn) generateRandomizedSpec() (ClientHelloSpec, error) { - css := &ClientSessionState{ - session: uconn.HandshakeState.Session, - ticket: uconn.HandshakeState.Hello.SessionTicket, - } - - return generateRandomizedSpec(&uconn.ClientHelloID, uconn.serverName, css, uconn.config.NextProtos) + return generateRandomizedSpec(&uconn.ClientHelloID, uconn.serverName, uconn.config.NextProtos) } func generateRandomizedSpec( id *ClientHelloID, serverName string, - session *ClientSessionState, nextProtos []string, ) (ClientHelloSpec, error) { p := ClientHelloSpec{} @@ -2609,7 +2570,7 @@ func generateRandomizedSpec( p.CipherSuites = removeRandomCiphers(r, shuffledSuites, id.Weights.CipherSuites_Remove_RandomCiphers) sni := SNIExtension{serverName} - sessionTicket := SessionTicketExtension{Session: session} + sessionTicket := SessionTicketExtension{} sigAndHashAlgos := []SignatureScheme{ ECDSAWithP256AndSHA256, diff --git a/u_pre_shared_key.go b/u_pre_shared_key.go index c3eb580..948d3fb 100644 --- a/u_pre_shared_key.go +++ b/u_pre_shared_key.go @@ -8,26 +8,38 @@ import ( "golang.org/x/crypto/cryptobyte" ) +type PreSharedKeyCommon struct { + Identities []PskIdentity + Binders [][]byte + BinderKey []byte // this will be used to compute the binder when hello message is ready + EarlySecret []byte + Session *SessionState +} + type PreSharedKeyExtension interface { // TLSExtension must be implemented by all PreSharedKeyExtension implementations. // However, the Read() method should return an error since it MUST NOT be used // for PreSharedKeyExtension. TLSExtension + IsInitialized() bool + + InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) + // GetBinders returns the binders that were computed during the handshake // to be set in the internal copy of the ClientHello. Only needed if expecting // to resume the session. // // FakePreSharedKeyExtension MUST return nil to make sure utls DOES NOT // try to do any session resumption. - GetBinders() [][]byte + GetPreSharedKeyCommon() PreSharedKeyCommon // ReadWithRawHello is used to read the extension from the ClientHello // instead of Read(), where the latter is used to read all other extensions. // // This is needed because the PSK extension needs to calculate the binder // based on all previous parts of the ClientHello. - ReadWithRawHello(raw, b []byte) (int, error) + PatchBuiltHello(hello *PubClientHelloMsg) error mustEmbedUnimplementedPreSharedKeyExtension() // this works like a type guard } @@ -36,8 +48,16 @@ type UnimplementedPreSharedKeyExtension struct{} func (UnimplementedPreSharedKeyExtension) mustEmbedUnimplementedPreSharedKeyExtension() {} +func (*UnimplementedPreSharedKeyExtension) IsInitialized() bool { + panic("tls: IsInitialized is not implemented for the PreSharedKeyExtension") +} + +func (*UnimplementedPreSharedKeyExtension) InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) { + panic("tls: Initialize is not implemented for the PreSharedKeyExtension") +} + func (*UnimplementedPreSharedKeyExtension) writeToUConn(*UConn) error { - return errors.New("tls: writeToUConn is not implemented for the PreSharedKeyExtension") + panic("tls: writeToUConn is not implemented for the PreSharedKeyExtension") } func (*UnimplementedPreSharedKeyExtension) Len() int { @@ -45,108 +65,120 @@ func (*UnimplementedPreSharedKeyExtension) Len() int { } func (*UnimplementedPreSharedKeyExtension) Read([]byte) (int, error) { - return 0, errors.New("tls: Read is not implemented for the PreSharedKeyExtension") + panic("tls: Read is not implemented for the PreSharedKeyExtension") } -func (*UnimplementedPreSharedKeyExtension) GetBinders() [][]byte { - panic("tls: Binders is not implemented for the PreSharedKeyExtension") +func (*UnimplementedPreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon { + panic("tls: GetPreSharedKeyCommon is not implemented for the PreSharedKeyExtension") } -func (*UnimplementedPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error) { - return 0, errors.New("tls: ReadWithRawHello is not implemented for the PreSharedKeyExtension") +func (*UnimplementedPreSharedKeyExtension) PatchBuiltHello(hello *PubClientHelloMsg) error { + panic("tls: ReadWithRawHello is not implemented for the PreSharedKeyExtension") } // UtlsPreSharedKeyExtension is an extension used to set the PSK extension in the // ClientHello. type UtlsPreSharedKeyExtension struct { UnimplementedPreSharedKeyExtension + PreSharedKeyCommon + cipherSuite *cipherSuiteTLS13 + cachedLength *int +} - identities []pskIdentity - binders [][]byte - binderKey []byte // this will be used to compute the binder when hello message is ready - cipherSuite *cipherSuiteTLS13 - earlySecret []byte +func (e *UtlsPreSharedKeyExtension) IsInitialized() bool { + return e.Session != nil +} + +func (e *UtlsPreSharedKeyExtension) InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) { + e.Session = session + e.EarlySecret = earlySecret + e.BinderKey = binderKey + e.cipherSuite = cipherSuiteTLS13ByID(e.Session.cipherSuite) + e.Identities = identities + e.Binders = make([][]byte, 0, len(e.Identities)) + for i := 0; i < len(e.Identities); i++ { + e.Binders = append(e.Binders, make([]byte, e.cipherSuite.hash.Size())) + } } func (e *UtlsPreSharedKeyExtension) writeToUConn(uc *UConn) error { - err := e.preloadSession(uc) - if err != nil { - return err - } - - uc.HandshakeState.Hello.PskIdentities = pskIdentities(e.identities).ToPublic() - // uc.HandshakeState.Hello.PskBinders = e.binders - // uc.HandshakeState.Hello = hello.getPublicPtr() // write back to public hello - // uc.HandshakeState.State13.EarlySecret = e.earlySecret - // uc.HandshakeState.State13.BinderKey = e.binderKey - + uc.HandshakeState.Hello.TicketSupported = true // This doesn't matter though, as utls doesn't care about this field. We write this for consistency. return nil } -func (e *UtlsPreSharedKeyExtension) Len() int { +func (e *UtlsPreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon { + return e.PreSharedKeyCommon +} + +func pskExtLen(identities []PskIdentity, binders [][]byte) int { + if len(identities) == 0 || len(binders) == 0 { + return 0 + } length := 4 // extension type + extension length length += 2 // identities length - for _, identity := range e.identities { - length += 2 + len(identity.label) + 4 // identity length + identity + obfuscated ticket age + for _, identity := range identities { + length += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age } length += 2 // binders length - for _, binder := range e.binders { - length += len(binder) + 1 // binder length + binder + for _, binder := range binders { + length += len(binder) + 1 } return length } -func (e *UtlsPreSharedKeyExtension) Read(b []byte) (int, error) { - return 0, errors.New("tls: PreSharedKeyExtension shouldn't be read, use ReadWithRawHello() instead") +func (e *UtlsPreSharedKeyExtension) Len() int { + if e.Session == nil { + return 0 + } + if e.cachedLength != nil { + return *e.cachedLength + } + length := pskExtLen(e.Identities, e.Binders) + e.cachedLength = &length + return length } -// Binders must be called after ReadWithRawHello -func (e *UtlsPreSharedKeyExtension) GetBinders() [][]byte { - return e.binders -} - -func (e *UtlsPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error) { - if len(b) < e.Len() { +func readPskIntoBytes(b []byte, identities []PskIdentity, binders [][]byte) (int, error) { + extLen := pskExtLen(identities, binders) + if extLen == 0 { + return 0, io.EOF + } + if len(b) < extLen { return 0, io.ErrShortBuffer } b[0] = byte(extensionPreSharedKey >> 8) b[1] = byte(extensionPreSharedKey) - b[2] = byte((e.Len() - 4) >> 8) - b[3] = byte(e.Len() - 4) + b[2] = byte((extLen - 4) >> 8) + b[3] = byte(extLen - 4) // identities length identitiesLength := 0 - for _, identity := range e.identities { - identitiesLength += 2 + len(identity.label) + 4 // identity length + identity + obfuscated ticket age + for _, identity := range identities { + identitiesLength += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age } b[4] = byte(identitiesLength >> 8) b[5] = byte(identitiesLength) // identities offset := 6 - for _, identity := range e.identities { - b[offset] = byte(len(identity.label) >> 8) - b[offset+1] = byte(len(identity.label)) + for _, identity := range identities { + b[offset] = byte(len(identity.Label) >> 8) + b[offset+1] = byte(len(identity.Label)) offset += 2 - copy(b[offset:], identity.label) - offset += len(identity.label) - b[offset] = byte(identity.obfuscatedTicketAge >> 24) - b[offset+1] = byte(identity.obfuscatedTicketAge >> 16) - b[offset+2] = byte(identity.obfuscatedTicketAge >> 8) - b[offset+3] = byte(identity.obfuscatedTicketAge) + copy(b[offset:], identity.Label) + offset += len(identity.Label) + b[offset] = byte(identity.ObfuscatedTicketAge >> 24) + b[offset+1] = byte(identity.ObfuscatedTicketAge >> 16) + b[offset+2] = byte(identity.ObfuscatedTicketAge >> 8) + b[offset+3] = byte(identity.ObfuscatedTicketAge) offset += 4 } - // concatenate ClientHello and PreSharedKeyExtension - rawHelloSoFar := append(raw, b[:offset]...) - transcript := e.cipherSuite.hash.New() - transcript.Write(rawHelloSoFar) - e.binders = [][]byte{e.cipherSuite.finishedHash(e.binderKey, transcript)} - // binders length bindersLength := 0 - for _, binder := range e.binders { + for _, binder := range binders { + // check if binder size is valid bindersLength += len(binder) + 1 // binder length + binder } b[offset] = byte(bindersLength >> 8) @@ -154,39 +186,49 @@ func (e *UtlsPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error) offset += 2 // binders - for _, binder := range e.binders { + for _, binder := range binders { b[offset] = byte(len(binder)) offset++ copy(b[offset:], binder) offset += len(binder) } - return e.Len(), io.EOF + return extLen, io.EOF } -func (e *UtlsPreSharedKeyExtension) preloadSession(uc *UConn) error { - // load Hello - hello := uc.HandshakeState.Hello.getPrivatePtr() - // try to use loadSession() - session, earlySecret, binderKey, err := uc.loadSession(hello) +func (e *UtlsPreSharedKeyExtension) Read(b []byte) (int, error) { + return readPskIntoBytes(b, e.Identities, e.Binders) +} + +func (e *UtlsPreSharedKeyExtension) PatchBuiltHello(hello *PubClientHelloMsg) error { + if e.Len() == 0 { + return nil + } + private := hello.getCachedPrivatePtr() + if private == nil { + private = hello.getPrivatePtr() + } + private.raw = hello.Raw + private.pskBinders = e.Binders // set the placeholder to the private Hello + + //--- mirror loadSession() begin ---// + transcript := e.cipherSuite.hash.New() + helloBytes, err := private.marshalWithoutBinders() // no marshal() will be actually called, as we have set the field `raw` if err != nil { return err } - if session != nil && session.version == VersionTLS13 && binderKey != nil { - e.identities = hello.pskIdentities - e.binders = hello.pskBinders - e.binderKey = binderKey - e.cipherSuite = cipherSuiteTLS13ByID(session.cipherSuite) - e.earlySecret = earlySecret - } else if session == nil { - return errors.New("tls: session not found.") - } else if session.version != VersionTLS13 { - return errors.New("tls: session is not for TLS 1.3.") - } else if binderKey == nil { - return errors.New("tls: binder key not found.") - } + transcript.Write(helloBytes) + pskBinders := [][]byte{e.cipherSuite.finishedHash(e.BinderKey, transcript)} - return nil + if err := private.updateBinders(pskBinders); err != nil { + return err + } + //--- mirror loadSession() end ---// + e.Binders = pskBinders + + // no need to care about other PSK related fields, they will be handled separately + + return io.EOF } func (e *UtlsPreSharedKeyExtension) Write(b []byte) (int, error) { @@ -212,6 +254,14 @@ type FakePreSharedKeyExtension struct { Binders [][]byte `json:"binders"` } +func (e *FakePreSharedKeyExtension) IsInitialized() bool { + return e.Identities != nil && e.Binders != nil +} + +func (e *FakePreSharedKeyExtension) InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) { + panic("InitializeByUtls failed: don't let utls initialize FakePreSharedKeyExtension; provide your own identities and binders or use UtlsPreSharedKeyExtension") +} + func (e *FakePreSharedKeyExtension) writeToUConn(uc *UConn) error { if uc.config.ClientSessionCache == nil { return nil // don't write the extension if there is no session cache @@ -225,85 +275,33 @@ func (e *FakePreSharedKeyExtension) writeToUConn(uc *UConn) error { } func (e *FakePreSharedKeyExtension) Len() int { - length := 4 // extension type + extension length - length += 2 // identities length - for _, identity := range e.Identities { - length += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age - } - length += 2 // binders length - for _, binder := range e.Binders { - length += len(binder) - } - return length + return pskExtLen(e.Identities, e.Binders) } func (e *FakePreSharedKeyExtension) Read(b []byte) (int, error) { - return 0, errors.New("tls: PreSharedKeyExtension shouldn't be read, use ReadWithRawHello() instead") -} - -func (e *FakePreSharedKeyExtension) GetBinders() [][]byte { - return nil -} - -func (e *FakePreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error) { - if len(b) < e.Len() { - return 0, io.ErrShortBuffer - } - - b[0] = byte(extensionPreSharedKey >> 8) - b[1] = byte(extensionPreSharedKey) - b[2] = byte((e.Len() - 4) >> 8) - b[3] = byte(e.Len() - 4) - - // identities length - identitiesLength := 0 - for _, identity := range e.Identities { - identitiesLength += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age - } - b[4] = byte(identitiesLength >> 8) - b[5] = byte(identitiesLength) - - // identities - offset := 6 - for _, identity := range e.Identities { - b[offset] = byte(len(identity.Label) >> 8) - b[offset+1] = byte(len(identity.Label)) - offset += 2 - copy(b[offset:], identity.Label) - offset += len(identity.Label) - b[offset] = byte(identity.ObfuscatedTicketAge >> 24) - b[offset+1] = byte(identity.ObfuscatedTicketAge >> 16) - b[offset+2] = byte(identity.ObfuscatedTicketAge >> 8) - b[offset+3] = byte(identity.ObfuscatedTicketAge) - offset += 4 - } - - // binders length - bindersLength := 0 -LOOP_BINDERS: - for _, binder := range e.Binders { - // check if binder size is valid - for _, cipherSuite := range cipherSuitesTLS13 { - if len(binder) == cipherSuite.hash.Size() { - bindersLength += len(binder) + 1 // binder length + binder - continue LOOP_BINDERS - } + for _, b := range e.Binders { + if !(anyTrue(validHashLen, func(valid *int) bool { + return len(b) == *valid + })) { + return 0, errors.New("tls: FakePreSharedKeyExtension.Read failed: invalid binder size") } - return 0, errors.New("tls: invalid binder size") } - b[offset] = byte(bindersLength >> 8) - b[offset+1] = byte(bindersLength) - offset += 2 + return readPskIntoBytes(b, e.Identities, e.Binders) +} - // binders - for _, binder := range e.Binders { - b[offset] = byte(len(binder)) - offset++ - copy(b[offset:], binder) - offset += len(binder) +func (e *FakePreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon { + return PreSharedKeyCommon{ + Identities: e.Identities, + Binders: e.Binders, } +} - return e.Len(), io.EOF +var validHashLen = mapSlice(cipherSuitesTLS13, func(c *cipherSuiteTLS13) int { + return c.hash.Size() +}) + +func (*FakePreSharedKeyExtension) PatchBuiltHello(*PubClientHelloMsg) error { + return nil // no need to patch the hello since we don't need to update binders } func (e *FakePreSharedKeyExtension) Write(b []byte) (n int, err error) { diff --git a/u_public.go b/u_public.go index 55caa43..419b679 100644 --- a/u_public.go +++ b/u_public.go @@ -45,7 +45,7 @@ type TLS13OnlyState struct { EarlySecret []byte BinderKey []byte CertReq *CertificateRequestMsgTLS13 - UsingPSK bool + UsingPSK bool // don't set this field when building client hello SentDummyCCS bool Transcript hash.Hash TrafficSecret []byte // client_application_traffic_secret_0 @@ -251,7 +251,7 @@ type PubServerHelloMsg struct { OcspStapling bool Scts [][]byte ExtendedMasterSecret bool - TicketSupported bool + TicketSupported bool // used by go tls to determine whether to add the session ticket ext SecureRenegotiation []byte SecureRenegotiationSupported bool AlpnProtocol string @@ -357,13 +357,15 @@ type PubClientHelloMsg struct { PskIdentities []PskIdentity PskBinders [][]byte QuicTransportParameters []byte + + cachedPrivateHello *clientHelloMsg // todo: further optimize to reduce clientHelloMsg construction } func (chm *PubClientHelloMsg) getPrivatePtr() *clientHelloMsg { if chm == nil { return nil } else { - return &clientHelloMsg{ + private := &clientHelloMsg{ raw: chm.Raw, vers: chm.Vers, random: chm.Random, @@ -395,6 +397,16 @@ func (chm *PubClientHelloMsg) getPrivatePtr() *clientHelloMsg { nextProtoNeg: chm.NextProtoNeg, } + chm.cachedPrivateHello = private + return private + } +} + +func (chm *PubClientHelloMsg) getCachedPrivatePtr() *clientHelloMsg { + if chm == nil { + return nil + } else { + return chm.cachedPrivateHello } } @@ -432,6 +444,7 @@ func (chm *clientHelloMsg) getPublicPtr() *PubClientHelloMsg { PskIdentities: pskIdentities(chm.pskIdentities).ToPublic(), PskBinders: chm.pskBinders, QuicTransportParameters: chm.quicTransportParameters, + cachedPrivateHello: chm, } } } diff --git a/u_session_controller.go b/u_session_controller.go new file mode 100644 index 0000000..b4f0841 --- /dev/null +++ b/u_session_controller.go @@ -0,0 +1,272 @@ +package tls + +import "fmt" + +type LoadSessionTrackerState int + +const NeverCalled LoadSessionTrackerState = 0 +const UtlsAboutToCall LoadSessionTrackerState = 3 +const CalledByULoadSession LoadSessionTrackerState = 1 +const CalledByGoTLS LoadSessionTrackerState = 2 + +type sessionState int + +const NoSession sessionState = 0 +const TicketInitialized sessionState = 1 +const TicketAllSet sessionState = 4 +const PskExtInitialized sessionState = 2 +const PskAllSet sessionState = 3 + +// sessionController is responsible for all session related +type sessionController struct { + sessionTicketExt *SessionTicketExtension + pskExtension PreSharedKeyExtension + uconnRef *UConn + state sessionState + loadSessionTracker LoadSessionTrackerState + callingLoadSession bool + locked bool +} + +type shouldLoadSessionResult int + +const shouldReturn shouldLoadSessionResult = 0 +const shouldSetTicket shouldLoadSessionResult = 1 +const shouldSetPsk shouldLoadSessionResult = 2 +const shouldLoad shouldLoadSessionResult = 3 + +func newSessionController(uconn *UConn) *sessionController { + return &sessionController{ + uconnRef: uconn, + sessionTicketExt: &SessionTicketExtension{}, + pskExtension: &UtlsPreSharedKeyExtension{}, + state: NoSession, + locked: false, + callingLoadSession: false, + loadSessionTracker: NeverCalled, + } +} + +func (s *sessionController) isSessionLocked() bool { + return s.locked +} + +func (s *sessionController) shouldLoadSession() shouldLoadSessionResult { + if s.sessionTicketExt == nil && s.pskExtension == nil || s.uconnRef.clientHelloBuildStatus != NotBuilt { + fmt.Println("uLoadSession s.sessionTicketExt == nil && s.pskExtension == nil") + // There's no need to load session since we don't have the related extensions. + return shouldReturn + } + if s.state == TicketInitialized { + return shouldSetTicket + } + if s.state == PskExtInitialized { + return shouldSetPsk + } + return shouldLoad +} + +func (s *sessionController) aboutToLoadSession() { + uAssert(s.state == NoSession && !s.locked, "tls: aboutToLoadSession failed: must only load session when the session of the client hello is not locked and when there's currently no session") + s.loadSessionTracker = UtlsAboutToCall +} + +func (s *sessionController) commonCheck(failureMsg string, params ...any) { + if s.uconnRef.clientHelloBuildStatus != NotBuilt { + panic(failureMsg + ": we can't modify the session after the clientHello is built") + } + if s.state != NoSession { + panic(failureMsg + ": the session already set") + } + panicOnNil(failureMsg, params...) +} + +func (s *sessionController) finalCheck() { + uAssert(s.state == PskAllSet || s.state == TicketAllSet || s.state == NoSession, "tls: SessionController.finalCheck failed: the session is half set") + s.locked = true +} + +func (s *sessionController) initSessionTicketExt(session *SessionState, ticket []byte) { + s.commonCheck("tls: initSessionTicket failed", s.sessionTicketExt, session, ticket) + s.sessionTicketExt.Session = session + s.sessionTicketExt.Ticket = ticket + s.state = TicketInitialized +} + +func (s *sessionController) setSessionTicketToUConn() { + uAssert(s.sessionTicketExt != nil && s.state == TicketInitialized, "tls: setSessionTicketExt failed: invalid state") + s.uconnRef.HandshakeState.Session = s.sessionTicketExt.Session + s.uconnRef.HandshakeState.Hello.SessionTicket = s.sessionTicketExt.Ticket + s.state = TicketAllSet +} + +func mapSlice[T any, U any](slice []T, transform func(T) U) []U { + newSlice := make([]U, 0, len(slice)) + for _, t := range slice { + newSlice = append(newSlice, transform(t)) + } + return newSlice +} + +func (s *sessionController) initPsk(session *SessionState, earlySecret []byte, binderKey []byte, pskIdentities []pskIdentity) { + s.commonCheck("tls: initPsk failed", s.pskExtension, session, earlySecret, pskIdentities) + uAssert(!s.pskExtension.IsInitialized(), "tls: initPsk failed: the psk extension is already initialized") + + publicPskIdentities := mapSlice(pskIdentities, func(private pskIdentity) PskIdentity { + return PskIdentity{ + Label: private.label, + ObfuscatedTicketAge: private.obfuscatedTicketAge, + } + }) + s.pskExtension.InitializeByUtls(session, earlySecret, binderKey, publicPskIdentities) + uAssert(s.pskExtension.IsInitialized(), "the psk extension is not initialized after initialization") + s.uconnRef.HandshakeState.State13.BinderKey = binderKey + s.uconnRef.HandshakeState.State13.EarlySecret = earlySecret + s.uconnRef.HandshakeState.Session = session + s.uconnRef.HandshakeState.Hello.PskIdentities = publicPskIdentities + // binders are not expected to be available at this point + s.state = PskExtInitialized +} + +func (s *sessionController) setPsk() { + uAssert(s.pskExtension != nil && (s.state == PskExtInitialized || s.state == PskAllSet), "tls: setPsk failed: invalid state") + pskCommon := s.pskExtension.GetPreSharedKeyCommon() + if s.state == PskExtInitialized { + s.uconnRef.HandshakeState.State13.EarlySecret = pskCommon.EarlySecret + s.uconnRef.HandshakeState.Session = pskCommon.Session + s.uconnRef.HandshakeState.Hello.PskIdentities = pskCommon.Identities + s.uconnRef.HandshakeState.Hello.PskBinders = pskCommon.Binders + } else if s.state == PskAllSet { + uAssert(sliceEq([]any{ + s.uconnRef.HandshakeState.State13.EarlySecret, + s.uconnRef.HandshakeState.Session, + s.uconnRef.HandshakeState.Hello.PskIdentities, + s.uconnRef.HandshakeState.Hello.PskBinders, + }, []any{ + pskCommon.EarlySecret, + pskCommon.Session, + pskCommon.Identities, + pskCommon.Binders, + }), "setPsk failed: only binders are allowed to change on state `PskAllSet`") + } + s.uconnRef.HandshakeState.State13.BinderKey = pskCommon.BinderKey + s.state = PskAllSet +} + +func (s *sessionController) shouldUpdateBinders() bool { + if s.pskExtension == nil { + return false + } + return s.state == PskExtInitialized || s.state == PskAllSet +} + +func (s *sessionController) updateBinders() { + uAssert(s.shouldUpdateBinders(), "tls: updateBinders failed: shouldn't update binders") + s.pskExtension.PatchBuiltHello(s.uconnRef.HandshakeState.Hello) +} + +func (s *sessionController) overridePskExt(psk PreSharedKeyExtension) error { + if s.state != NoSession { + return fmt.Errorf("SetSessionState13 failed: there's already a session") + } + s.pskExtension = psk + if psk.IsInitialized() { + s.state = PskExtInitialized + } + return nil +} + +var customizedHellos = []ClientHelloID{ + HelloCustom, + HelloRandomized, + HelloRandomizedALPN, + HelloRandomizedNoALPN, +} + +func (s *sessionController) checkSessionExt() { + uAssert(s.uconnRef.clientHelloBuildStatus == NotBuilt, "tls: checkSessionExt failed: we can't modify the session after the clientHello is built") + numSessionExt := 0 + hasPskExt := false + for i, e := range s.uconnRef.Extensions { + switch ext := e.(type) { + case *SessionTicketExtension: + if ext != s.uconnRef.sessionController.sessionTicketExt { + if anyTrue(customizedHellos, func(h *ClientHelloID) bool { + return s.uconnRef.ClientHelloID.Client == h.Client + }) { + s.uconnRef.Extensions[i] = s.uconnRef.sessionController.sessionTicketExt + } else { + panic(fmt.Sprintf("tls: checkSessionExt failed: sessionTicketExtShortcut != SessionTicketExtension from the extension list and the clientHello is build from presets: [%v]", s.uconnRef.ClientHelloID)) + } + } + numSessionExt += 1 + case PreSharedKeyExtension: + uAssert(i == len(s.uconnRef.Extensions)-1, "tls: checkSessionExt failed: PreSharedKeyExtension must be the last extension") + if ext != s.uconnRef.sessionController.pskExtension { + if anyTrue(customizedHellos, func(h *ClientHelloID) bool { + return s.uconnRef.ClientHelloID.Client == h.Client + }) { + s.uconnRef.Extensions[i] = s.uconnRef.sessionController.pskExtension + } else { + panic(fmt.Sprintf("tls: checkSessionExt failed: pskExtensionShortcut != PreSharedKeyExtension from the extension list and the clientHello is build from presets: [%v]", s.uconnRef.ClientHelloID)) + } + } + hasPskExt = true + } + } + if !(s.state == NoSession || s.state == TicketInitialized || s.state == PskExtInitialized) { + panic(fmt.Sprintf("tls: checkSessionExt failed: can't remove session ticket extension; the session ticket extension is unused, but the internal state is: %d", s.state)) + } + if numSessionExt == 0 { + s.sessionTicketExt = nil + s.uconnRef.HandshakeState.Session = nil + s.uconnRef.HandshakeState.Hello.SessionTicket = nil + } else if numSessionExt > 1 { + panic("checkSessionExt failed: multiple session ticket extensions in the extension list") + } + if !hasPskExt { + s.pskExtension = nil + s.uconnRef.HandshakeState.State13.BinderKey = nil + s.uconnRef.HandshakeState.State13.EarlySecret = nil + s.uconnRef.HandshakeState.Session = nil + s.uconnRef.HandshakeState.Hello.PskIdentities = nil + } +} + +func (s *sessionController) onEnterLoadSessionCheck() { + uAssert(!s.locked, "tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: session is set and locked, no call to loadSession is allowed") + switch s.loadSessionTracker { + case UtlsAboutToCall, NeverCalled: + s.callingLoadSession = true + case CalledByULoadSession, CalledByGoTLS: + panic("tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: you must not call loadSession() twice") + default: + panic("tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: unimplemented state") + } +} + +func (s *sessionController) onLoadSessionReturn() { + uAssert(s.callingLoadSession, "tls: LoadSessionCoordinator.onLoadSessionReturn failed: it's not loading sessions, perhaps this function is not being called by loadSession.") + switch s.loadSessionTracker { + case NeverCalled: + s.loadSessionTracker = CalledByGoTLS + case UtlsAboutToCall: + s.loadSessionTracker = CalledByULoadSession + default: + panic("tls: LoadSessionCoordinator.onLoadSessionReturn failed: unimplemented state") + } + s.callingLoadSession = false +} + +func (s *sessionController) shouldWriteBinders() bool { + uAssert(s.callingLoadSession, "tls: shouldWriteBinders failed: LoadSessionCoordinator isn't loading sessions, perhaps this function is not being called by loadSession.") + + switch s.loadSessionTracker { + case NeverCalled: + return true + case UtlsAboutToCall: + return false + default: + panic("tls: shouldWriteBinders failed: unimplemented state") + } +} diff --git a/u_tls_extensions.go b/u_tls_extensions.go index 90f6969..4072d35 100644 --- a/u_tls_extensions.go +++ b/u_tls_extensions.go @@ -802,22 +802,19 @@ func (e *SCTExtension) Write(_ []byte) (int, error) { // SessionTicketExtension implements session_ticket (35) type SessionTicketExtension struct { - Session *ClientSessionState + Session *SessionState + Ticket []byte } func (e *SessionTicketExtension) writeToUConn(uc *UConn) error { - if e.Session != nil { - uc.HandshakeState.Session = e.Session.session - uc.HandshakeState.Hello.SessionTicket = e.Session.ticket - } + // session states are handled later. At this point tickets aren't + // being loaded by utls, so don't write anything to the UConn. + uc.HandshakeState.Hello.TicketSupported = true // This doesn't really matter, this field is only used to add session ticket ext in go tls. return nil } func (e *SessionTicketExtension) Len() int { - if e.Session != nil { - return 4 + len(e.Session.ticket) - } - return 4 + return 4 + len(e.Ticket) } func (e *SessionTicketExtension) Read(b []byte) (int, error) { @@ -832,7 +829,7 @@ func (e *SessionTicketExtension) Read(b []byte) (int, error) { b[2] = byte(extBodyLen >> 8) b[3] = byte(extBodyLen) if extBodyLen > 0 { - copy(b[4:], e.Session.ticket) + copy(b[4:], e.Ticket) } return e.Len(), io.EOF }