feat: bug fix and refactor

This commit is contained in:
3andne 2023-08-20 18:53:04 -07:00
parent 3162534cc7
commit a040a404e6
13 changed files with 675 additions and 349 deletions

View file

@ -157,7 +157,7 @@ Currently, there is a simple function to set session ticket to any desired state
```Golang
// If you want you session tickets to be reused - use same cache on following connections
func (uconn *UConn) SetSessionState(session *ClientSessionState)
func (uconn *UConn) SetSessionState12(session *ClientSessionState)
```
Note that session tickets (fake ones or otherwise) are not reused.
@ -294,7 +294,7 @@ Some customizations(such as setting session ticket/clientHello) have easy-to-use
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
masterSecret,
nil, nil)
tlsConn.SetSessionState(sessionState)
tlsConn.SetSessionState12(sessionState)
```
For other customizations there are following functions

View file

@ -140,7 +140,7 @@ func HttpGetTicket(hostname string, addr string) (*http.Response, error) {
masterSecret,
nil, nil)
err = uTlsConn.SetSessionState(sessionState)
err = uTlsConn.SetSessionState12(sessionState)
if err != nil {
return nil, err
}
@ -174,7 +174,7 @@ func HttpGetTicketHelloID(hostname string, addr string, helloID tls.ClientHelloI
masterSecret,
nil, nil)
uTlsConn.SetSessionState(sessionState)
uTlsConn.SetSessionState12(sessionState)
err = uTlsConn.Handshake()
if err != nil {
return nil, fmt.Errorf("uTlsConn.Handshake() error: %+v", err)

View file

@ -38,7 +38,7 @@ func (csc *ClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState
}
}
func main() {
func runPskCheck(helloID tls.ClientHelloID) {
const serverAddr string = "refraction.network:443"
csc := NewClientSessionCache()
tcpConn, err := net.Dial("tcp", serverAddr)
@ -53,7 +53,7 @@ func main() {
ServerName: strings.Split(serverAddr, ":")[0],
// NextProtos: []string{"h2", "http/1.1"},
ClientSessionCache: csc, // set this so session tickets will be saved
}, tls.HelloChrome_100)
}, helloID)
// HS
err = tlsConn.Handshake()
@ -88,10 +88,12 @@ func main() {
tlsConnPSK := tls.UClient(tcpConnPSK, &tls.Config{
ServerName: strings.Split(serverAddr, ":")[0],
ClientSessionCache: csc,
}, tls.HelloChrome_100_PSK, &tls.UtlsPreSharedKeyExtension{})
}, helloID)
// HS
err = tlsConnPSK.Handshake()
fmt.Println(tlsConnPSK.HandshakeState.Hello.Raw)
fmt.Println(tlsConnPSK.HandshakeState.Hello.PskIdentities)
if err != nil {
panic(err)
}
@ -111,3 +113,8 @@ func main() {
}
}
}
func main() {
runPskCheck(tls.HelloChrome_100_PSK)
runPskCheck(tls.HelloGolang)
}

View file

@ -312,6 +312,12 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
func (c *Conn) loadSession(hello *clientHelloMsg) (
session *SessionState, earlySecret, binderKey []byte, err error) {
// [UTLS SECTION START]
if c.utls.sessionController != nil {
c.utls.sessionController.onEnterLoadSessionCheck()
defer c.utls.sessionController.onLoadSessionReturn()
}
// [UTLS SECTION END]
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
return nil, nil, nil, nil
}
@ -324,12 +330,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
hello.pskModes = []uint8{pskModeDHE}
}
// [UTLS BEGINS]
if c.utls.session != nil {
return c.utls.session, c.utls.earlySecret, c.utls.binderKey, nil
}
// [UTLS ENDS]
// Session resumption is not allowed if renegotiating because
// renegotiation is primarily used to allow a client to send a client
// certificate, which would be skipped if session resumption occurred.
@ -456,6 +456,11 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
// Compute the PSK binders. See RFC 8446, Section 4.2.11.2.
earlySecret = cipherSuite.extract(session.secret, nil)
binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil)
// [UTLS SECTION START]
if c.utls.sessionController != nil && !c.utls.sessionController.shouldWriteBinders() {
return
}
// [UTLS SECTION END]
transcript := cipherSuite.hash.New()
helloBytes, err := hello.marshalWithoutBinders()
if err != nil {
@ -466,11 +471,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
if err := hello.updateBinders(pskBinders); err != nil {
return nil, nil, nil, err
}
c.utls.session = session // [uTLS]
c.utls.earlySecret = earlySecret // [uTLS]
c.utls.binderKey = binderKey // [uTLS]
return
}

View file

@ -29,7 +29,7 @@ func testClientHelloSpecJSONUnmarshaler(
t.Fatal(err)
}
truthSpec, _ := utlsIdToSpec(truthClientHelloID)
truthSpec, _ := utlsIdToSpec(truthClientHelloID, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{})
jsonSpec := chsju.ClientHelloSpec()
// Compare CipherSuites
@ -85,7 +85,7 @@ func testClientHelloSpecUnmarshalJSON(
t.Fatal(err)
}
truthSpec, _ := utlsIdToSpec(truthClientHelloID)
truthSpec, _ := utlsIdToSpec(truthClientHelloID, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{})
// Compare CipherSuites
if !reflect.DeepEqual(jsonSpec.CipherSuites, truthSpec.CipherSuites) {

View file

@ -726,3 +726,38 @@ func EnableWeakCiphers() {
suiteECDHE | suiteTLS12 | suiteSHA384, cipherAES, utlsMacSHA384, nil},
}...)
}
func panicOnNil(failureMsg string, params ...any) {
for i, p := range params {
if p == nil {
panic(fmt.Sprintf("%s: the [%d] parameter is nil", failureMsg, i))
}
}
}
func anyTrue[T any](slice []T, predicate func(t *T) bool) bool {
for i := 0; i < len(slice); i++ {
if predicate(&slice[i]) {
return true
}
}
return false
}
func uAssert(condition bool, msg string) {
if !condition {
panic(msg)
}
}
func sliceEq(sliceA []any, sliceB []any) bool {
if len(sliceA) != len(sliceB) {
return false
}
for i := 0; i < len(sliceA); i++ {
if sliceA[i] != sliceB[i] {
return false
}
}
return true
}

189
u_conn.go
View file

@ -14,23 +14,26 @@ import (
"errors"
"fmt"
"hash"
"io"
"net"
"strconv"
)
type ClientHelloBuildStatus int
const NotBuilt ClientHelloBuildStatus = 0
const BuildByUtls ClientHelloBuildStatus = 1
const BuildByGoTLS ClientHelloBuildStatus = 2
type UConn struct {
*Conn
Extensions []TLSExtension
ClientHelloID ClientHelloID
pskExtension []PreSharedKeyExtension
Extensions []TLSExtension
ClientHelloID ClientHelloID
sessionController *sessionController
ClientHelloBuilt bool
HandshakeState PubClientHandshakeState
clientHelloBuildStatus ClientHelloBuildStatus
// sessionID may or may not depend on ticket; nil => random
GetSessionID func(ticket []byte) [32]byte
HandshakeState PubClientHandshakeState
greaseSeed [ssl_grease_last_index]uint16
@ -44,15 +47,17 @@ type UConn struct {
// UClient returns a new uTLS client, with behavior depending on clientHelloID.
// Config CAN be nil, but make sure to eventually specify ServerName.
func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID, pskExtension ...PreSharedKeyExtension) *UConn {
func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn {
if config == nil {
config = &Config{}
}
tlsConn := Conn{conn: conn, config: config, isClient: true}
handshakeState := PubClientHandshakeState{C: &tlsConn, Hello: &PubClientHelloMsg{}}
uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, pskExtension: pskExtension, HandshakeState: handshakeState}
uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, HandshakeState: handshakeState}
uconn.HandshakeState.uconn = &uconn
uconn.handshakeFn = uconn.clientHandshake
uconn.sessionController = newSessionController(&uconn)
uconn.utls.sessionController = uconn.sessionController
return &uconn
}
@ -73,9 +78,10 @@ func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID, pskExte
// default/mimicked ClientHello.
func (uconn *UConn) BuildHandshakeState() error {
if uconn.ClientHelloID == HelloGolang {
if uconn.ClientHelloBuilt {
if uconn.clientHelloBuildStatus == BuildByGoTLS {
return nil
}
uAssert(uconn.clientHelloBuildStatus == NotBuilt, "BuildHandshakeState failed: invalid call, client hello has already been built by utls")
// use default Golang ClientHello.
hello, keySharePrivate, err := uconn.makeClientHello()
@ -92,8 +98,10 @@ func (uconn *UConn) BuildHandshakeState() error {
return fmt.Errorf("uTLS: unknown keySharePrivate type: %T", keySharePrivate)
}
uconn.HandshakeState.C = uconn.Conn
uconn.clientHelloBuildStatus = BuildByGoTLS
} else {
if !uconn.ClientHelloBuilt {
uAssert(uconn.clientHelloBuildStatus == BuildByUtls || uconn.clientHelloBuildStatus == NotBuilt, "BuildHandshakeState failed: invalid call, client hello has already been built by go-tls")
if uconn.clientHelloBuildStatus == NotBuilt {
err := uconn.applyPresetByID(uconn.ClientHelloID)
if err != nil {
return err
@ -107,51 +115,93 @@ func (uconn *UConn) BuildHandshakeState() error {
if err != nil {
return err
}
err = uconn.uLoadSession()
if err != nil {
return err
}
err = uconn.MarshalClientHello()
if err != nil {
return err
}
uconn.uApplyPatch()
uconn.sessionController.finalCheck()
uconn.clientHelloBuildStatus = BuildByUtls
}
uconn.ClientHelloBuilt = true
return nil
}
// SetSessionState sets the session ticket, which may be preshared or fake.
func (uconn *UConn) uLoadSession() error {
if cfg := uconn.config; cfg.SessionTicketsDisabled || cfg.ClientSessionCache == nil {
return nil
}
switch uconn.sessionController.shouldLoadSession() {
case shouldReturn:
case shouldSetTicket:
uconn.sessionController.setSessionTicketToUConn()
case shouldSetPsk:
uconn.sessionController.setPsk()
case shouldLoad:
hello := uconn.HandshakeState.Hello.getPrivatePtr()
uconn.sessionController.aboutToLoadSession()
session, earlySecret, binderKey, err := uconn.loadSession(hello)
if session == nil || err != nil {
return err
}
if session.version == VersionTLS12 {
// We use the session ticket extension for tls 1.2 session resumption
uconn.sessionController.initSessionTicketExt(session, hello.sessionTicket)
uconn.sessionController.setSessionTicketToUConn()
} else {
uconn.sessionController.initPsk(session, earlySecret, binderKey, hello.pskIdentities)
}
}
return nil
}
func (uconn *UConn) uApplyPatch() {
if uconn.sessionController.shouldUpdateBinders() {
uconn.sessionController.updateBinders()
uconn.sessionController.setPsk()
}
}
// SetSessionState12 sets the session ticket, which may be preshared or fake.
// If session is nil, the body of session ticket extension will be unset,
// but the extension itself still MAY be present for mimicking purposes.
// Session tickets to be reused - use same cache on following connections.
func (uconn *UConn) SetSessionState(session *ClientSessionState) error {
var sessionTicket []uint8
if session != nil {
sessionTicket = session.ticket
uconn.HandshakeState.Session = session.session
func (uconn *UConn) SetSessionState12(session *ClientSessionState) error {
if uconn.config.SessionTicketsDisabled || uconn.config.ClientSessionCache == nil {
return fmt.Errorf("SetSessionState12 failed: session is disabled")
}
uconn.HandshakeState.Hello.TicketSupported = true
uconn.HandshakeState.Hello.SessionTicket = sessionTicket
for _, ext := range uconn.Extensions {
st, ok := ext.(*SessionTicketExtension)
if !ok {
continue
}
st.Session = session
if session != nil {
if len(session.SessionTicket()) > 0 {
if uconn.GetSessionID != nil {
sid := uconn.GetSessionID(session.SessionTicket())
uconn.HandshakeState.Hello.SessionId = sid[:]
return nil
}
}
var sessionID [32]byte
_, err := io.ReadFull(uconn.config.rand(), sessionID[:])
if err != nil {
return err
}
uconn.HandshakeState.Hello.SessionId = sessionID[:]
}
if session == nil {
return nil
}
if session.session == nil {
return fmt.Errorf("SetSessionState12 failed: session must not be nil")
}
if session.session.version != VersionTLS12 {
return fmt.Errorf("SetSessionState12 failed: SetSessionState12 only works for tls 1.2 session ticket; for tls 1.3 please customize PSK with SetSessionState13()")
}
uconn.sessionController.initSessionTicketExt(session.session, session.ticket)
return nil
}
// SetSessionState13 sets the psk extension for tls 1.3 resumption
func (uconn *UConn) SetSessionState13(psk PreSharedKeyExtension) error {
if uconn.config.SessionTicketsDisabled || uconn.config.ClientSessionCache == nil {
return fmt.Errorf("SetSessionState13 failed: session is disabled")
}
if psk == nil {
return nil
}
uconn.HandshakeState.Hello.TicketSupported = true
uconn.sessionController.overridePskExt(psk)
return nil
}
@ -397,7 +447,7 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
hello := c.HandshakeState.Hello.getPrivatePtr()
defer func() { c.HandshakeState.Hello = hello.getPublicPtr() }()
sessionIsAlreadySet := c.HandshakeState.Session != nil
sessionIsLocked := c.utls.sessionController.isSessionLocked()
// after this point exactly 1 out of 2 HandshakeState pointers is non-nil,
// useTLS13 variable tells which pointer
@ -434,9 +484,24 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
if c.handshakes > 0 {
hello.secureRenegotiation = c.clientFinished[:]
}
// [uTLS section ends]
session, earlySecret, binderKey, err := c.loadSession(hello)
var (
session *SessionState
earlySecret []byte
binderKey []byte
)
if !sessionIsLocked {
// [uTLS section ends]
session, earlySecret, binderKey, err = c.loadSession(hello)
// [uTLS section start]
} else {
session = c.HandshakeState.Session
earlySecret = c.HandshakeState.State13.EarlySecret
binderKey = c.HandshakeState.State13.BinderKey
}
// [uTLS section ends]
if err != nil {
return err
}
@ -491,7 +556,7 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
hs13.serverHello = serverHello
hs13.hello = hello
hs13.keySharesParams = NewKeySharesParameters()
if !sessionIsAlreadySet {
if !sessionIsLocked {
hs13.earlySecret = earlySecret
hs13.binderKey = binderKey
hs13.session = session
@ -547,7 +612,7 @@ func (uconn *UConn) MarshalClientHello() error {
if paddingExt == nil {
paddingExt = pe
} else {
return errors.New("multiple padding extensions!")
return errors.New("multiple padding extensions")
}
}
}
@ -589,27 +654,8 @@ func (uconn *UConn) MarshalClientHello() error {
if len(uconn.Extensions) > 0 {
binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen))
for _, ext := range uconn.Extensions {
switch typedExt := ext.(type) {
case PreSharedKeyExtension:
// PSK extension is handled separately
err := bufferedWriter.Flush()
if err != nil {
return fmt.Errorf("bufferedWriter.Flush(): %w", err)
}
hello.Raw = helloBuffer.Bytes()
// prepare buffer
buf := make([]byte, typedExt.Len())
n, err := typedExt.ReadWithRawHello(hello.Raw, buf)
if err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("(*PreSharedKeyExtension).ReadWithRawHello(): %w", err)
}
if n != typedExt.Len() {
return errors.New("uconn: PreSharedKeyExtension: read wrong number of bytes")
}
bufferedWriter.Write(buf)
hello.PskBinders = typedExt.GetBinders()
default:
bufferedWriter.ReadFrom(ext)
if _, err := bufferedWriter.ReadFrom(ext); err != nil {
return err
}
}
}
@ -801,8 +847,5 @@ type utlsConnExtraFields struct {
peerApplicationSettings []byte
localApplicationSettings []byte
// session resumption (PSK)
session *SessionState
earlySecret []byte
binderKey []byte
sessionController *sessionController
}

View file

@ -252,7 +252,7 @@ func TestUTLSFingerprintClientHelloBluntMimicry(t *testing.T) {
var extensionId uint16 = 0xfeed
extensionData := []byte("random data")
specWithGeneric, err := utlsIdToSpec(HelloChrome_Auto)
specWithGeneric, err := utlsIdToSpec(HelloChrome_Auto, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{})
if err != nil {
t.Errorf("got error: %v; expected to succeed", err)
}
@ -293,11 +293,11 @@ func TestUTLSFingerprintClientHelloBluntMimicry(t *testing.T) {
func TestUTLSFingerprintClientHelloAlwaysAddPadding(t *testing.T) {
serverName := "foobar"
specWithoutPadding, err := utlsIdToSpec(HelloIOS_12_1)
specWithoutPadding, err := utlsIdToSpec(HelloIOS_12_1, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{})
if err != nil {
t.Errorf("got error: %v; expected to succeed", err)
}
specWithPadding, err := utlsIdToSpec(HelloChrome_83)
specWithPadding, err := utlsIdToSpec(HelloChrome_83, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{})
if err != nil {
t.Errorf("got error: %v; expected to succeed", err)
}

View file

@ -17,26 +17,18 @@ import (
)
var ErrUnknownClientHelloID = errors.New("tls: unknown ClientHelloID")
var ErrNotPSKClientHelloID = errors.New("tls: ClientHello does not contain pre_shared_key extension")
var ErrPSKExtensionExpected = errors.New("tls: pre_shared_key extension expected when fetching preset ClientHelloSpec")
// UTLSIdToSpec converts a ClientHelloID to a corresponding ClientHelloSpec.
//
// Exported internal function utlsIdToSpec per request.
func UTLSIdToSpec(id ClientHelloID, pskExtension ...PreSharedKeyExtension) (ClientHelloSpec, error) {
if len(pskExtension) > 1 {
return ClientHelloSpec{}, errors.New("tls: at most one PreSharedKeyExtensions is allowed")
}
chs, err := utlsIdToSpec(id)
if err != nil && errors.Is(err, ErrUnknownClientHelloID) {
chs, err = utlsIdToSpecWithPSK(id, pskExtension...)
}
return chs, err
func UTLSIdToSpec(id ClientHelloID, pskExt PreSharedKeyExtension, sessionTicketExt *SessionTicketExtension) (ClientHelloSpec, error) {
return utlsIdToSpec(id, pskExt, sessionTicketExt)
}
func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
func utlsIdToSpec(id ClientHelloID, pskExt PreSharedKeyExtension, sessionTicketExt *SessionTicketExtension) (ClientHelloSpec, error) {
if pskExt == nil || sessionTicketExt == nil {
panic("utlsIdToSpec failed: pskExt and sessionTicketExt must be non-nil pointers")
}
switch id {
case HelloChrome_58, HelloChrome_62:
return ClientHelloSpec{
@ -64,7 +56,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient},
&SNIExtension{},
&ExtendedMasterSecretExtension{},
&SessionTicketExtension{},
sessionTicketExt,
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
ECDSAWithP256AndSHA256,
PSSWithSHA256,
@ -119,7 +111,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient},
&SNIExtension{},
&ExtendedMasterSecretExtension{},
&SessionTicketExtension{},
sessionTicketExt,
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
ECDSAWithP256AndSHA256,
PSSWithSHA256,
@ -198,7 +190,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&SupportedPointsExtension{SupportedPoints: []byte{
0x00, // pointFormatUncompressed
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
@ -271,7 +263,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&SupportedPointsExtension{SupportedPoints: []byte{
0x00, // pointFormatUncompressed
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
@ -343,7 +335,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&SupportedPointsExtension{SupportedPoints: []byte{
0x00, // pointFormatUncompressed
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
@ -415,7 +407,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&SupportedPointsExtension{SupportedPoints: []byte{
0x00, // pointFormatUncompressed
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
@ -488,7 +480,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&SupportedPointsExtension{SupportedPoints: []byte{
0x00, // pointFormatUncompressed
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
@ -559,7 +551,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&SupportedPointsExtension{SupportedPoints: []byte{
0x00, // pointFormatUncompressed
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
@ -632,7 +624,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&SupportedPointsExtension{SupportedPoints: []byte{
0x00, // pointFormatUncompressed
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
@ -695,7 +687,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient},
&SupportedCurvesExtension{[]CurveID{X25519, CurveP256, CurveP384, CurveP521}},
&SupportedPointsExtension{SupportedPoints: []byte{pointFormatUncompressed}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
@ -757,7 +749,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&SupportedPointsExtension{SupportedPoints: []byte{
pointFormatUncompressed,
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&KeyShareExtension{[]KeyShare{
@ -828,7 +820,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&SupportedPointsExtension{SupportedPoints: []byte{ //ec_point_formats
pointFormatUncompressed,
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, //application_layer_protocol_negotiation
&StatusRequestExtension{},
&FakeDelegatedCredentialsExtension{
@ -909,7 +901,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&SupportedPointsExtension{SupportedPoints: []byte{ //ec_point_formats
pointFormatUncompressed,
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2"}}, //application_layer_protocol_negotiation
&StatusRequestExtension{},
&FakeDelegatedCredentialsExtension{
@ -994,7 +986,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
0x0, // uncompressed
},
},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{
AlpnProtocols: []string{
"h2",
@ -1426,7 +1418,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
0x0, // pointFormatUncompressed
},
},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{
AlpnProtocols: []string{
"h2",
@ -1531,7 +1523,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
0x0, // uncompressed
},
},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{
AlpnProtocols: []string{
"h2",
@ -1749,7 +1741,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
0x0, // pointFormatUncompressed
},
},
&SessionTicketExtension{},
sessionTicketExt,
&NPNExtension{},
&ALPNExtension{
AlpnProtocols: []string{
@ -1823,7 +1815,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
0x0, // uncompressed
},
},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{
AlpnProtocols: []string{
"h2",
@ -1931,7 +1923,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
0x0, // uncompressed
},
},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{
AlpnProtocols: []string{
"h2",
@ -1995,24 +1987,6 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
},
},
}, nil
default:
if id.Client == helloRandomized || id.Client == helloRandomizedALPN || id.Client == helloRandomizedNoALPN {
// Use empty values as they can be filled later by UConn.ApplyPreset or manually.
return generateRandomizedSpec(&id, "", nil, nil)
}
return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrUnknownClientHelloID, id.Str())
}
}
func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension) (ClientHelloSpec, error) {
switch id {
case HelloChrome_100_PSK, HelloChrome_112_PSK_Shuf, HelloChrome_114_Padding_PSK_Shuf, HelloChrome_115_PQ_PSK:
if len(pskExtension) == 0 || pskExtension[0] == nil {
return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrPSKExtensionExpected, id.Str())
}
}
switch id {
case HelloChrome_100_PSK:
return ClientHelloSpec{
CipherSuites: []uint16{
@ -2050,7 +2024,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
&SupportedPointsExtension{SupportedPoints: []byte{
0x00, // pointFormatUncompressed
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
@ -2081,7 +2055,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
}},
&ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
&UtlsGREASEExtension{},
pskExtension[0],
pskExt,
},
}, nil
case HelloChrome_112_PSK_Shuf:
@ -2121,7 +2095,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
&SupportedPointsExtension{SupportedPoints: []byte{
0x00, // pointFormatUncompressed
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
@ -2152,7 +2126,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
}},
&ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
&UtlsGREASEExtension{},
pskExtension[0],
pskExt,
}),
}, nil
case HelloChrome_114_Padding_PSK_Shuf:
@ -2192,7 +2166,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
&SupportedPointsExtension{SupportedPoints: []byte{
0x00, // pointFormatUncompressed
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
@ -2224,7 +2198,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
&ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
&UtlsGREASEExtension{},
&UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle},
pskExtension[0],
pskExt,
}),
}, nil
// Chrome w/ Post-Quantum Key Agreement
@ -2266,7 +2240,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
&SupportedPointsExtension{SupportedPoints: []byte{
0x00, // pointFormatUncompressed
}},
&SessionTicketExtension{},
sessionTicketExt,
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
&StatusRequestExtension{},
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
@ -2298,12 +2272,17 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
}},
&ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
&UtlsGREASEExtension{},
pskExtension[0],
pskExt,
}),
}, nil
}
default:
if id.Client == helloRandomized || id.Client == helloRandomizedALPN || id.Client == helloRandomizedNoALPN {
// Use empty values as they can be filled later by UConn.ApplyPreset or manually.
return generateRandomizedSpec(&id, "", nil)
}
return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrUnknownClientHelloID, id.Str())
return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrUnknownClientHelloID, id.Str())
}
}
// ShuffleChromeTLSExtensions shuffles the extensions in the ClientHelloSpec to avoid ossification.
@ -2345,9 +2324,8 @@ func (uconn *UConn) applyPresetByID(id ClientHelloID) (err error) {
}
case helloCustom:
return nil
default:
spec, err = UTLSIdToSpec(id, uconn.pskExtension...)
spec, err = UTLSIdToSpec(id, uconn.sessionController.pskExtension, uconn.sessionController.sessionTicketExt)
if err != nil {
return err
}
@ -2379,7 +2357,6 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
}
uconn.HandshakeState.State13.KeySharesParams = NewKeySharesParameters()
hello := uconn.HandshakeState.Hello
session := uconn.HandshakeState.Session
switch len(hello.Random) {
case 0:
@ -2420,7 +2397,12 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
hello.CipherSuites[i] = GetBoringGREASEValue(uconn.greaseSeed, ssl_grease_cipher)
}
}
uconn.GetSessionID = p.GetSessionID
var sessionID [32]byte
_, err = io.ReadFull(uconn.config.rand(), sessionID[:])
if err != nil {
return err
}
uconn.HandshakeState.Hello.SessionId = sessionID[:]
uconn.Extensions = make([]TLSExtension, len(p.Extensions))
copy(uconn.Extensions, p.Extensions)
@ -2445,23 +2427,6 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
return errors.New("at most 2 grease extensions are supported")
}
grease_extensions_seen += 1
case *SessionTicketExtension:
var cs *ClientSessionState
if session == nil && uconn.config.ClientSessionCache != nil {
cacheKey := uconn.clientSessionCacheKey()
cs, _ = uconn.config.ClientSessionCache.Get(cacheKey)
if cs != nil {
session = cs.session
}
}
// TLS 1.3 (PSK) resumption is handled by PreSharedKeyExtension in MarshalClientHello()
if session != nil && session.version == VersionTLS13 {
break
}
err := uconn.SetSessionState(cs)
if err != nil {
return err
}
case *SupportedCurvesExtension:
for i := range ext.Curves {
if isGREASEUint16(uint16(ext.Curves[i])) {
@ -2528,22 +2493,18 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
// but NextProtos is also used by ALPN and our spec nmay not actually have a NPN extension
hello.NextProtoNeg = haveNPN
uconn.sessionController.checkSessionExt()
return nil
}
func (uconn *UConn) generateRandomizedSpec() (ClientHelloSpec, error) {
css := &ClientSessionState{
session: uconn.HandshakeState.Session,
ticket: uconn.HandshakeState.Hello.SessionTicket,
}
return generateRandomizedSpec(&uconn.ClientHelloID, uconn.serverName, css, uconn.config.NextProtos)
return generateRandomizedSpec(&uconn.ClientHelloID, uconn.serverName, uconn.config.NextProtos)
}
func generateRandomizedSpec(
id *ClientHelloID,
serverName string,
session *ClientSessionState,
nextProtos []string,
) (ClientHelloSpec, error) {
p := ClientHelloSpec{}
@ -2609,7 +2570,7 @@ func generateRandomizedSpec(
p.CipherSuites = removeRandomCiphers(r, shuffledSuites, id.Weights.CipherSuites_Remove_RandomCiphers)
sni := SNIExtension{serverName}
sessionTicket := SessionTicketExtension{Session: session}
sessionTicket := SessionTicketExtension{}
sigAndHashAlgos := []SignatureScheme{
ECDSAWithP256AndSHA256,

View file

@ -8,26 +8,38 @@ import (
"golang.org/x/crypto/cryptobyte"
)
type PreSharedKeyCommon struct {
Identities []PskIdentity
Binders [][]byte
BinderKey []byte // this will be used to compute the binder when hello message is ready
EarlySecret []byte
Session *SessionState
}
type PreSharedKeyExtension interface {
// TLSExtension must be implemented by all PreSharedKeyExtension implementations.
// However, the Read() method should return an error since it MUST NOT be used
// for PreSharedKeyExtension.
TLSExtension
IsInitialized() bool
InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity)
// GetBinders returns the binders that were computed during the handshake
// to be set in the internal copy of the ClientHello. Only needed if expecting
// to resume the session.
//
// FakePreSharedKeyExtension MUST return nil to make sure utls DOES NOT
// try to do any session resumption.
GetBinders() [][]byte
GetPreSharedKeyCommon() PreSharedKeyCommon
// ReadWithRawHello is used to read the extension from the ClientHello
// instead of Read(), where the latter is used to read all other extensions.
//
// This is needed because the PSK extension needs to calculate the binder
// based on all previous parts of the ClientHello.
ReadWithRawHello(raw, b []byte) (int, error)
PatchBuiltHello(hello *PubClientHelloMsg) error
mustEmbedUnimplementedPreSharedKeyExtension() // this works like a type guard
}
@ -36,8 +48,16 @@ type UnimplementedPreSharedKeyExtension struct{}
func (UnimplementedPreSharedKeyExtension) mustEmbedUnimplementedPreSharedKeyExtension() {}
func (*UnimplementedPreSharedKeyExtension) IsInitialized() bool {
panic("tls: IsInitialized is not implemented for the PreSharedKeyExtension")
}
func (*UnimplementedPreSharedKeyExtension) InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) {
panic("tls: Initialize is not implemented for the PreSharedKeyExtension")
}
func (*UnimplementedPreSharedKeyExtension) writeToUConn(*UConn) error {
return errors.New("tls: writeToUConn is not implemented for the PreSharedKeyExtension")
panic("tls: writeToUConn is not implemented for the PreSharedKeyExtension")
}
func (*UnimplementedPreSharedKeyExtension) Len() int {
@ -45,108 +65,120 @@ func (*UnimplementedPreSharedKeyExtension) Len() int {
}
func (*UnimplementedPreSharedKeyExtension) Read([]byte) (int, error) {
return 0, errors.New("tls: Read is not implemented for the PreSharedKeyExtension")
panic("tls: Read is not implemented for the PreSharedKeyExtension")
}
func (*UnimplementedPreSharedKeyExtension) GetBinders() [][]byte {
panic("tls: Binders is not implemented for the PreSharedKeyExtension")
func (*UnimplementedPreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon {
panic("tls: GetPreSharedKeyCommon is not implemented for the PreSharedKeyExtension")
}
func (*UnimplementedPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error) {
return 0, errors.New("tls: ReadWithRawHello is not implemented for the PreSharedKeyExtension")
func (*UnimplementedPreSharedKeyExtension) PatchBuiltHello(hello *PubClientHelloMsg) error {
panic("tls: ReadWithRawHello is not implemented for the PreSharedKeyExtension")
}
// UtlsPreSharedKeyExtension is an extension used to set the PSK extension in the
// ClientHello.
type UtlsPreSharedKeyExtension struct {
UnimplementedPreSharedKeyExtension
PreSharedKeyCommon
cipherSuite *cipherSuiteTLS13
cachedLength *int
}
identities []pskIdentity
binders [][]byte
binderKey []byte // this will be used to compute the binder when hello message is ready
cipherSuite *cipherSuiteTLS13
earlySecret []byte
func (e *UtlsPreSharedKeyExtension) IsInitialized() bool {
return e.Session != nil
}
func (e *UtlsPreSharedKeyExtension) InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) {
e.Session = session
e.EarlySecret = earlySecret
e.BinderKey = binderKey
e.cipherSuite = cipherSuiteTLS13ByID(e.Session.cipherSuite)
e.Identities = identities
e.Binders = make([][]byte, 0, len(e.Identities))
for i := 0; i < len(e.Identities); i++ {
e.Binders = append(e.Binders, make([]byte, e.cipherSuite.hash.Size()))
}
}
func (e *UtlsPreSharedKeyExtension) writeToUConn(uc *UConn) error {
err := e.preloadSession(uc)
if err != nil {
return err
}
uc.HandshakeState.Hello.PskIdentities = pskIdentities(e.identities).ToPublic()
// uc.HandshakeState.Hello.PskBinders = e.binders
// uc.HandshakeState.Hello = hello.getPublicPtr() // write back to public hello
// uc.HandshakeState.State13.EarlySecret = e.earlySecret
// uc.HandshakeState.State13.BinderKey = e.binderKey
uc.HandshakeState.Hello.TicketSupported = true // This doesn't matter though, as utls doesn't care about this field. We write this for consistency.
return nil
}
func (e *UtlsPreSharedKeyExtension) Len() int {
func (e *UtlsPreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon {
return e.PreSharedKeyCommon
}
func pskExtLen(identities []PskIdentity, binders [][]byte) int {
if len(identities) == 0 || len(binders) == 0 {
return 0
}
length := 4 // extension type + extension length
length += 2 // identities length
for _, identity := range e.identities {
length += 2 + len(identity.label) + 4 // identity length + identity + obfuscated ticket age
for _, identity := range identities {
length += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
}
length += 2 // binders length
for _, binder := range e.binders {
length += len(binder) + 1 // binder length + binder
for _, binder := range binders {
length += len(binder) + 1
}
return length
}
func (e *UtlsPreSharedKeyExtension) Read(b []byte) (int, error) {
return 0, errors.New("tls: PreSharedKeyExtension shouldn't be read, use ReadWithRawHello() instead")
func (e *UtlsPreSharedKeyExtension) Len() int {
if e.Session == nil {
return 0
}
if e.cachedLength != nil {
return *e.cachedLength
}
length := pskExtLen(e.Identities, e.Binders)
e.cachedLength = &length
return length
}
// Binders must be called after ReadWithRawHello
func (e *UtlsPreSharedKeyExtension) GetBinders() [][]byte {
return e.binders
}
func (e *UtlsPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error) {
if len(b) < e.Len() {
func readPskIntoBytes(b []byte, identities []PskIdentity, binders [][]byte) (int, error) {
extLen := pskExtLen(identities, binders)
if extLen == 0 {
return 0, io.EOF
}
if len(b) < extLen {
return 0, io.ErrShortBuffer
}
b[0] = byte(extensionPreSharedKey >> 8)
b[1] = byte(extensionPreSharedKey)
b[2] = byte((e.Len() - 4) >> 8)
b[3] = byte(e.Len() - 4)
b[2] = byte((extLen - 4) >> 8)
b[3] = byte(extLen - 4)
// identities length
identitiesLength := 0
for _, identity := range e.identities {
identitiesLength += 2 + len(identity.label) + 4 // identity length + identity + obfuscated ticket age
for _, identity := range identities {
identitiesLength += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
}
b[4] = byte(identitiesLength >> 8)
b[5] = byte(identitiesLength)
// identities
offset := 6
for _, identity := range e.identities {
b[offset] = byte(len(identity.label) >> 8)
b[offset+1] = byte(len(identity.label))
for _, identity := range identities {
b[offset] = byte(len(identity.Label) >> 8)
b[offset+1] = byte(len(identity.Label))
offset += 2
copy(b[offset:], identity.label)
offset += len(identity.label)
b[offset] = byte(identity.obfuscatedTicketAge >> 24)
b[offset+1] = byte(identity.obfuscatedTicketAge >> 16)
b[offset+2] = byte(identity.obfuscatedTicketAge >> 8)
b[offset+3] = byte(identity.obfuscatedTicketAge)
copy(b[offset:], identity.Label)
offset += len(identity.Label)
b[offset] = byte(identity.ObfuscatedTicketAge >> 24)
b[offset+1] = byte(identity.ObfuscatedTicketAge >> 16)
b[offset+2] = byte(identity.ObfuscatedTicketAge >> 8)
b[offset+3] = byte(identity.ObfuscatedTicketAge)
offset += 4
}
// concatenate ClientHello and PreSharedKeyExtension
rawHelloSoFar := append(raw, b[:offset]...)
transcript := e.cipherSuite.hash.New()
transcript.Write(rawHelloSoFar)
e.binders = [][]byte{e.cipherSuite.finishedHash(e.binderKey, transcript)}
// binders length
bindersLength := 0
for _, binder := range e.binders {
for _, binder := range binders {
// check if binder size is valid
bindersLength += len(binder) + 1 // binder length + binder
}
b[offset] = byte(bindersLength >> 8)
@ -154,39 +186,49 @@ func (e *UtlsPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error)
offset += 2
// binders
for _, binder := range e.binders {
for _, binder := range binders {
b[offset] = byte(len(binder))
offset++
copy(b[offset:], binder)
offset += len(binder)
}
return e.Len(), io.EOF
return extLen, io.EOF
}
func (e *UtlsPreSharedKeyExtension) preloadSession(uc *UConn) error {
// load Hello
hello := uc.HandshakeState.Hello.getPrivatePtr()
// try to use loadSession()
session, earlySecret, binderKey, err := uc.loadSession(hello)
func (e *UtlsPreSharedKeyExtension) Read(b []byte) (int, error) {
return readPskIntoBytes(b, e.Identities, e.Binders)
}
func (e *UtlsPreSharedKeyExtension) PatchBuiltHello(hello *PubClientHelloMsg) error {
if e.Len() == 0 {
return nil
}
private := hello.getCachedPrivatePtr()
if private == nil {
private = hello.getPrivatePtr()
}
private.raw = hello.Raw
private.pskBinders = e.Binders // set the placeholder to the private Hello
//--- mirror loadSession() begin ---//
transcript := e.cipherSuite.hash.New()
helloBytes, err := private.marshalWithoutBinders() // no marshal() will be actually called, as we have set the field `raw`
if err != nil {
return err
}
if session != nil && session.version == VersionTLS13 && binderKey != nil {
e.identities = hello.pskIdentities
e.binders = hello.pskBinders
e.binderKey = binderKey
e.cipherSuite = cipherSuiteTLS13ByID(session.cipherSuite)
e.earlySecret = earlySecret
} else if session == nil {
return errors.New("tls: session not found.")
} else if session.version != VersionTLS13 {
return errors.New("tls: session is not for TLS 1.3.")
} else if binderKey == nil {
return errors.New("tls: binder key not found.")
}
transcript.Write(helloBytes)
pskBinders := [][]byte{e.cipherSuite.finishedHash(e.BinderKey, transcript)}
return nil
if err := private.updateBinders(pskBinders); err != nil {
return err
}
//--- mirror loadSession() end ---//
e.Binders = pskBinders
// no need to care about other PSK related fields, they will be handled separately
return io.EOF
}
func (e *UtlsPreSharedKeyExtension) Write(b []byte) (int, error) {
@ -212,6 +254,14 @@ type FakePreSharedKeyExtension struct {
Binders [][]byte `json:"binders"`
}
func (e *FakePreSharedKeyExtension) IsInitialized() bool {
return e.Identities != nil && e.Binders != nil
}
func (e *FakePreSharedKeyExtension) InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) {
panic("InitializeByUtls failed: don't let utls initialize FakePreSharedKeyExtension; provide your own identities and binders or use UtlsPreSharedKeyExtension")
}
func (e *FakePreSharedKeyExtension) writeToUConn(uc *UConn) error {
if uc.config.ClientSessionCache == nil {
return nil // don't write the extension if there is no session cache
@ -225,85 +275,33 @@ func (e *FakePreSharedKeyExtension) writeToUConn(uc *UConn) error {
}
func (e *FakePreSharedKeyExtension) Len() int {
length := 4 // extension type + extension length
length += 2 // identities length
for _, identity := range e.Identities {
length += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
}
length += 2 // binders length
for _, binder := range e.Binders {
length += len(binder)
}
return length
return pskExtLen(e.Identities, e.Binders)
}
func (e *FakePreSharedKeyExtension) Read(b []byte) (int, error) {
return 0, errors.New("tls: PreSharedKeyExtension shouldn't be read, use ReadWithRawHello() instead")
}
func (e *FakePreSharedKeyExtension) GetBinders() [][]byte {
return nil
}
func (e *FakePreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error) {
if len(b) < e.Len() {
return 0, io.ErrShortBuffer
}
b[0] = byte(extensionPreSharedKey >> 8)
b[1] = byte(extensionPreSharedKey)
b[2] = byte((e.Len() - 4) >> 8)
b[3] = byte(e.Len() - 4)
// identities length
identitiesLength := 0
for _, identity := range e.Identities {
identitiesLength += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
}
b[4] = byte(identitiesLength >> 8)
b[5] = byte(identitiesLength)
// identities
offset := 6
for _, identity := range e.Identities {
b[offset] = byte(len(identity.Label) >> 8)
b[offset+1] = byte(len(identity.Label))
offset += 2
copy(b[offset:], identity.Label)
offset += len(identity.Label)
b[offset] = byte(identity.ObfuscatedTicketAge >> 24)
b[offset+1] = byte(identity.ObfuscatedTicketAge >> 16)
b[offset+2] = byte(identity.ObfuscatedTicketAge >> 8)
b[offset+3] = byte(identity.ObfuscatedTicketAge)
offset += 4
}
// binders length
bindersLength := 0
LOOP_BINDERS:
for _, binder := range e.Binders {
// check if binder size is valid
for _, cipherSuite := range cipherSuitesTLS13 {
if len(binder) == cipherSuite.hash.Size() {
bindersLength += len(binder) + 1 // binder length + binder
continue LOOP_BINDERS
}
for _, b := range e.Binders {
if !(anyTrue(validHashLen, func(valid *int) bool {
return len(b) == *valid
})) {
return 0, errors.New("tls: FakePreSharedKeyExtension.Read failed: invalid binder size")
}
return 0, errors.New("tls: invalid binder size")
}
b[offset] = byte(bindersLength >> 8)
b[offset+1] = byte(bindersLength)
offset += 2
return readPskIntoBytes(b, e.Identities, e.Binders)
}
// binders
for _, binder := range e.Binders {
b[offset] = byte(len(binder))
offset++
copy(b[offset:], binder)
offset += len(binder)
func (e *FakePreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon {
return PreSharedKeyCommon{
Identities: e.Identities,
Binders: e.Binders,
}
}
return e.Len(), io.EOF
var validHashLen = mapSlice(cipherSuitesTLS13, func(c *cipherSuiteTLS13) int {
return c.hash.Size()
})
func (*FakePreSharedKeyExtension) PatchBuiltHello(*PubClientHelloMsg) error {
return nil // no need to patch the hello since we don't need to update binders
}
func (e *FakePreSharedKeyExtension) Write(b []byte) (n int, err error) {

View file

@ -45,7 +45,7 @@ type TLS13OnlyState struct {
EarlySecret []byte
BinderKey []byte
CertReq *CertificateRequestMsgTLS13
UsingPSK bool
UsingPSK bool // don't set this field when building client hello
SentDummyCCS bool
Transcript hash.Hash
TrafficSecret []byte // client_application_traffic_secret_0
@ -251,7 +251,7 @@ type PubServerHelloMsg struct {
OcspStapling bool
Scts [][]byte
ExtendedMasterSecret bool
TicketSupported bool
TicketSupported bool // used by go tls to determine whether to add the session ticket ext
SecureRenegotiation []byte
SecureRenegotiationSupported bool
AlpnProtocol string
@ -357,13 +357,15 @@ type PubClientHelloMsg struct {
PskIdentities []PskIdentity
PskBinders [][]byte
QuicTransportParameters []byte
cachedPrivateHello *clientHelloMsg // todo: further optimize to reduce clientHelloMsg construction
}
func (chm *PubClientHelloMsg) getPrivatePtr() *clientHelloMsg {
if chm == nil {
return nil
} else {
return &clientHelloMsg{
private := &clientHelloMsg{
raw: chm.Raw,
vers: chm.Vers,
random: chm.Random,
@ -395,6 +397,16 @@ func (chm *PubClientHelloMsg) getPrivatePtr() *clientHelloMsg {
nextProtoNeg: chm.NextProtoNeg,
}
chm.cachedPrivateHello = private
return private
}
}
func (chm *PubClientHelloMsg) getCachedPrivatePtr() *clientHelloMsg {
if chm == nil {
return nil
} else {
return chm.cachedPrivateHello
}
}
@ -432,6 +444,7 @@ func (chm *clientHelloMsg) getPublicPtr() *PubClientHelloMsg {
PskIdentities: pskIdentities(chm.pskIdentities).ToPublic(),
PskBinders: chm.pskBinders,
QuicTransportParameters: chm.quicTransportParameters,
cachedPrivateHello: chm,
}
}
}

272
u_session_controller.go Normal file
View file

@ -0,0 +1,272 @@
package tls
import "fmt"
type LoadSessionTrackerState int
const NeverCalled LoadSessionTrackerState = 0
const UtlsAboutToCall LoadSessionTrackerState = 3
const CalledByULoadSession LoadSessionTrackerState = 1
const CalledByGoTLS LoadSessionTrackerState = 2
type sessionState int
const NoSession sessionState = 0
const TicketInitialized sessionState = 1
const TicketAllSet sessionState = 4
const PskExtInitialized sessionState = 2
const PskAllSet sessionState = 3
// sessionController is responsible for all session related
type sessionController struct {
sessionTicketExt *SessionTicketExtension
pskExtension PreSharedKeyExtension
uconnRef *UConn
state sessionState
loadSessionTracker LoadSessionTrackerState
callingLoadSession bool
locked bool
}
type shouldLoadSessionResult int
const shouldReturn shouldLoadSessionResult = 0
const shouldSetTicket shouldLoadSessionResult = 1
const shouldSetPsk shouldLoadSessionResult = 2
const shouldLoad shouldLoadSessionResult = 3
func newSessionController(uconn *UConn) *sessionController {
return &sessionController{
uconnRef: uconn,
sessionTicketExt: &SessionTicketExtension{},
pskExtension: &UtlsPreSharedKeyExtension{},
state: NoSession,
locked: false,
callingLoadSession: false,
loadSessionTracker: NeverCalled,
}
}
func (s *sessionController) isSessionLocked() bool {
return s.locked
}
func (s *sessionController) shouldLoadSession() shouldLoadSessionResult {
if s.sessionTicketExt == nil && s.pskExtension == nil || s.uconnRef.clientHelloBuildStatus != NotBuilt {
fmt.Println("uLoadSession s.sessionTicketExt == nil && s.pskExtension == nil")
// There's no need to load session since we don't have the related extensions.
return shouldReturn
}
if s.state == TicketInitialized {
return shouldSetTicket
}
if s.state == PskExtInitialized {
return shouldSetPsk
}
return shouldLoad
}
func (s *sessionController) aboutToLoadSession() {
uAssert(s.state == NoSession && !s.locked, "tls: aboutToLoadSession failed: must only load session when the session of the client hello is not locked and when there's currently no session")
s.loadSessionTracker = UtlsAboutToCall
}
func (s *sessionController) commonCheck(failureMsg string, params ...any) {
if s.uconnRef.clientHelloBuildStatus != NotBuilt {
panic(failureMsg + ": we can't modify the session after the clientHello is built")
}
if s.state != NoSession {
panic(failureMsg + ": the session already set")
}
panicOnNil(failureMsg, params...)
}
func (s *sessionController) finalCheck() {
uAssert(s.state == PskAllSet || s.state == TicketAllSet || s.state == NoSession, "tls: SessionController.finalCheck failed: the session is half set")
s.locked = true
}
func (s *sessionController) initSessionTicketExt(session *SessionState, ticket []byte) {
s.commonCheck("tls: initSessionTicket failed", s.sessionTicketExt, session, ticket)
s.sessionTicketExt.Session = session
s.sessionTicketExt.Ticket = ticket
s.state = TicketInitialized
}
func (s *sessionController) setSessionTicketToUConn() {
uAssert(s.sessionTicketExt != nil && s.state == TicketInitialized, "tls: setSessionTicketExt failed: invalid state")
s.uconnRef.HandshakeState.Session = s.sessionTicketExt.Session
s.uconnRef.HandshakeState.Hello.SessionTicket = s.sessionTicketExt.Ticket
s.state = TicketAllSet
}
func mapSlice[T any, U any](slice []T, transform func(T) U) []U {
newSlice := make([]U, 0, len(slice))
for _, t := range slice {
newSlice = append(newSlice, transform(t))
}
return newSlice
}
func (s *sessionController) initPsk(session *SessionState, earlySecret []byte, binderKey []byte, pskIdentities []pskIdentity) {
s.commonCheck("tls: initPsk failed", s.pskExtension, session, earlySecret, pskIdentities)
uAssert(!s.pskExtension.IsInitialized(), "tls: initPsk failed: the psk extension is already initialized")
publicPskIdentities := mapSlice(pskIdentities, func(private pskIdentity) PskIdentity {
return PskIdentity{
Label: private.label,
ObfuscatedTicketAge: private.obfuscatedTicketAge,
}
})
s.pskExtension.InitializeByUtls(session, earlySecret, binderKey, publicPskIdentities)
uAssert(s.pskExtension.IsInitialized(), "the psk extension is not initialized after initialization")
s.uconnRef.HandshakeState.State13.BinderKey = binderKey
s.uconnRef.HandshakeState.State13.EarlySecret = earlySecret
s.uconnRef.HandshakeState.Session = session
s.uconnRef.HandshakeState.Hello.PskIdentities = publicPskIdentities
// binders are not expected to be available at this point
s.state = PskExtInitialized
}
func (s *sessionController) setPsk() {
uAssert(s.pskExtension != nil && (s.state == PskExtInitialized || s.state == PskAllSet), "tls: setPsk failed: invalid state")
pskCommon := s.pskExtension.GetPreSharedKeyCommon()
if s.state == PskExtInitialized {
s.uconnRef.HandshakeState.State13.EarlySecret = pskCommon.EarlySecret
s.uconnRef.HandshakeState.Session = pskCommon.Session
s.uconnRef.HandshakeState.Hello.PskIdentities = pskCommon.Identities
s.uconnRef.HandshakeState.Hello.PskBinders = pskCommon.Binders
} else if s.state == PskAllSet {
uAssert(sliceEq([]any{
s.uconnRef.HandshakeState.State13.EarlySecret,
s.uconnRef.HandshakeState.Session,
s.uconnRef.HandshakeState.Hello.PskIdentities,
s.uconnRef.HandshakeState.Hello.PskBinders,
}, []any{
pskCommon.EarlySecret,
pskCommon.Session,
pskCommon.Identities,
pskCommon.Binders,
}), "setPsk failed: only binders are allowed to change on state `PskAllSet`")
}
s.uconnRef.HandshakeState.State13.BinderKey = pskCommon.BinderKey
s.state = PskAllSet
}
func (s *sessionController) shouldUpdateBinders() bool {
if s.pskExtension == nil {
return false
}
return s.state == PskExtInitialized || s.state == PskAllSet
}
func (s *sessionController) updateBinders() {
uAssert(s.shouldUpdateBinders(), "tls: updateBinders failed: shouldn't update binders")
s.pskExtension.PatchBuiltHello(s.uconnRef.HandshakeState.Hello)
}
func (s *sessionController) overridePskExt(psk PreSharedKeyExtension) error {
if s.state != NoSession {
return fmt.Errorf("SetSessionState13 failed: there's already a session")
}
s.pskExtension = psk
if psk.IsInitialized() {
s.state = PskExtInitialized
}
return nil
}
var customizedHellos = []ClientHelloID{
HelloCustom,
HelloRandomized,
HelloRandomizedALPN,
HelloRandomizedNoALPN,
}
func (s *sessionController) checkSessionExt() {
uAssert(s.uconnRef.clientHelloBuildStatus == NotBuilt, "tls: checkSessionExt failed: we can't modify the session after the clientHello is built")
numSessionExt := 0
hasPskExt := false
for i, e := range s.uconnRef.Extensions {
switch ext := e.(type) {
case *SessionTicketExtension:
if ext != s.uconnRef.sessionController.sessionTicketExt {
if anyTrue(customizedHellos, func(h *ClientHelloID) bool {
return s.uconnRef.ClientHelloID.Client == h.Client
}) {
s.uconnRef.Extensions[i] = s.uconnRef.sessionController.sessionTicketExt
} else {
panic(fmt.Sprintf("tls: checkSessionExt failed: sessionTicketExtShortcut != SessionTicketExtension from the extension list and the clientHello is build from presets: [%v]", s.uconnRef.ClientHelloID))
}
}
numSessionExt += 1
case PreSharedKeyExtension:
uAssert(i == len(s.uconnRef.Extensions)-1, "tls: checkSessionExt failed: PreSharedKeyExtension must be the last extension")
if ext != s.uconnRef.sessionController.pskExtension {
if anyTrue(customizedHellos, func(h *ClientHelloID) bool {
return s.uconnRef.ClientHelloID.Client == h.Client
}) {
s.uconnRef.Extensions[i] = s.uconnRef.sessionController.pskExtension
} else {
panic(fmt.Sprintf("tls: checkSessionExt failed: pskExtensionShortcut != PreSharedKeyExtension from the extension list and the clientHello is build from presets: [%v]", s.uconnRef.ClientHelloID))
}
}
hasPskExt = true
}
}
if !(s.state == NoSession || s.state == TicketInitialized || s.state == PskExtInitialized) {
panic(fmt.Sprintf("tls: checkSessionExt failed: can't remove session ticket extension; the session ticket extension is unused, but the internal state is: %d", s.state))
}
if numSessionExt == 0 {
s.sessionTicketExt = nil
s.uconnRef.HandshakeState.Session = nil
s.uconnRef.HandshakeState.Hello.SessionTicket = nil
} else if numSessionExt > 1 {
panic("checkSessionExt failed: multiple session ticket extensions in the extension list")
}
if !hasPskExt {
s.pskExtension = nil
s.uconnRef.HandshakeState.State13.BinderKey = nil
s.uconnRef.HandshakeState.State13.EarlySecret = nil
s.uconnRef.HandshakeState.Session = nil
s.uconnRef.HandshakeState.Hello.PskIdentities = nil
}
}
func (s *sessionController) onEnterLoadSessionCheck() {
uAssert(!s.locked, "tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: session is set and locked, no call to loadSession is allowed")
switch s.loadSessionTracker {
case UtlsAboutToCall, NeverCalled:
s.callingLoadSession = true
case CalledByULoadSession, CalledByGoTLS:
panic("tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: you must not call loadSession() twice")
default:
panic("tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: unimplemented state")
}
}
func (s *sessionController) onLoadSessionReturn() {
uAssert(s.callingLoadSession, "tls: LoadSessionCoordinator.onLoadSessionReturn failed: it's not loading sessions, perhaps this function is not being called by loadSession.")
switch s.loadSessionTracker {
case NeverCalled:
s.loadSessionTracker = CalledByGoTLS
case UtlsAboutToCall:
s.loadSessionTracker = CalledByULoadSession
default:
panic("tls: LoadSessionCoordinator.onLoadSessionReturn failed: unimplemented state")
}
s.callingLoadSession = false
}
func (s *sessionController) shouldWriteBinders() bool {
uAssert(s.callingLoadSession, "tls: shouldWriteBinders failed: LoadSessionCoordinator isn't loading sessions, perhaps this function is not being called by loadSession.")
switch s.loadSessionTracker {
case NeverCalled:
return true
case UtlsAboutToCall:
return false
default:
panic("tls: shouldWriteBinders failed: unimplemented state")
}
}

View file

@ -802,22 +802,19 @@ func (e *SCTExtension) Write(_ []byte) (int, error) {
// SessionTicketExtension implements session_ticket (35)
type SessionTicketExtension struct {
Session *ClientSessionState
Session *SessionState
Ticket []byte
}
func (e *SessionTicketExtension) writeToUConn(uc *UConn) error {
if e.Session != nil {
uc.HandshakeState.Session = e.Session.session
uc.HandshakeState.Hello.SessionTicket = e.Session.ticket
}
// session states are handled later. At this point tickets aren't
// being loaded by utls, so don't write anything to the UConn.
uc.HandshakeState.Hello.TicketSupported = true // This doesn't really matter, this field is only used to add session ticket ext in go tls.
return nil
}
func (e *SessionTicketExtension) Len() int {
if e.Session != nil {
return 4 + len(e.Session.ticket)
}
return 4
return 4 + len(e.Ticket)
}
func (e *SessionTicketExtension) Read(b []byte) (int, error) {
@ -832,7 +829,7 @@ func (e *SessionTicketExtension) Read(b []byte) (int, error) {
b[2] = byte(extBodyLen >> 8)
b[3] = byte(extBodyLen)
if extBodyLen > 0 {
copy(b[4:], e.Session.ticket)
copy(b[4:], e.Ticket)
}
return e.Len(), io.EOF
}