mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-05 21:17:35 +03:00
feat: bug fix and refactor
This commit is contained in:
parent
3162534cc7
commit
a040a404e6
13 changed files with 675 additions and 349 deletions
|
@ -157,7 +157,7 @@ Currently, there is a simple function to set session ticket to any desired state
|
|||
|
||||
```Golang
|
||||
// If you want you session tickets to be reused - use same cache on following connections
|
||||
func (uconn *UConn) SetSessionState(session *ClientSessionState)
|
||||
func (uconn *UConn) SetSessionState12(session *ClientSessionState)
|
||||
```
|
||||
|
||||
Note that session tickets (fake ones or otherwise) are not reused.
|
||||
|
@ -294,7 +294,7 @@ Some customizations(such as setting session ticket/clientHello) have easy-to-use
|
|||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
masterSecret,
|
||||
nil, nil)
|
||||
tlsConn.SetSessionState(sessionState)
|
||||
tlsConn.SetSessionState12(sessionState)
|
||||
```
|
||||
|
||||
For other customizations there are following functions
|
||||
|
|
|
@ -140,7 +140,7 @@ func HttpGetTicket(hostname string, addr string) (*http.Response, error) {
|
|||
masterSecret,
|
||||
nil, nil)
|
||||
|
||||
err = uTlsConn.SetSessionState(sessionState)
|
||||
err = uTlsConn.SetSessionState12(sessionState)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -174,7 +174,7 @@ func HttpGetTicketHelloID(hostname string, addr string, helloID tls.ClientHelloI
|
|||
masterSecret,
|
||||
nil, nil)
|
||||
|
||||
uTlsConn.SetSessionState(sessionState)
|
||||
uTlsConn.SetSessionState12(sessionState)
|
||||
err = uTlsConn.Handshake()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("uTlsConn.Handshake() error: %+v", err)
|
||||
|
|
|
@ -38,7 +38,7 @@ func (csc *ClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState
|
|||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
func runPskCheck(helloID tls.ClientHelloID) {
|
||||
const serverAddr string = "refraction.network:443"
|
||||
csc := NewClientSessionCache()
|
||||
tcpConn, err := net.Dial("tcp", serverAddr)
|
||||
|
@ -53,7 +53,7 @@ func main() {
|
|||
ServerName: strings.Split(serverAddr, ":")[0],
|
||||
// NextProtos: []string{"h2", "http/1.1"},
|
||||
ClientSessionCache: csc, // set this so session tickets will be saved
|
||||
}, tls.HelloChrome_100)
|
||||
}, helloID)
|
||||
|
||||
// HS
|
||||
err = tlsConn.Handshake()
|
||||
|
@ -88,10 +88,12 @@ func main() {
|
|||
tlsConnPSK := tls.UClient(tcpConnPSK, &tls.Config{
|
||||
ServerName: strings.Split(serverAddr, ":")[0],
|
||||
ClientSessionCache: csc,
|
||||
}, tls.HelloChrome_100_PSK, &tls.UtlsPreSharedKeyExtension{})
|
||||
}, helloID)
|
||||
|
||||
// HS
|
||||
err = tlsConnPSK.Handshake()
|
||||
fmt.Println(tlsConnPSK.HandshakeState.Hello.Raw)
|
||||
fmt.Println(tlsConnPSK.HandshakeState.Hello.PskIdentities)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -111,3 +113,8 @@ func main() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
runPskCheck(tls.HelloChrome_100_PSK)
|
||||
runPskCheck(tls.HelloGolang)
|
||||
}
|
||||
|
|
|
@ -312,6 +312,12 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
|
|||
|
||||
func (c *Conn) loadSession(hello *clientHelloMsg) (
|
||||
session *SessionState, earlySecret, binderKey []byte, err error) {
|
||||
// [UTLS SECTION START]
|
||||
if c.utls.sessionController != nil {
|
||||
c.utls.sessionController.onEnterLoadSessionCheck()
|
||||
defer c.utls.sessionController.onLoadSessionReturn()
|
||||
}
|
||||
// [UTLS SECTION END]
|
||||
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
|
||||
return nil, nil, nil, nil
|
||||
}
|
||||
|
@ -324,12 +330,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
|
|||
hello.pskModes = []uint8{pskModeDHE}
|
||||
}
|
||||
|
||||
// [UTLS BEGINS]
|
||||
if c.utls.session != nil {
|
||||
return c.utls.session, c.utls.earlySecret, c.utls.binderKey, nil
|
||||
}
|
||||
// [UTLS ENDS]
|
||||
|
||||
// Session resumption is not allowed if renegotiating because
|
||||
// renegotiation is primarily used to allow a client to send a client
|
||||
// certificate, which would be skipped if session resumption occurred.
|
||||
|
@ -456,6 +456,11 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
|
|||
// Compute the PSK binders. See RFC 8446, Section 4.2.11.2.
|
||||
earlySecret = cipherSuite.extract(session.secret, nil)
|
||||
binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil)
|
||||
// [UTLS SECTION START]
|
||||
if c.utls.sessionController != nil && !c.utls.sessionController.shouldWriteBinders() {
|
||||
return
|
||||
}
|
||||
// [UTLS SECTION END]
|
||||
transcript := cipherSuite.hash.New()
|
||||
helloBytes, err := hello.marshalWithoutBinders()
|
||||
if err != nil {
|
||||
|
@ -466,11 +471,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
|
|||
if err := hello.updateBinders(pskBinders); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
c.utls.session = session // [uTLS]
|
||||
c.utls.earlySecret = earlySecret // [uTLS]
|
||||
c.utls.binderKey = binderKey // [uTLS]
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ func testClientHelloSpecJSONUnmarshaler(
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
truthSpec, _ := utlsIdToSpec(truthClientHelloID)
|
||||
truthSpec, _ := utlsIdToSpec(truthClientHelloID, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{})
|
||||
jsonSpec := chsju.ClientHelloSpec()
|
||||
|
||||
// Compare CipherSuites
|
||||
|
@ -85,7 +85,7 @@ func testClientHelloSpecUnmarshalJSON(
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
truthSpec, _ := utlsIdToSpec(truthClientHelloID)
|
||||
truthSpec, _ := utlsIdToSpec(truthClientHelloID, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{})
|
||||
|
||||
// Compare CipherSuites
|
||||
if !reflect.DeepEqual(jsonSpec.CipherSuites, truthSpec.CipherSuites) {
|
||||
|
|
35
u_common.go
35
u_common.go
|
@ -726,3 +726,38 @@ func EnableWeakCiphers() {
|
|||
suiteECDHE | suiteTLS12 | suiteSHA384, cipherAES, utlsMacSHA384, nil},
|
||||
}...)
|
||||
}
|
||||
|
||||
func panicOnNil(failureMsg string, params ...any) {
|
||||
for i, p := range params {
|
||||
if p == nil {
|
||||
panic(fmt.Sprintf("%s: the [%d] parameter is nil", failureMsg, i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func anyTrue[T any](slice []T, predicate func(t *T) bool) bool {
|
||||
for i := 0; i < len(slice); i++ {
|
||||
if predicate(&slice[i]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func uAssert(condition bool, msg string) {
|
||||
if !condition {
|
||||
panic(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func sliceEq(sliceA []any, sliceB []any) bool {
|
||||
if len(sliceA) != len(sliceB) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(sliceA); i++ {
|
||||
if sliceA[i] != sliceB[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
189
u_conn.go
189
u_conn.go
|
@ -14,23 +14,26 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type ClientHelloBuildStatus int
|
||||
|
||||
const NotBuilt ClientHelloBuildStatus = 0
|
||||
const BuildByUtls ClientHelloBuildStatus = 1
|
||||
const BuildByGoTLS ClientHelloBuildStatus = 2
|
||||
|
||||
type UConn struct {
|
||||
*Conn
|
||||
|
||||
Extensions []TLSExtension
|
||||
ClientHelloID ClientHelloID
|
||||
pskExtension []PreSharedKeyExtension
|
||||
Extensions []TLSExtension
|
||||
ClientHelloID ClientHelloID
|
||||
sessionController *sessionController
|
||||
|
||||
ClientHelloBuilt bool
|
||||
HandshakeState PubClientHandshakeState
|
||||
clientHelloBuildStatus ClientHelloBuildStatus
|
||||
|
||||
// sessionID may or may not depend on ticket; nil => random
|
||||
GetSessionID func(ticket []byte) [32]byte
|
||||
HandshakeState PubClientHandshakeState
|
||||
|
||||
greaseSeed [ssl_grease_last_index]uint16
|
||||
|
||||
|
@ -44,15 +47,17 @@ type UConn struct {
|
|||
|
||||
// UClient returns a new uTLS client, with behavior depending on clientHelloID.
|
||||
// Config CAN be nil, but make sure to eventually specify ServerName.
|
||||
func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID, pskExtension ...PreSharedKeyExtension) *UConn {
|
||||
func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
tlsConn := Conn{conn: conn, config: config, isClient: true}
|
||||
handshakeState := PubClientHandshakeState{C: &tlsConn, Hello: &PubClientHelloMsg{}}
|
||||
uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, pskExtension: pskExtension, HandshakeState: handshakeState}
|
||||
uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, HandshakeState: handshakeState}
|
||||
uconn.HandshakeState.uconn = &uconn
|
||||
uconn.handshakeFn = uconn.clientHandshake
|
||||
uconn.sessionController = newSessionController(&uconn)
|
||||
uconn.utls.sessionController = uconn.sessionController
|
||||
return &uconn
|
||||
}
|
||||
|
||||
|
@ -73,9 +78,10 @@ func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID, pskExte
|
|||
// default/mimicked ClientHello.
|
||||
func (uconn *UConn) BuildHandshakeState() error {
|
||||
if uconn.ClientHelloID == HelloGolang {
|
||||
if uconn.ClientHelloBuilt {
|
||||
if uconn.clientHelloBuildStatus == BuildByGoTLS {
|
||||
return nil
|
||||
}
|
||||
uAssert(uconn.clientHelloBuildStatus == NotBuilt, "BuildHandshakeState failed: invalid call, client hello has already been built by utls")
|
||||
|
||||
// use default Golang ClientHello.
|
||||
hello, keySharePrivate, err := uconn.makeClientHello()
|
||||
|
@ -92,8 +98,10 @@ func (uconn *UConn) BuildHandshakeState() error {
|
|||
return fmt.Errorf("uTLS: unknown keySharePrivate type: %T", keySharePrivate)
|
||||
}
|
||||
uconn.HandshakeState.C = uconn.Conn
|
||||
uconn.clientHelloBuildStatus = BuildByGoTLS
|
||||
} else {
|
||||
if !uconn.ClientHelloBuilt {
|
||||
uAssert(uconn.clientHelloBuildStatus == BuildByUtls || uconn.clientHelloBuildStatus == NotBuilt, "BuildHandshakeState failed: invalid call, client hello has already been built by go-tls")
|
||||
if uconn.clientHelloBuildStatus == NotBuilt {
|
||||
err := uconn.applyPresetByID(uconn.ClientHelloID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -107,51 +115,93 @@ func (uconn *UConn) BuildHandshakeState() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = uconn.uLoadSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = uconn.MarshalClientHello()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uconn.uApplyPatch()
|
||||
|
||||
uconn.sessionController.finalCheck()
|
||||
uconn.clientHelloBuildStatus = BuildByUtls
|
||||
}
|
||||
uconn.ClientHelloBuilt = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetSessionState sets the session ticket, which may be preshared or fake.
|
||||
func (uconn *UConn) uLoadSession() error {
|
||||
if cfg := uconn.config; cfg.SessionTicketsDisabled || cfg.ClientSessionCache == nil {
|
||||
return nil
|
||||
}
|
||||
switch uconn.sessionController.shouldLoadSession() {
|
||||
case shouldReturn:
|
||||
case shouldSetTicket:
|
||||
uconn.sessionController.setSessionTicketToUConn()
|
||||
case shouldSetPsk:
|
||||
uconn.sessionController.setPsk()
|
||||
case shouldLoad:
|
||||
hello := uconn.HandshakeState.Hello.getPrivatePtr()
|
||||
uconn.sessionController.aboutToLoadSession()
|
||||
session, earlySecret, binderKey, err := uconn.loadSession(hello)
|
||||
if session == nil || err != nil {
|
||||
return err
|
||||
}
|
||||
if session.version == VersionTLS12 {
|
||||
// We use the session ticket extension for tls 1.2 session resumption
|
||||
uconn.sessionController.initSessionTicketExt(session, hello.sessionTicket)
|
||||
uconn.sessionController.setSessionTicketToUConn()
|
||||
} else {
|
||||
uconn.sessionController.initPsk(session, earlySecret, binderKey, hello.pskIdentities)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uconn *UConn) uApplyPatch() {
|
||||
if uconn.sessionController.shouldUpdateBinders() {
|
||||
uconn.sessionController.updateBinders()
|
||||
uconn.sessionController.setPsk()
|
||||
}
|
||||
}
|
||||
|
||||
// SetSessionState12 sets the session ticket, which may be preshared or fake.
|
||||
// If session is nil, the body of session ticket extension will be unset,
|
||||
// but the extension itself still MAY be present for mimicking purposes.
|
||||
// Session tickets to be reused - use same cache on following connections.
|
||||
func (uconn *UConn) SetSessionState(session *ClientSessionState) error {
|
||||
var sessionTicket []uint8
|
||||
if session != nil {
|
||||
sessionTicket = session.ticket
|
||||
uconn.HandshakeState.Session = session.session
|
||||
func (uconn *UConn) SetSessionState12(session *ClientSessionState) error {
|
||||
if uconn.config.SessionTicketsDisabled || uconn.config.ClientSessionCache == nil {
|
||||
return fmt.Errorf("SetSessionState12 failed: session is disabled")
|
||||
}
|
||||
uconn.HandshakeState.Hello.TicketSupported = true
|
||||
uconn.HandshakeState.Hello.SessionTicket = sessionTicket
|
||||
|
||||
for _, ext := range uconn.Extensions {
|
||||
st, ok := ext.(*SessionTicketExtension)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
st.Session = session
|
||||
if session != nil {
|
||||
if len(session.SessionTicket()) > 0 {
|
||||
if uconn.GetSessionID != nil {
|
||||
sid := uconn.GetSessionID(session.SessionTicket())
|
||||
uconn.HandshakeState.Hello.SessionId = sid[:]
|
||||
return nil
|
||||
}
|
||||
}
|
||||
var sessionID [32]byte
|
||||
_, err := io.ReadFull(uconn.config.rand(), sessionID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uconn.HandshakeState.Hello.SessionId = sessionID[:]
|
||||
}
|
||||
if session == nil {
|
||||
return nil
|
||||
}
|
||||
if session.session == nil {
|
||||
return fmt.Errorf("SetSessionState12 failed: session must not be nil")
|
||||
}
|
||||
if session.session.version != VersionTLS12 {
|
||||
return fmt.Errorf("SetSessionState12 failed: SetSessionState12 only works for tls 1.2 session ticket; for tls 1.3 please customize PSK with SetSessionState13()")
|
||||
}
|
||||
uconn.sessionController.initSessionTicketExt(session.session, session.ticket)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetSessionState13 sets the psk extension for tls 1.3 resumption
|
||||
func (uconn *UConn) SetSessionState13(psk PreSharedKeyExtension) error {
|
||||
if uconn.config.SessionTicketsDisabled || uconn.config.ClientSessionCache == nil {
|
||||
return fmt.Errorf("SetSessionState13 failed: session is disabled")
|
||||
}
|
||||
if psk == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
uconn.HandshakeState.Hello.TicketSupported = true
|
||||
uconn.sessionController.overridePskExt(psk)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -397,7 +447,7 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
|
|||
hello := c.HandshakeState.Hello.getPrivatePtr()
|
||||
defer func() { c.HandshakeState.Hello = hello.getPublicPtr() }()
|
||||
|
||||
sessionIsAlreadySet := c.HandshakeState.Session != nil
|
||||
sessionIsLocked := c.utls.sessionController.isSessionLocked()
|
||||
|
||||
// after this point exactly 1 out of 2 HandshakeState pointers is non-nil,
|
||||
// useTLS13 variable tells which pointer
|
||||
|
@ -434,9 +484,24 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
|
|||
if c.handshakes > 0 {
|
||||
hello.secureRenegotiation = c.clientFinished[:]
|
||||
}
|
||||
// [uTLS section ends]
|
||||
|
||||
session, earlySecret, binderKey, err := c.loadSession(hello)
|
||||
var (
|
||||
session *SessionState
|
||||
earlySecret []byte
|
||||
binderKey []byte
|
||||
)
|
||||
if !sessionIsLocked {
|
||||
// [uTLS section ends]
|
||||
|
||||
session, earlySecret, binderKey, err = c.loadSession(hello)
|
||||
|
||||
// [uTLS section start]
|
||||
} else {
|
||||
session = c.HandshakeState.Session
|
||||
earlySecret = c.HandshakeState.State13.EarlySecret
|
||||
binderKey = c.HandshakeState.State13.BinderKey
|
||||
}
|
||||
// [uTLS section ends]
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -491,7 +556,7 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
|
|||
hs13.serverHello = serverHello
|
||||
hs13.hello = hello
|
||||
hs13.keySharesParams = NewKeySharesParameters()
|
||||
if !sessionIsAlreadySet {
|
||||
if !sessionIsLocked {
|
||||
hs13.earlySecret = earlySecret
|
||||
hs13.binderKey = binderKey
|
||||
hs13.session = session
|
||||
|
@ -547,7 +612,7 @@ func (uconn *UConn) MarshalClientHello() error {
|
|||
if paddingExt == nil {
|
||||
paddingExt = pe
|
||||
} else {
|
||||
return errors.New("multiple padding extensions!")
|
||||
return errors.New("multiple padding extensions")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -589,27 +654,8 @@ func (uconn *UConn) MarshalClientHello() error {
|
|||
if len(uconn.Extensions) > 0 {
|
||||
binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen))
|
||||
for _, ext := range uconn.Extensions {
|
||||
switch typedExt := ext.(type) {
|
||||
case PreSharedKeyExtension:
|
||||
// PSK extension is handled separately
|
||||
err := bufferedWriter.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("bufferedWriter.Flush(): %w", err)
|
||||
}
|
||||
hello.Raw = helloBuffer.Bytes()
|
||||
// prepare buffer
|
||||
buf := make([]byte, typedExt.Len())
|
||||
n, err := typedExt.ReadWithRawHello(hello.Raw, buf)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return fmt.Errorf("(*PreSharedKeyExtension).ReadWithRawHello(): %w", err)
|
||||
}
|
||||
if n != typedExt.Len() {
|
||||
return errors.New("uconn: PreSharedKeyExtension: read wrong number of bytes")
|
||||
}
|
||||
bufferedWriter.Write(buf)
|
||||
hello.PskBinders = typedExt.GetBinders()
|
||||
default:
|
||||
bufferedWriter.ReadFrom(ext)
|
||||
if _, err := bufferedWriter.ReadFrom(ext); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -801,8 +847,5 @@ type utlsConnExtraFields struct {
|
|||
peerApplicationSettings []byte
|
||||
localApplicationSettings []byte
|
||||
|
||||
// session resumption (PSK)
|
||||
session *SessionState
|
||||
earlySecret []byte
|
||||
binderKey []byte
|
||||
sessionController *sessionController
|
||||
}
|
||||
|
|
|
@ -252,7 +252,7 @@ func TestUTLSFingerprintClientHelloBluntMimicry(t *testing.T) {
|
|||
var extensionId uint16 = 0xfeed
|
||||
extensionData := []byte("random data")
|
||||
|
||||
specWithGeneric, err := utlsIdToSpec(HelloChrome_Auto)
|
||||
specWithGeneric, err := utlsIdToSpec(HelloChrome_Auto, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{})
|
||||
if err != nil {
|
||||
t.Errorf("got error: %v; expected to succeed", err)
|
||||
}
|
||||
|
@ -293,11 +293,11 @@ func TestUTLSFingerprintClientHelloBluntMimicry(t *testing.T) {
|
|||
func TestUTLSFingerprintClientHelloAlwaysAddPadding(t *testing.T) {
|
||||
serverName := "foobar"
|
||||
|
||||
specWithoutPadding, err := utlsIdToSpec(HelloIOS_12_1)
|
||||
specWithoutPadding, err := utlsIdToSpec(HelloIOS_12_1, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{})
|
||||
if err != nil {
|
||||
t.Errorf("got error: %v; expected to succeed", err)
|
||||
}
|
||||
specWithPadding, err := utlsIdToSpec(HelloChrome_83)
|
||||
specWithPadding, err := utlsIdToSpec(HelloChrome_83, &UtlsPreSharedKeyExtension{}, &SessionTicketExtension{})
|
||||
if err != nil {
|
||||
t.Errorf("got error: %v; expected to succeed", err)
|
||||
}
|
||||
|
|
141
u_parrots.go
141
u_parrots.go
|
@ -17,26 +17,18 @@ import (
|
|||
)
|
||||
|
||||
var ErrUnknownClientHelloID = errors.New("tls: unknown ClientHelloID")
|
||||
var ErrNotPSKClientHelloID = errors.New("tls: ClientHello does not contain pre_shared_key extension")
|
||||
var ErrPSKExtensionExpected = errors.New("tls: pre_shared_key extension expected when fetching preset ClientHelloSpec")
|
||||
|
||||
// UTLSIdToSpec converts a ClientHelloID to a corresponding ClientHelloSpec.
|
||||
//
|
||||
// Exported internal function utlsIdToSpec per request.
|
||||
func UTLSIdToSpec(id ClientHelloID, pskExtension ...PreSharedKeyExtension) (ClientHelloSpec, error) {
|
||||
if len(pskExtension) > 1 {
|
||||
return ClientHelloSpec{}, errors.New("tls: at most one PreSharedKeyExtensions is allowed")
|
||||
}
|
||||
|
||||
chs, err := utlsIdToSpec(id)
|
||||
if err != nil && errors.Is(err, ErrUnknownClientHelloID) {
|
||||
chs, err = utlsIdToSpecWithPSK(id, pskExtension...)
|
||||
}
|
||||
|
||||
return chs, err
|
||||
func UTLSIdToSpec(id ClientHelloID, pskExt PreSharedKeyExtension, sessionTicketExt *SessionTicketExtension) (ClientHelloSpec, error) {
|
||||
return utlsIdToSpec(id, pskExt, sessionTicketExt)
|
||||
}
|
||||
|
||||
func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
||||
func utlsIdToSpec(id ClientHelloID, pskExt PreSharedKeyExtension, sessionTicketExt *SessionTicketExtension) (ClientHelloSpec, error) {
|
||||
if pskExt == nil || sessionTicketExt == nil {
|
||||
panic("utlsIdToSpec failed: pskExt and sessionTicketExt must be non-nil pointers")
|
||||
}
|
||||
switch id {
|
||||
case HelloChrome_58, HelloChrome_62:
|
||||
return ClientHelloSpec{
|
||||
|
@ -64,7 +56,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient},
|
||||
&SNIExtension{},
|
||||
&ExtendedMasterSecretExtension{},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
ECDSAWithP256AndSHA256,
|
||||
PSSWithSHA256,
|
||||
|
@ -119,7 +111,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient},
|
||||
&SNIExtension{},
|
||||
&ExtendedMasterSecretExtension{},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
ECDSAWithP256AndSHA256,
|
||||
PSSWithSHA256,
|
||||
|
@ -198,7 +190,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{
|
||||
0x00, // pointFormatUncompressed
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
|
@ -271,7 +263,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{
|
||||
0x00, // pointFormatUncompressed
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
|
@ -343,7 +335,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{
|
||||
0x00, // pointFormatUncompressed
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
|
@ -415,7 +407,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{
|
||||
0x00, // pointFormatUncompressed
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
|
@ -488,7 +480,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{
|
||||
0x00, // pointFormatUncompressed
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
|
@ -559,7 +551,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{
|
||||
0x00, // pointFormatUncompressed
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
|
@ -632,7 +624,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{
|
||||
0x00, // pointFormatUncompressed
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
|
@ -695,7 +687,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient},
|
||||
&SupportedCurvesExtension{[]CurveID{X25519, CurveP256, CurveP384, CurveP521}},
|
||||
&SupportedPointsExtension{SupportedPoints: []byte{pointFormatUncompressed}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
|
@ -757,7 +749,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{
|
||||
pointFormatUncompressed,
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&KeyShareExtension{[]KeyShare{
|
||||
|
@ -828,7 +820,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{ //ec_point_formats
|
||||
pointFormatUncompressed,
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, //application_layer_protocol_negotiation
|
||||
&StatusRequestExtension{},
|
||||
&FakeDelegatedCredentialsExtension{
|
||||
|
@ -909,7 +901,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{ //ec_point_formats
|
||||
pointFormatUncompressed,
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2"}}, //application_layer_protocol_negotiation
|
||||
&StatusRequestExtension{},
|
||||
&FakeDelegatedCredentialsExtension{
|
||||
|
@ -994,7 +986,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
0x0, // uncompressed
|
||||
},
|
||||
},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{
|
||||
AlpnProtocols: []string{
|
||||
"h2",
|
||||
|
@ -1426,7 +1418,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
0x0, // pointFormatUncompressed
|
||||
},
|
||||
},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{
|
||||
AlpnProtocols: []string{
|
||||
"h2",
|
||||
|
@ -1531,7 +1523,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
0x0, // uncompressed
|
||||
},
|
||||
},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{
|
||||
AlpnProtocols: []string{
|
||||
"h2",
|
||||
|
@ -1749,7 +1741,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
0x0, // pointFormatUncompressed
|
||||
},
|
||||
},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&NPNExtension{},
|
||||
&ALPNExtension{
|
||||
AlpnProtocols: []string{
|
||||
|
@ -1823,7 +1815,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
0x0, // uncompressed
|
||||
},
|
||||
},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{
|
||||
AlpnProtocols: []string{
|
||||
"h2",
|
||||
|
@ -1931,7 +1923,7 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
0x0, // uncompressed
|
||||
},
|
||||
},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{
|
||||
AlpnProtocols: []string{
|
||||
"h2",
|
||||
|
@ -1995,24 +1987,6 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
|
|||
},
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
if id.Client == helloRandomized || id.Client == helloRandomizedALPN || id.Client == helloRandomizedNoALPN {
|
||||
// Use empty values as they can be filled later by UConn.ApplyPreset or manually.
|
||||
return generateRandomizedSpec(&id, "", nil, nil)
|
||||
}
|
||||
return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrUnknownClientHelloID, id.Str())
|
||||
}
|
||||
}
|
||||
|
||||
func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension) (ClientHelloSpec, error) {
|
||||
switch id {
|
||||
case HelloChrome_100_PSK, HelloChrome_112_PSK_Shuf, HelloChrome_114_Padding_PSK_Shuf, HelloChrome_115_PQ_PSK:
|
||||
if len(pskExtension) == 0 || pskExtension[0] == nil {
|
||||
return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrPSKExtensionExpected, id.Str())
|
||||
}
|
||||
}
|
||||
|
||||
switch id {
|
||||
case HelloChrome_100_PSK:
|
||||
return ClientHelloSpec{
|
||||
CipherSuites: []uint16{
|
||||
|
@ -2050,7 +2024,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{
|
||||
0x00, // pointFormatUncompressed
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
|
@ -2081,7 +2055,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
|
|||
}},
|
||||
&ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
|
||||
&UtlsGREASEExtension{},
|
||||
pskExtension[0],
|
||||
pskExt,
|
||||
},
|
||||
}, nil
|
||||
case HelloChrome_112_PSK_Shuf:
|
||||
|
@ -2121,7 +2095,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{
|
||||
0x00, // pointFormatUncompressed
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
|
@ -2152,7 +2126,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
|
|||
}},
|
||||
&ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
|
||||
&UtlsGREASEExtension{},
|
||||
pskExtension[0],
|
||||
pskExt,
|
||||
}),
|
||||
}, nil
|
||||
case HelloChrome_114_Padding_PSK_Shuf:
|
||||
|
@ -2192,7 +2166,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{
|
||||
0x00, // pointFormatUncompressed
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
|
@ -2224,7 +2198,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
|
|||
&ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
|
||||
&UtlsGREASEExtension{},
|
||||
&UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle},
|
||||
pskExtension[0],
|
||||
pskExt,
|
||||
}),
|
||||
}, nil
|
||||
// Chrome w/ Post-Quantum Key Agreement
|
||||
|
@ -2266,7 +2240,7 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
|
|||
&SupportedPointsExtension{SupportedPoints: []byte{
|
||||
0x00, // pointFormatUncompressed
|
||||
}},
|
||||
&SessionTicketExtension{},
|
||||
sessionTicketExt,
|
||||
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&StatusRequestExtension{},
|
||||
&SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
|
||||
|
@ -2298,12 +2272,17 @@ func utlsIdToSpecWithPSK(id ClientHelloID, pskExtension ...PreSharedKeyExtension
|
|||
}},
|
||||
&ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
|
||||
&UtlsGREASEExtension{},
|
||||
pskExtension[0],
|
||||
pskExt,
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
default:
|
||||
if id.Client == helloRandomized || id.Client == helloRandomizedALPN || id.Client == helloRandomizedNoALPN {
|
||||
// Use empty values as they can be filled later by UConn.ApplyPreset or manually.
|
||||
return generateRandomizedSpec(&id, "", nil)
|
||||
}
|
||||
|
||||
return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrUnknownClientHelloID, id.Str())
|
||||
return ClientHelloSpec{}, fmt.Errorf("%w: %s", ErrUnknownClientHelloID, id.Str())
|
||||
}
|
||||
}
|
||||
|
||||
// ShuffleChromeTLSExtensions shuffles the extensions in the ClientHelloSpec to avoid ossification.
|
||||
|
@ -2345,9 +2324,8 @@ func (uconn *UConn) applyPresetByID(id ClientHelloID) (err error) {
|
|||
}
|
||||
case helloCustom:
|
||||
return nil
|
||||
|
||||
default:
|
||||
spec, err = UTLSIdToSpec(id, uconn.pskExtension...)
|
||||
spec, err = UTLSIdToSpec(id, uconn.sessionController.pskExtension, uconn.sessionController.sessionTicketExt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -2379,7 +2357,6 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
|
|||
}
|
||||
uconn.HandshakeState.State13.KeySharesParams = NewKeySharesParameters()
|
||||
hello := uconn.HandshakeState.Hello
|
||||
session := uconn.HandshakeState.Session
|
||||
|
||||
switch len(hello.Random) {
|
||||
case 0:
|
||||
|
@ -2420,7 +2397,12 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
|
|||
hello.CipherSuites[i] = GetBoringGREASEValue(uconn.greaseSeed, ssl_grease_cipher)
|
||||
}
|
||||
}
|
||||
uconn.GetSessionID = p.GetSessionID
|
||||
var sessionID [32]byte
|
||||
_, err = io.ReadFull(uconn.config.rand(), sessionID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uconn.HandshakeState.Hello.SessionId = sessionID[:]
|
||||
uconn.Extensions = make([]TLSExtension, len(p.Extensions))
|
||||
copy(uconn.Extensions, p.Extensions)
|
||||
|
||||
|
@ -2445,23 +2427,6 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
|
|||
return errors.New("at most 2 grease extensions are supported")
|
||||
}
|
||||
grease_extensions_seen += 1
|
||||
case *SessionTicketExtension:
|
||||
var cs *ClientSessionState
|
||||
if session == nil && uconn.config.ClientSessionCache != nil {
|
||||
cacheKey := uconn.clientSessionCacheKey()
|
||||
cs, _ = uconn.config.ClientSessionCache.Get(cacheKey)
|
||||
if cs != nil {
|
||||
session = cs.session
|
||||
}
|
||||
}
|
||||
// TLS 1.3 (PSK) resumption is handled by PreSharedKeyExtension in MarshalClientHello()
|
||||
if session != nil && session.version == VersionTLS13 {
|
||||
break
|
||||
}
|
||||
err := uconn.SetSessionState(cs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case *SupportedCurvesExtension:
|
||||
for i := range ext.Curves {
|
||||
if isGREASEUint16(uint16(ext.Curves[i])) {
|
||||
|
@ -2528,22 +2493,18 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
|
|||
// but NextProtos is also used by ALPN and our spec nmay not actually have a NPN extension
|
||||
hello.NextProtoNeg = haveNPN
|
||||
|
||||
uconn.sessionController.checkSessionExt()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uconn *UConn) generateRandomizedSpec() (ClientHelloSpec, error) {
|
||||
css := &ClientSessionState{
|
||||
session: uconn.HandshakeState.Session,
|
||||
ticket: uconn.HandshakeState.Hello.SessionTicket,
|
||||
}
|
||||
|
||||
return generateRandomizedSpec(&uconn.ClientHelloID, uconn.serverName, css, uconn.config.NextProtos)
|
||||
return generateRandomizedSpec(&uconn.ClientHelloID, uconn.serverName, uconn.config.NextProtos)
|
||||
}
|
||||
|
||||
func generateRandomizedSpec(
|
||||
id *ClientHelloID,
|
||||
serverName string,
|
||||
session *ClientSessionState,
|
||||
nextProtos []string,
|
||||
) (ClientHelloSpec, error) {
|
||||
p := ClientHelloSpec{}
|
||||
|
@ -2609,7 +2570,7 @@ func generateRandomizedSpec(
|
|||
p.CipherSuites = removeRandomCiphers(r, shuffledSuites, id.Weights.CipherSuites_Remove_RandomCiphers)
|
||||
|
||||
sni := SNIExtension{serverName}
|
||||
sessionTicket := SessionTicketExtension{Session: session}
|
||||
sessionTicket := SessionTicketExtension{}
|
||||
|
||||
sigAndHashAlgos := []SignatureScheme{
|
||||
ECDSAWithP256AndSHA256,
|
||||
|
|
|
@ -8,26 +8,38 @@ import (
|
|||
"golang.org/x/crypto/cryptobyte"
|
||||
)
|
||||
|
||||
type PreSharedKeyCommon struct {
|
||||
Identities []PskIdentity
|
||||
Binders [][]byte
|
||||
BinderKey []byte // this will be used to compute the binder when hello message is ready
|
||||
EarlySecret []byte
|
||||
Session *SessionState
|
||||
}
|
||||
|
||||
type PreSharedKeyExtension interface {
|
||||
// TLSExtension must be implemented by all PreSharedKeyExtension implementations.
|
||||
// However, the Read() method should return an error since it MUST NOT be used
|
||||
// for PreSharedKeyExtension.
|
||||
TLSExtension
|
||||
|
||||
IsInitialized() bool
|
||||
|
||||
InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity)
|
||||
|
||||
// GetBinders returns the binders that were computed during the handshake
|
||||
// to be set in the internal copy of the ClientHello. Only needed if expecting
|
||||
// to resume the session.
|
||||
//
|
||||
// FakePreSharedKeyExtension MUST return nil to make sure utls DOES NOT
|
||||
// try to do any session resumption.
|
||||
GetBinders() [][]byte
|
||||
GetPreSharedKeyCommon() PreSharedKeyCommon
|
||||
|
||||
// ReadWithRawHello is used to read the extension from the ClientHello
|
||||
// instead of Read(), where the latter is used to read all other extensions.
|
||||
//
|
||||
// This is needed because the PSK extension needs to calculate the binder
|
||||
// based on all previous parts of the ClientHello.
|
||||
ReadWithRawHello(raw, b []byte) (int, error)
|
||||
PatchBuiltHello(hello *PubClientHelloMsg) error
|
||||
|
||||
mustEmbedUnimplementedPreSharedKeyExtension() // this works like a type guard
|
||||
}
|
||||
|
@ -36,8 +48,16 @@ type UnimplementedPreSharedKeyExtension struct{}
|
|||
|
||||
func (UnimplementedPreSharedKeyExtension) mustEmbedUnimplementedPreSharedKeyExtension() {}
|
||||
|
||||
func (*UnimplementedPreSharedKeyExtension) IsInitialized() bool {
|
||||
panic("tls: IsInitialized is not implemented for the PreSharedKeyExtension")
|
||||
}
|
||||
|
||||
func (*UnimplementedPreSharedKeyExtension) InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) {
|
||||
panic("tls: Initialize is not implemented for the PreSharedKeyExtension")
|
||||
}
|
||||
|
||||
func (*UnimplementedPreSharedKeyExtension) writeToUConn(*UConn) error {
|
||||
return errors.New("tls: writeToUConn is not implemented for the PreSharedKeyExtension")
|
||||
panic("tls: writeToUConn is not implemented for the PreSharedKeyExtension")
|
||||
}
|
||||
|
||||
func (*UnimplementedPreSharedKeyExtension) Len() int {
|
||||
|
@ -45,108 +65,120 @@ func (*UnimplementedPreSharedKeyExtension) Len() int {
|
|||
}
|
||||
|
||||
func (*UnimplementedPreSharedKeyExtension) Read([]byte) (int, error) {
|
||||
return 0, errors.New("tls: Read is not implemented for the PreSharedKeyExtension")
|
||||
panic("tls: Read is not implemented for the PreSharedKeyExtension")
|
||||
}
|
||||
|
||||
func (*UnimplementedPreSharedKeyExtension) GetBinders() [][]byte {
|
||||
panic("tls: Binders is not implemented for the PreSharedKeyExtension")
|
||||
func (*UnimplementedPreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon {
|
||||
panic("tls: GetPreSharedKeyCommon is not implemented for the PreSharedKeyExtension")
|
||||
}
|
||||
|
||||
func (*UnimplementedPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error) {
|
||||
return 0, errors.New("tls: ReadWithRawHello is not implemented for the PreSharedKeyExtension")
|
||||
func (*UnimplementedPreSharedKeyExtension) PatchBuiltHello(hello *PubClientHelloMsg) error {
|
||||
panic("tls: ReadWithRawHello is not implemented for the PreSharedKeyExtension")
|
||||
}
|
||||
|
||||
// UtlsPreSharedKeyExtension is an extension used to set the PSK extension in the
|
||||
// ClientHello.
|
||||
type UtlsPreSharedKeyExtension struct {
|
||||
UnimplementedPreSharedKeyExtension
|
||||
PreSharedKeyCommon
|
||||
cipherSuite *cipherSuiteTLS13
|
||||
cachedLength *int
|
||||
}
|
||||
|
||||
identities []pskIdentity
|
||||
binders [][]byte
|
||||
binderKey []byte // this will be used to compute the binder when hello message is ready
|
||||
cipherSuite *cipherSuiteTLS13
|
||||
earlySecret []byte
|
||||
func (e *UtlsPreSharedKeyExtension) IsInitialized() bool {
|
||||
return e.Session != nil
|
||||
}
|
||||
|
||||
func (e *UtlsPreSharedKeyExtension) InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) {
|
||||
e.Session = session
|
||||
e.EarlySecret = earlySecret
|
||||
e.BinderKey = binderKey
|
||||
e.cipherSuite = cipherSuiteTLS13ByID(e.Session.cipherSuite)
|
||||
e.Identities = identities
|
||||
e.Binders = make([][]byte, 0, len(e.Identities))
|
||||
for i := 0; i < len(e.Identities); i++ {
|
||||
e.Binders = append(e.Binders, make([]byte, e.cipherSuite.hash.Size()))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *UtlsPreSharedKeyExtension) writeToUConn(uc *UConn) error {
|
||||
err := e.preloadSession(uc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uc.HandshakeState.Hello.PskIdentities = pskIdentities(e.identities).ToPublic()
|
||||
// uc.HandshakeState.Hello.PskBinders = e.binders
|
||||
// uc.HandshakeState.Hello = hello.getPublicPtr() // write back to public hello
|
||||
// uc.HandshakeState.State13.EarlySecret = e.earlySecret
|
||||
// uc.HandshakeState.State13.BinderKey = e.binderKey
|
||||
|
||||
uc.HandshakeState.Hello.TicketSupported = true // This doesn't matter though, as utls doesn't care about this field. We write this for consistency.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *UtlsPreSharedKeyExtension) Len() int {
|
||||
func (e *UtlsPreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon {
|
||||
return e.PreSharedKeyCommon
|
||||
}
|
||||
|
||||
func pskExtLen(identities []PskIdentity, binders [][]byte) int {
|
||||
if len(identities) == 0 || len(binders) == 0 {
|
||||
return 0
|
||||
}
|
||||
length := 4 // extension type + extension length
|
||||
length += 2 // identities length
|
||||
for _, identity := range e.identities {
|
||||
length += 2 + len(identity.label) + 4 // identity length + identity + obfuscated ticket age
|
||||
for _, identity := range identities {
|
||||
length += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
|
||||
}
|
||||
length += 2 // binders length
|
||||
for _, binder := range e.binders {
|
||||
length += len(binder) + 1 // binder length + binder
|
||||
for _, binder := range binders {
|
||||
length += len(binder) + 1
|
||||
}
|
||||
return length
|
||||
}
|
||||
|
||||
func (e *UtlsPreSharedKeyExtension) Read(b []byte) (int, error) {
|
||||
return 0, errors.New("tls: PreSharedKeyExtension shouldn't be read, use ReadWithRawHello() instead")
|
||||
func (e *UtlsPreSharedKeyExtension) Len() int {
|
||||
if e.Session == nil {
|
||||
return 0
|
||||
}
|
||||
if e.cachedLength != nil {
|
||||
return *e.cachedLength
|
||||
}
|
||||
length := pskExtLen(e.Identities, e.Binders)
|
||||
e.cachedLength = &length
|
||||
return length
|
||||
}
|
||||
|
||||
// Binders must be called after ReadWithRawHello
|
||||
func (e *UtlsPreSharedKeyExtension) GetBinders() [][]byte {
|
||||
return e.binders
|
||||
}
|
||||
|
||||
func (e *UtlsPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error) {
|
||||
if len(b) < e.Len() {
|
||||
func readPskIntoBytes(b []byte, identities []PskIdentity, binders [][]byte) (int, error) {
|
||||
extLen := pskExtLen(identities, binders)
|
||||
if extLen == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if len(b) < extLen {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
b[0] = byte(extensionPreSharedKey >> 8)
|
||||
b[1] = byte(extensionPreSharedKey)
|
||||
b[2] = byte((e.Len() - 4) >> 8)
|
||||
b[3] = byte(e.Len() - 4)
|
||||
b[2] = byte((extLen - 4) >> 8)
|
||||
b[3] = byte(extLen - 4)
|
||||
|
||||
// identities length
|
||||
identitiesLength := 0
|
||||
for _, identity := range e.identities {
|
||||
identitiesLength += 2 + len(identity.label) + 4 // identity length + identity + obfuscated ticket age
|
||||
for _, identity := range identities {
|
||||
identitiesLength += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
|
||||
}
|
||||
b[4] = byte(identitiesLength >> 8)
|
||||
b[5] = byte(identitiesLength)
|
||||
|
||||
// identities
|
||||
offset := 6
|
||||
for _, identity := range e.identities {
|
||||
b[offset] = byte(len(identity.label) >> 8)
|
||||
b[offset+1] = byte(len(identity.label))
|
||||
for _, identity := range identities {
|
||||
b[offset] = byte(len(identity.Label) >> 8)
|
||||
b[offset+1] = byte(len(identity.Label))
|
||||
offset += 2
|
||||
copy(b[offset:], identity.label)
|
||||
offset += len(identity.label)
|
||||
b[offset] = byte(identity.obfuscatedTicketAge >> 24)
|
||||
b[offset+1] = byte(identity.obfuscatedTicketAge >> 16)
|
||||
b[offset+2] = byte(identity.obfuscatedTicketAge >> 8)
|
||||
b[offset+3] = byte(identity.obfuscatedTicketAge)
|
||||
copy(b[offset:], identity.Label)
|
||||
offset += len(identity.Label)
|
||||
b[offset] = byte(identity.ObfuscatedTicketAge >> 24)
|
||||
b[offset+1] = byte(identity.ObfuscatedTicketAge >> 16)
|
||||
b[offset+2] = byte(identity.ObfuscatedTicketAge >> 8)
|
||||
b[offset+3] = byte(identity.ObfuscatedTicketAge)
|
||||
offset += 4
|
||||
}
|
||||
|
||||
// concatenate ClientHello and PreSharedKeyExtension
|
||||
rawHelloSoFar := append(raw, b[:offset]...)
|
||||
transcript := e.cipherSuite.hash.New()
|
||||
transcript.Write(rawHelloSoFar)
|
||||
e.binders = [][]byte{e.cipherSuite.finishedHash(e.binderKey, transcript)}
|
||||
|
||||
// binders length
|
||||
bindersLength := 0
|
||||
for _, binder := range e.binders {
|
||||
for _, binder := range binders {
|
||||
// check if binder size is valid
|
||||
bindersLength += len(binder) + 1 // binder length + binder
|
||||
}
|
||||
b[offset] = byte(bindersLength >> 8)
|
||||
|
@ -154,39 +186,49 @@ func (e *UtlsPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error)
|
|||
offset += 2
|
||||
|
||||
// binders
|
||||
for _, binder := range e.binders {
|
||||
for _, binder := range binders {
|
||||
b[offset] = byte(len(binder))
|
||||
offset++
|
||||
copy(b[offset:], binder)
|
||||
offset += len(binder)
|
||||
}
|
||||
|
||||
return e.Len(), io.EOF
|
||||
return extLen, io.EOF
|
||||
}
|
||||
|
||||
func (e *UtlsPreSharedKeyExtension) preloadSession(uc *UConn) error {
|
||||
// load Hello
|
||||
hello := uc.HandshakeState.Hello.getPrivatePtr()
|
||||
// try to use loadSession()
|
||||
session, earlySecret, binderKey, err := uc.loadSession(hello)
|
||||
func (e *UtlsPreSharedKeyExtension) Read(b []byte) (int, error) {
|
||||
return readPskIntoBytes(b, e.Identities, e.Binders)
|
||||
}
|
||||
|
||||
func (e *UtlsPreSharedKeyExtension) PatchBuiltHello(hello *PubClientHelloMsg) error {
|
||||
if e.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
private := hello.getCachedPrivatePtr()
|
||||
if private == nil {
|
||||
private = hello.getPrivatePtr()
|
||||
}
|
||||
private.raw = hello.Raw
|
||||
private.pskBinders = e.Binders // set the placeholder to the private Hello
|
||||
|
||||
//--- mirror loadSession() begin ---//
|
||||
transcript := e.cipherSuite.hash.New()
|
||||
helloBytes, err := private.marshalWithoutBinders() // no marshal() will be actually called, as we have set the field `raw`
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if session != nil && session.version == VersionTLS13 && binderKey != nil {
|
||||
e.identities = hello.pskIdentities
|
||||
e.binders = hello.pskBinders
|
||||
e.binderKey = binderKey
|
||||
e.cipherSuite = cipherSuiteTLS13ByID(session.cipherSuite)
|
||||
e.earlySecret = earlySecret
|
||||
} else if session == nil {
|
||||
return errors.New("tls: session not found.")
|
||||
} else if session.version != VersionTLS13 {
|
||||
return errors.New("tls: session is not for TLS 1.3.")
|
||||
} else if binderKey == nil {
|
||||
return errors.New("tls: binder key not found.")
|
||||
}
|
||||
transcript.Write(helloBytes)
|
||||
pskBinders := [][]byte{e.cipherSuite.finishedHash(e.BinderKey, transcript)}
|
||||
|
||||
return nil
|
||||
if err := private.updateBinders(pskBinders); err != nil {
|
||||
return err
|
||||
}
|
||||
//--- mirror loadSession() end ---//
|
||||
e.Binders = pskBinders
|
||||
|
||||
// no need to care about other PSK related fields, they will be handled separately
|
||||
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (e *UtlsPreSharedKeyExtension) Write(b []byte) (int, error) {
|
||||
|
@ -212,6 +254,14 @@ type FakePreSharedKeyExtension struct {
|
|||
Binders [][]byte `json:"binders"`
|
||||
}
|
||||
|
||||
func (e *FakePreSharedKeyExtension) IsInitialized() bool {
|
||||
return e.Identities != nil && e.Binders != nil
|
||||
}
|
||||
|
||||
func (e *FakePreSharedKeyExtension) InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity) {
|
||||
panic("InitializeByUtls failed: don't let utls initialize FakePreSharedKeyExtension; provide your own identities and binders or use UtlsPreSharedKeyExtension")
|
||||
}
|
||||
|
||||
func (e *FakePreSharedKeyExtension) writeToUConn(uc *UConn) error {
|
||||
if uc.config.ClientSessionCache == nil {
|
||||
return nil // don't write the extension if there is no session cache
|
||||
|
@ -225,85 +275,33 @@ func (e *FakePreSharedKeyExtension) writeToUConn(uc *UConn) error {
|
|||
}
|
||||
|
||||
func (e *FakePreSharedKeyExtension) Len() int {
|
||||
length := 4 // extension type + extension length
|
||||
length += 2 // identities length
|
||||
for _, identity := range e.Identities {
|
||||
length += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
|
||||
}
|
||||
length += 2 // binders length
|
||||
for _, binder := range e.Binders {
|
||||
length += len(binder)
|
||||
}
|
||||
return length
|
||||
return pskExtLen(e.Identities, e.Binders)
|
||||
}
|
||||
|
||||
func (e *FakePreSharedKeyExtension) Read(b []byte) (int, error) {
|
||||
return 0, errors.New("tls: PreSharedKeyExtension shouldn't be read, use ReadWithRawHello() instead")
|
||||
}
|
||||
|
||||
func (e *FakePreSharedKeyExtension) GetBinders() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *FakePreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error) {
|
||||
if len(b) < e.Len() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
b[0] = byte(extensionPreSharedKey >> 8)
|
||||
b[1] = byte(extensionPreSharedKey)
|
||||
b[2] = byte((e.Len() - 4) >> 8)
|
||||
b[3] = byte(e.Len() - 4)
|
||||
|
||||
// identities length
|
||||
identitiesLength := 0
|
||||
for _, identity := range e.Identities {
|
||||
identitiesLength += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
|
||||
}
|
||||
b[4] = byte(identitiesLength >> 8)
|
||||
b[5] = byte(identitiesLength)
|
||||
|
||||
// identities
|
||||
offset := 6
|
||||
for _, identity := range e.Identities {
|
||||
b[offset] = byte(len(identity.Label) >> 8)
|
||||
b[offset+1] = byte(len(identity.Label))
|
||||
offset += 2
|
||||
copy(b[offset:], identity.Label)
|
||||
offset += len(identity.Label)
|
||||
b[offset] = byte(identity.ObfuscatedTicketAge >> 24)
|
||||
b[offset+1] = byte(identity.ObfuscatedTicketAge >> 16)
|
||||
b[offset+2] = byte(identity.ObfuscatedTicketAge >> 8)
|
||||
b[offset+3] = byte(identity.ObfuscatedTicketAge)
|
||||
offset += 4
|
||||
}
|
||||
|
||||
// binders length
|
||||
bindersLength := 0
|
||||
LOOP_BINDERS:
|
||||
for _, binder := range e.Binders {
|
||||
// check if binder size is valid
|
||||
for _, cipherSuite := range cipherSuitesTLS13 {
|
||||
if len(binder) == cipherSuite.hash.Size() {
|
||||
bindersLength += len(binder) + 1 // binder length + binder
|
||||
continue LOOP_BINDERS
|
||||
}
|
||||
for _, b := range e.Binders {
|
||||
if !(anyTrue(validHashLen, func(valid *int) bool {
|
||||
return len(b) == *valid
|
||||
})) {
|
||||
return 0, errors.New("tls: FakePreSharedKeyExtension.Read failed: invalid binder size")
|
||||
}
|
||||
return 0, errors.New("tls: invalid binder size")
|
||||
}
|
||||
b[offset] = byte(bindersLength >> 8)
|
||||
b[offset+1] = byte(bindersLength)
|
||||
offset += 2
|
||||
return readPskIntoBytes(b, e.Identities, e.Binders)
|
||||
}
|
||||
|
||||
// binders
|
||||
for _, binder := range e.Binders {
|
||||
b[offset] = byte(len(binder))
|
||||
offset++
|
||||
copy(b[offset:], binder)
|
||||
offset += len(binder)
|
||||
func (e *FakePreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon {
|
||||
return PreSharedKeyCommon{
|
||||
Identities: e.Identities,
|
||||
Binders: e.Binders,
|
||||
}
|
||||
}
|
||||
|
||||
return e.Len(), io.EOF
|
||||
var validHashLen = mapSlice(cipherSuitesTLS13, func(c *cipherSuiteTLS13) int {
|
||||
return c.hash.Size()
|
||||
})
|
||||
|
||||
func (*FakePreSharedKeyExtension) PatchBuiltHello(*PubClientHelloMsg) error {
|
||||
return nil // no need to patch the hello since we don't need to update binders
|
||||
}
|
||||
|
||||
func (e *FakePreSharedKeyExtension) Write(b []byte) (n int, err error) {
|
||||
|
|
19
u_public.go
19
u_public.go
|
@ -45,7 +45,7 @@ type TLS13OnlyState struct {
|
|||
EarlySecret []byte
|
||||
BinderKey []byte
|
||||
CertReq *CertificateRequestMsgTLS13
|
||||
UsingPSK bool
|
||||
UsingPSK bool // don't set this field when building client hello
|
||||
SentDummyCCS bool
|
||||
Transcript hash.Hash
|
||||
TrafficSecret []byte // client_application_traffic_secret_0
|
||||
|
@ -251,7 +251,7 @@ type PubServerHelloMsg struct {
|
|||
OcspStapling bool
|
||||
Scts [][]byte
|
||||
ExtendedMasterSecret bool
|
||||
TicketSupported bool
|
||||
TicketSupported bool // used by go tls to determine whether to add the session ticket ext
|
||||
SecureRenegotiation []byte
|
||||
SecureRenegotiationSupported bool
|
||||
AlpnProtocol string
|
||||
|
@ -357,13 +357,15 @@ type PubClientHelloMsg struct {
|
|||
PskIdentities []PskIdentity
|
||||
PskBinders [][]byte
|
||||
QuicTransportParameters []byte
|
||||
|
||||
cachedPrivateHello *clientHelloMsg // todo: further optimize to reduce clientHelloMsg construction
|
||||
}
|
||||
|
||||
func (chm *PubClientHelloMsg) getPrivatePtr() *clientHelloMsg {
|
||||
if chm == nil {
|
||||
return nil
|
||||
} else {
|
||||
return &clientHelloMsg{
|
||||
private := &clientHelloMsg{
|
||||
raw: chm.Raw,
|
||||
vers: chm.Vers,
|
||||
random: chm.Random,
|
||||
|
@ -395,6 +397,16 @@ func (chm *PubClientHelloMsg) getPrivatePtr() *clientHelloMsg {
|
|||
|
||||
nextProtoNeg: chm.NextProtoNeg,
|
||||
}
|
||||
chm.cachedPrivateHello = private
|
||||
return private
|
||||
}
|
||||
}
|
||||
|
||||
func (chm *PubClientHelloMsg) getCachedPrivatePtr() *clientHelloMsg {
|
||||
if chm == nil {
|
||||
return nil
|
||||
} else {
|
||||
return chm.cachedPrivateHello
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -432,6 +444,7 @@ func (chm *clientHelloMsg) getPublicPtr() *PubClientHelloMsg {
|
|||
PskIdentities: pskIdentities(chm.pskIdentities).ToPublic(),
|
||||
PskBinders: chm.pskBinders,
|
||||
QuicTransportParameters: chm.quicTransportParameters,
|
||||
cachedPrivateHello: chm,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
272
u_session_controller.go
Normal file
272
u_session_controller.go
Normal file
|
@ -0,0 +1,272 @@
|
|||
package tls
|
||||
|
||||
import "fmt"
|
||||
|
||||
type LoadSessionTrackerState int
|
||||
|
||||
const NeverCalled LoadSessionTrackerState = 0
|
||||
const UtlsAboutToCall LoadSessionTrackerState = 3
|
||||
const CalledByULoadSession LoadSessionTrackerState = 1
|
||||
const CalledByGoTLS LoadSessionTrackerState = 2
|
||||
|
||||
type sessionState int
|
||||
|
||||
const NoSession sessionState = 0
|
||||
const TicketInitialized sessionState = 1
|
||||
const TicketAllSet sessionState = 4
|
||||
const PskExtInitialized sessionState = 2
|
||||
const PskAllSet sessionState = 3
|
||||
|
||||
// sessionController is responsible for all session related
|
||||
type sessionController struct {
|
||||
sessionTicketExt *SessionTicketExtension
|
||||
pskExtension PreSharedKeyExtension
|
||||
uconnRef *UConn
|
||||
state sessionState
|
||||
loadSessionTracker LoadSessionTrackerState
|
||||
callingLoadSession bool
|
||||
locked bool
|
||||
}
|
||||
|
||||
type shouldLoadSessionResult int
|
||||
|
||||
const shouldReturn shouldLoadSessionResult = 0
|
||||
const shouldSetTicket shouldLoadSessionResult = 1
|
||||
const shouldSetPsk shouldLoadSessionResult = 2
|
||||
const shouldLoad shouldLoadSessionResult = 3
|
||||
|
||||
func newSessionController(uconn *UConn) *sessionController {
|
||||
return &sessionController{
|
||||
uconnRef: uconn,
|
||||
sessionTicketExt: &SessionTicketExtension{},
|
||||
pskExtension: &UtlsPreSharedKeyExtension{},
|
||||
state: NoSession,
|
||||
locked: false,
|
||||
callingLoadSession: false,
|
||||
loadSessionTracker: NeverCalled,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sessionController) isSessionLocked() bool {
|
||||
return s.locked
|
||||
}
|
||||
|
||||
func (s *sessionController) shouldLoadSession() shouldLoadSessionResult {
|
||||
if s.sessionTicketExt == nil && s.pskExtension == nil || s.uconnRef.clientHelloBuildStatus != NotBuilt {
|
||||
fmt.Println("uLoadSession s.sessionTicketExt == nil && s.pskExtension == nil")
|
||||
// There's no need to load session since we don't have the related extensions.
|
||||
return shouldReturn
|
||||
}
|
||||
if s.state == TicketInitialized {
|
||||
return shouldSetTicket
|
||||
}
|
||||
if s.state == PskExtInitialized {
|
||||
return shouldSetPsk
|
||||
}
|
||||
return shouldLoad
|
||||
}
|
||||
|
||||
func (s *sessionController) aboutToLoadSession() {
|
||||
uAssert(s.state == NoSession && !s.locked, "tls: aboutToLoadSession failed: must only load session when the session of the client hello is not locked and when there's currently no session")
|
||||
s.loadSessionTracker = UtlsAboutToCall
|
||||
}
|
||||
|
||||
func (s *sessionController) commonCheck(failureMsg string, params ...any) {
|
||||
if s.uconnRef.clientHelloBuildStatus != NotBuilt {
|
||||
panic(failureMsg + ": we can't modify the session after the clientHello is built")
|
||||
}
|
||||
if s.state != NoSession {
|
||||
panic(failureMsg + ": the session already set")
|
||||
}
|
||||
panicOnNil(failureMsg, params...)
|
||||
}
|
||||
|
||||
func (s *sessionController) finalCheck() {
|
||||
uAssert(s.state == PskAllSet || s.state == TicketAllSet || s.state == NoSession, "tls: SessionController.finalCheck failed: the session is half set")
|
||||
s.locked = true
|
||||
}
|
||||
|
||||
func (s *sessionController) initSessionTicketExt(session *SessionState, ticket []byte) {
|
||||
s.commonCheck("tls: initSessionTicket failed", s.sessionTicketExt, session, ticket)
|
||||
s.sessionTicketExt.Session = session
|
||||
s.sessionTicketExt.Ticket = ticket
|
||||
s.state = TicketInitialized
|
||||
}
|
||||
|
||||
func (s *sessionController) setSessionTicketToUConn() {
|
||||
uAssert(s.sessionTicketExt != nil && s.state == TicketInitialized, "tls: setSessionTicketExt failed: invalid state")
|
||||
s.uconnRef.HandshakeState.Session = s.sessionTicketExt.Session
|
||||
s.uconnRef.HandshakeState.Hello.SessionTicket = s.sessionTicketExt.Ticket
|
||||
s.state = TicketAllSet
|
||||
}
|
||||
|
||||
func mapSlice[T any, U any](slice []T, transform func(T) U) []U {
|
||||
newSlice := make([]U, 0, len(slice))
|
||||
for _, t := range slice {
|
||||
newSlice = append(newSlice, transform(t))
|
||||
}
|
||||
return newSlice
|
||||
}
|
||||
|
||||
func (s *sessionController) initPsk(session *SessionState, earlySecret []byte, binderKey []byte, pskIdentities []pskIdentity) {
|
||||
s.commonCheck("tls: initPsk failed", s.pskExtension, session, earlySecret, pskIdentities)
|
||||
uAssert(!s.pskExtension.IsInitialized(), "tls: initPsk failed: the psk extension is already initialized")
|
||||
|
||||
publicPskIdentities := mapSlice(pskIdentities, func(private pskIdentity) PskIdentity {
|
||||
return PskIdentity{
|
||||
Label: private.label,
|
||||
ObfuscatedTicketAge: private.obfuscatedTicketAge,
|
||||
}
|
||||
})
|
||||
s.pskExtension.InitializeByUtls(session, earlySecret, binderKey, publicPskIdentities)
|
||||
uAssert(s.pskExtension.IsInitialized(), "the psk extension is not initialized after initialization")
|
||||
s.uconnRef.HandshakeState.State13.BinderKey = binderKey
|
||||
s.uconnRef.HandshakeState.State13.EarlySecret = earlySecret
|
||||
s.uconnRef.HandshakeState.Session = session
|
||||
s.uconnRef.HandshakeState.Hello.PskIdentities = publicPskIdentities
|
||||
// binders are not expected to be available at this point
|
||||
s.state = PskExtInitialized
|
||||
}
|
||||
|
||||
func (s *sessionController) setPsk() {
|
||||
uAssert(s.pskExtension != nil && (s.state == PskExtInitialized || s.state == PskAllSet), "tls: setPsk failed: invalid state")
|
||||
pskCommon := s.pskExtension.GetPreSharedKeyCommon()
|
||||
if s.state == PskExtInitialized {
|
||||
s.uconnRef.HandshakeState.State13.EarlySecret = pskCommon.EarlySecret
|
||||
s.uconnRef.HandshakeState.Session = pskCommon.Session
|
||||
s.uconnRef.HandshakeState.Hello.PskIdentities = pskCommon.Identities
|
||||
s.uconnRef.HandshakeState.Hello.PskBinders = pskCommon.Binders
|
||||
} else if s.state == PskAllSet {
|
||||
uAssert(sliceEq([]any{
|
||||
s.uconnRef.HandshakeState.State13.EarlySecret,
|
||||
s.uconnRef.HandshakeState.Session,
|
||||
s.uconnRef.HandshakeState.Hello.PskIdentities,
|
||||
s.uconnRef.HandshakeState.Hello.PskBinders,
|
||||
}, []any{
|
||||
pskCommon.EarlySecret,
|
||||
pskCommon.Session,
|
||||
pskCommon.Identities,
|
||||
pskCommon.Binders,
|
||||
}), "setPsk failed: only binders are allowed to change on state `PskAllSet`")
|
||||
}
|
||||
s.uconnRef.HandshakeState.State13.BinderKey = pskCommon.BinderKey
|
||||
s.state = PskAllSet
|
||||
}
|
||||
|
||||
func (s *sessionController) shouldUpdateBinders() bool {
|
||||
if s.pskExtension == nil {
|
||||
return false
|
||||
}
|
||||
return s.state == PskExtInitialized || s.state == PskAllSet
|
||||
}
|
||||
|
||||
func (s *sessionController) updateBinders() {
|
||||
uAssert(s.shouldUpdateBinders(), "tls: updateBinders failed: shouldn't update binders")
|
||||
s.pskExtension.PatchBuiltHello(s.uconnRef.HandshakeState.Hello)
|
||||
}
|
||||
|
||||
func (s *sessionController) overridePskExt(psk PreSharedKeyExtension) error {
|
||||
if s.state != NoSession {
|
||||
return fmt.Errorf("SetSessionState13 failed: there's already a session")
|
||||
}
|
||||
s.pskExtension = psk
|
||||
if psk.IsInitialized() {
|
||||
s.state = PskExtInitialized
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var customizedHellos = []ClientHelloID{
|
||||
HelloCustom,
|
||||
HelloRandomized,
|
||||
HelloRandomizedALPN,
|
||||
HelloRandomizedNoALPN,
|
||||
}
|
||||
|
||||
func (s *sessionController) checkSessionExt() {
|
||||
uAssert(s.uconnRef.clientHelloBuildStatus == NotBuilt, "tls: checkSessionExt failed: we can't modify the session after the clientHello is built")
|
||||
numSessionExt := 0
|
||||
hasPskExt := false
|
||||
for i, e := range s.uconnRef.Extensions {
|
||||
switch ext := e.(type) {
|
||||
case *SessionTicketExtension:
|
||||
if ext != s.uconnRef.sessionController.sessionTicketExt {
|
||||
if anyTrue(customizedHellos, func(h *ClientHelloID) bool {
|
||||
return s.uconnRef.ClientHelloID.Client == h.Client
|
||||
}) {
|
||||
s.uconnRef.Extensions[i] = s.uconnRef.sessionController.sessionTicketExt
|
||||
} else {
|
||||
panic(fmt.Sprintf("tls: checkSessionExt failed: sessionTicketExtShortcut != SessionTicketExtension from the extension list and the clientHello is build from presets: [%v]", s.uconnRef.ClientHelloID))
|
||||
}
|
||||
}
|
||||
numSessionExt += 1
|
||||
case PreSharedKeyExtension:
|
||||
uAssert(i == len(s.uconnRef.Extensions)-1, "tls: checkSessionExt failed: PreSharedKeyExtension must be the last extension")
|
||||
if ext != s.uconnRef.sessionController.pskExtension {
|
||||
if anyTrue(customizedHellos, func(h *ClientHelloID) bool {
|
||||
return s.uconnRef.ClientHelloID.Client == h.Client
|
||||
}) {
|
||||
s.uconnRef.Extensions[i] = s.uconnRef.sessionController.pskExtension
|
||||
} else {
|
||||
panic(fmt.Sprintf("tls: checkSessionExt failed: pskExtensionShortcut != PreSharedKeyExtension from the extension list and the clientHello is build from presets: [%v]", s.uconnRef.ClientHelloID))
|
||||
}
|
||||
}
|
||||
hasPskExt = true
|
||||
}
|
||||
}
|
||||
if !(s.state == NoSession || s.state == TicketInitialized || s.state == PskExtInitialized) {
|
||||
panic(fmt.Sprintf("tls: checkSessionExt failed: can't remove session ticket extension; the session ticket extension is unused, but the internal state is: %d", s.state))
|
||||
}
|
||||
if numSessionExt == 0 {
|
||||
s.sessionTicketExt = nil
|
||||
s.uconnRef.HandshakeState.Session = nil
|
||||
s.uconnRef.HandshakeState.Hello.SessionTicket = nil
|
||||
} else if numSessionExt > 1 {
|
||||
panic("checkSessionExt failed: multiple session ticket extensions in the extension list")
|
||||
}
|
||||
if !hasPskExt {
|
||||
s.pskExtension = nil
|
||||
s.uconnRef.HandshakeState.State13.BinderKey = nil
|
||||
s.uconnRef.HandshakeState.State13.EarlySecret = nil
|
||||
s.uconnRef.HandshakeState.Session = nil
|
||||
s.uconnRef.HandshakeState.Hello.PskIdentities = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sessionController) onEnterLoadSessionCheck() {
|
||||
uAssert(!s.locked, "tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: session is set and locked, no call to loadSession is allowed")
|
||||
switch s.loadSessionTracker {
|
||||
case UtlsAboutToCall, NeverCalled:
|
||||
s.callingLoadSession = true
|
||||
case CalledByULoadSession, CalledByGoTLS:
|
||||
panic("tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: you must not call loadSession() twice")
|
||||
default:
|
||||
panic("tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: unimplemented state")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sessionController) onLoadSessionReturn() {
|
||||
uAssert(s.callingLoadSession, "tls: LoadSessionCoordinator.onLoadSessionReturn failed: it's not loading sessions, perhaps this function is not being called by loadSession.")
|
||||
switch s.loadSessionTracker {
|
||||
case NeverCalled:
|
||||
s.loadSessionTracker = CalledByGoTLS
|
||||
case UtlsAboutToCall:
|
||||
s.loadSessionTracker = CalledByULoadSession
|
||||
default:
|
||||
panic("tls: LoadSessionCoordinator.onLoadSessionReturn failed: unimplemented state")
|
||||
}
|
||||
s.callingLoadSession = false
|
||||
}
|
||||
|
||||
func (s *sessionController) shouldWriteBinders() bool {
|
||||
uAssert(s.callingLoadSession, "tls: shouldWriteBinders failed: LoadSessionCoordinator isn't loading sessions, perhaps this function is not being called by loadSession.")
|
||||
|
||||
switch s.loadSessionTracker {
|
||||
case NeverCalled:
|
||||
return true
|
||||
case UtlsAboutToCall:
|
||||
return false
|
||||
default:
|
||||
panic("tls: shouldWriteBinders failed: unimplemented state")
|
||||
}
|
||||
}
|
|
@ -802,22 +802,19 @@ func (e *SCTExtension) Write(_ []byte) (int, error) {
|
|||
|
||||
// SessionTicketExtension implements session_ticket (35)
|
||||
type SessionTicketExtension struct {
|
||||
Session *ClientSessionState
|
||||
Session *SessionState
|
||||
Ticket []byte
|
||||
}
|
||||
|
||||
func (e *SessionTicketExtension) writeToUConn(uc *UConn) error {
|
||||
if e.Session != nil {
|
||||
uc.HandshakeState.Session = e.Session.session
|
||||
uc.HandshakeState.Hello.SessionTicket = e.Session.ticket
|
||||
}
|
||||
// session states are handled later. At this point tickets aren't
|
||||
// being loaded by utls, so don't write anything to the UConn.
|
||||
uc.HandshakeState.Hello.TicketSupported = true // This doesn't really matter, this field is only used to add session ticket ext in go tls.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *SessionTicketExtension) Len() int {
|
||||
if e.Session != nil {
|
||||
return 4 + len(e.Session.ticket)
|
||||
}
|
||||
return 4
|
||||
return 4 + len(e.Ticket)
|
||||
}
|
||||
|
||||
func (e *SessionTicketExtension) Read(b []byte) (int, error) {
|
||||
|
@ -832,7 +829,7 @@ func (e *SessionTicketExtension) Read(b []byte) (int, error) {
|
|||
b[2] = byte(extBodyLen >> 8)
|
||||
b[3] = byte(extBodyLen)
|
||||
if extBodyLen > 0 {
|
||||
copy(b[4:], e.Session.ticket)
|
||||
copy(b[4:], e.Ticket)
|
||||
}
|
||||
return e.Len(), io.EOF
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue