mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-03 20:17:36 +03:00
feat: improve example docs: add detailed explanation about the design feat: add assertion on uApplyPatch
This commit is contained in:
parent
a040a404e6
commit
e707a3bcbe
7 changed files with 334 additions and 190 deletions
|
@ -1,120 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
)
|
||||
|
||||
type ClientSessionCache struct {
|
||||
sessionKeyMap map[string]*tls.ClientSessionState
|
||||
}
|
||||
|
||||
func NewClientSessionCache() tls.ClientSessionCache {
|
||||
return &ClientSessionCache{
|
||||
sessionKeyMap: make(map[string]*tls.ClientSessionState),
|
||||
}
|
||||
}
|
||||
|
||||
func (csc *ClientSessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) {
|
||||
if session, ok = csc.sessionKeyMap[sessionKey]; ok {
|
||||
fmt.Printf("Getting session for %s\n", sessionKey)
|
||||
return session, true
|
||||
}
|
||||
fmt.Printf("Missing session for %s\n", sessionKey)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (csc *ClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
|
||||
if csc.sessionKeyMap == nil {
|
||||
fmt.Printf("Deleting session for %s\n", sessionKey)
|
||||
delete(csc.sessionKeyMap, sessionKey)
|
||||
} else {
|
||||
fmt.Printf("Putting session for %s\n", sessionKey)
|
||||
csc.sessionKeyMap[sessionKey] = cs
|
||||
}
|
||||
}
|
||||
|
||||
func runPskCheck(helloID tls.ClientHelloID) {
|
||||
const serverAddr string = "refraction.network:443"
|
||||
csc := NewClientSessionCache()
|
||||
tcpConn, err := net.Dial("tcp", serverAddr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Everything below this line is brought to you by uTLS API, enjoy!
|
||||
|
||||
// use chs
|
||||
tlsConn := tls.UClient(tcpConn, &tls.Config{
|
||||
ServerName: strings.Split(serverAddr, ":")[0],
|
||||
// NextProtos: []string{"h2", "http/1.1"},
|
||||
ClientSessionCache: csc, // set this so session tickets will be saved
|
||||
}, helloID)
|
||||
|
||||
// HS
|
||||
err = tlsConn.Handshake()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if tlsConn.ConnectionState().HandshakeComplete {
|
||||
fmt.Println("Handshake complete")
|
||||
fmt.Printf("TLS Version: %04x\n", tlsConn.ConnectionState().Version)
|
||||
if tlsConn.ConnectionState().Version != tls.VersionTLS13 {
|
||||
fmt.Printf("Only TLS 1.3 suppports PSK\n")
|
||||
return
|
||||
}
|
||||
|
||||
if tlsConn.HandshakeState.State13.UsingPSK {
|
||||
panic("unintended using of PSK happened...")
|
||||
} else {
|
||||
fmt.Println("First connection, no PSK to use.")
|
||||
}
|
||||
|
||||
tlsConn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
tlsConn.Read(make([]byte, 1024)) // trigger a read so NewSessionTicket gets handled
|
||||
}
|
||||
tlsConn.Close()
|
||||
|
||||
tcpConnPSK, err := net.Dial("tcp", serverAddr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tlsConnPSK := tls.UClient(tcpConnPSK, &tls.Config{
|
||||
ServerName: strings.Split(serverAddr, ":")[0],
|
||||
ClientSessionCache: csc,
|
||||
}, helloID)
|
||||
|
||||
// HS
|
||||
err = tlsConnPSK.Handshake()
|
||||
fmt.Println(tlsConnPSK.HandshakeState.Hello.Raw)
|
||||
fmt.Println(tlsConnPSK.HandshakeState.Hello.PskIdentities)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if tlsConnPSK.ConnectionState().HandshakeComplete {
|
||||
fmt.Println("Handshake complete")
|
||||
fmt.Printf("TLS Version: %04x\n", tlsConnPSK.ConnectionState().Version)
|
||||
if tlsConnPSK.ConnectionState().Version != tls.VersionTLS13 {
|
||||
fmt.Printf("Only TLS 1.3 suppports PSK\n")
|
||||
return
|
||||
}
|
||||
|
||||
if tlsConnPSK.HandshakeState.State13.UsingPSK {
|
||||
fmt.Println("PSK used!")
|
||||
} else {
|
||||
panic("PSK not used for a resumption session!")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
runPskCheck(tls.HelloChrome_100_PSK)
|
||||
runPskCheck(tls.HelloGolang)
|
||||
}
|
157
examples/tls-resumption/main.go
Normal file
157
examples/tls-resumption/main.go
Normal file
|
@ -0,0 +1,157 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
)
|
||||
|
||||
type ClientSessionCache struct {
|
||||
sessionKeyMap map[string]*tls.ClientSessionState
|
||||
}
|
||||
|
||||
func NewClientSessionCache() tls.ClientSessionCache {
|
||||
return &ClientSessionCache{
|
||||
sessionKeyMap: make(map[string]*tls.ClientSessionState),
|
||||
}
|
||||
}
|
||||
|
||||
func (csc *ClientSessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) {
|
||||
if session, ok = csc.sessionKeyMap[sessionKey]; ok {
|
||||
fmt.Printf("Getting session for %s\n", sessionKey)
|
||||
return session, true
|
||||
}
|
||||
fmt.Printf("Missing session for %s\n", sessionKey)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (csc *ClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
|
||||
if csc.sessionKeyMap == nil {
|
||||
fmt.Printf("Deleting session for %s\n", sessionKey)
|
||||
delete(csc.sessionKeyMap, sessionKey)
|
||||
} else {
|
||||
fmt.Printf("Putting session for %s\n", sessionKey)
|
||||
csc.sessionKeyMap[sessionKey] = cs
|
||||
}
|
||||
}
|
||||
|
||||
func runResumptionCheck(helloID tls.ClientHelloID, serverAddr string, retry int, verbose bool) {
|
||||
csc := NewClientSessionCache()
|
||||
tcpConn, err := net.Dial("tcp", serverAddr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Everything below this line is brought to you by uTLS API, enjoy!
|
||||
|
||||
// use chs
|
||||
tlsConn := tls.UClient(tcpConn, &tls.Config{
|
||||
ServerName: strings.Split(serverAddr, ":")[0],
|
||||
// NextProtos: []string{"h2", "http/1.1"},
|
||||
ClientSessionCache: csc, // set this so session tickets will be saved
|
||||
}, helloID)
|
||||
|
||||
// HS
|
||||
err = tlsConn.Handshake()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var tlsVer uint16
|
||||
|
||||
if tlsConn.ConnectionState().HandshakeComplete {
|
||||
tlsVer = tlsConn.ConnectionState().Version
|
||||
if verbose {
|
||||
fmt.Println("Handshake complete")
|
||||
fmt.Printf("TLS Version: %04x\n", tlsVer)
|
||||
}
|
||||
if tlsVer == tls.VersionTLS13 {
|
||||
if verbose {
|
||||
fmt.Printf("Expecting PSK resumption\n")
|
||||
}
|
||||
} else if tlsVer == tls.VersionTLS12 {
|
||||
if verbose {
|
||||
fmt.Printf("Expecting session ticket resumption\n")
|
||||
}
|
||||
} else {
|
||||
panic("Don't try resumption on old TLS versions")
|
||||
}
|
||||
|
||||
if tlsConn.HandshakeState.State13.UsingPSK {
|
||||
panic("unintended using of PSK happened...")
|
||||
} else if tlsConn.DidTls12Resume() {
|
||||
panic("unintended using of session ticket happened...")
|
||||
} else {
|
||||
if verbose {
|
||||
fmt.Println("First connection, no PSK/session ticket to use.")
|
||||
}
|
||||
}
|
||||
|
||||
tlsConn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
tlsConn.Read(make([]byte, 1024)) // trigger a read so NewSessionTicket gets handled
|
||||
}
|
||||
tlsConn.Close()
|
||||
|
||||
for i := 0; i < retry; i++ {
|
||||
tcpConnPSK, err := net.Dial("tcp", serverAddr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tlsConnPSK := tls.UClient(tcpConnPSK, &tls.Config{
|
||||
ServerName: strings.Split(serverAddr, ":")[0],
|
||||
ClientSessionCache: csc,
|
||||
}, helloID)
|
||||
|
||||
// HS
|
||||
err = tlsConnPSK.Handshake()
|
||||
if verbose {
|
||||
fmt.Printf("tlsConnPSK.HandshakeState.Hello.Raw %v\n", tlsConnPSK.HandshakeState.Hello.Raw)
|
||||
fmt.Printf("tlsConnPSK.HandshakeState.Hello.PskIdentities: %v\n", tlsConnPSK.HandshakeState.Hello.PskIdentities)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if tlsConnPSK.ConnectionState().HandshakeComplete {
|
||||
if verbose {
|
||||
fmt.Println("Handshake complete")
|
||||
}
|
||||
newVer := tlsConnPSK.ConnectionState().Version
|
||||
if verbose {
|
||||
fmt.Printf("TLS Version: %04x\n", newVer)
|
||||
}
|
||||
if newVer != tlsVer {
|
||||
panic("Tls version changed unexpectedly on the second connection")
|
||||
}
|
||||
|
||||
if tlsVer == tls.VersionTLS13 && tlsConnPSK.HandshakeState.State13.UsingPSK {
|
||||
fmt.Println("[PSK used]")
|
||||
return
|
||||
} else if tlsVer == tls.VersionTLS12 && tlsConnPSK.DidTls12Resume() {
|
||||
fmt.Println("[session ticket used]")
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(700 * time.Millisecond)
|
||||
}
|
||||
panic(fmt.Sprintf("PSK or session ticket not used for a resumption session, server %s, helloID: %s", serverAddr, helloID.Client))
|
||||
}
|
||||
|
||||
func main() {
|
||||
tls13Url := "www.microsoft.com:443"
|
||||
tls12Url1 := "spocs.getpocket.com:443"
|
||||
tls12Url2 := "marketplace.visualstudio.com:443"
|
||||
runResumptionCheck(tls.HelloChrome_100_PSK, tls13Url, 1, false) // psk + utls
|
||||
runResumptionCheck(tls.HelloGolang, tls13Url, 1, false) // psk + crypto/tls
|
||||
|
||||
runResumptionCheck(tls.HelloChrome_100_PSK, tls12Url1, 10, false) // session ticket + utls
|
||||
runResumptionCheck(tls.HelloGolang, tls12Url1, 10, false) // session ticket + crypto/tls
|
||||
runResumptionCheck(tls.HelloChrome_100_PSK, tls12Url2, 10, false) // session ticket + utls
|
||||
runResumptionCheck(tls.HelloGolang, tls12Url2, 10, false) // session ticket + crypto/tls
|
||||
|
||||
}
|
|
@ -457,7 +457,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
|
|||
earlySecret = cipherSuite.extract(session.secret, nil)
|
||||
binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil)
|
||||
// [UTLS SECTION START]
|
||||
if c.utls.sessionController != nil && !c.utls.sessionController.shouldWriteBinders() {
|
||||
if c.utls.sessionController != nil && !c.utls.sessionController.shouldLoadSessionWriteBinders() {
|
||||
return
|
||||
}
|
||||
// [UTLS SECTION END]
|
||||
|
|
23
u_common.go
23
u_common.go
|
@ -727,6 +727,14 @@ func EnableWeakCiphers() {
|
|||
}...)
|
||||
}
|
||||
|
||||
func mapSlice[T any, U any](slice []T, transform func(T) U) []U {
|
||||
newSlice := make([]U, 0, len(slice))
|
||||
for _, t := range slice {
|
||||
newSlice = append(newSlice, transform(t))
|
||||
}
|
||||
return newSlice
|
||||
}
|
||||
|
||||
func panicOnNil(failureMsg string, params ...any) {
|
||||
for i, p := range params {
|
||||
if p == nil {
|
||||
|
@ -735,22 +743,31 @@ func panicOnNil(failureMsg string, params ...any) {
|
|||
}
|
||||
}
|
||||
|
||||
func anyTrue[T any](slice []T, predicate func(t *T) bool) bool {
|
||||
func anyTrue[T any](slice []T, predicate func(i int, t *T) bool) bool {
|
||||
for i := 0; i < len(slice); i++ {
|
||||
if predicate(&slice[i]) {
|
||||
if predicate(i, &slice[i]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func allTrue[T any](slice []T, predicate func(i int, t *T) bool) bool {
|
||||
for i := 0; i < len(slice); i++ {
|
||||
if !predicate(i, &slice[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func uAssert(condition bool, msg string) {
|
||||
if !condition {
|
||||
panic(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func sliceEq(sliceA []any, sliceB []any) bool {
|
||||
func sliceEq[T comparable](sliceA []T, sliceB []T) bool {
|
||||
if len(sliceA) != len(sliceB) {
|
||||
return false
|
||||
}
|
||||
|
|
12
u_conn.go
12
u_conn.go
|
@ -143,10 +143,10 @@ func (uconn *UConn) uLoadSession() error {
|
|||
case shouldSetTicket:
|
||||
uconn.sessionController.setSessionTicketToUConn()
|
||||
case shouldSetPsk:
|
||||
uconn.sessionController.setPsk()
|
||||
uconn.sessionController.setPskToHandshake()
|
||||
case shouldLoad:
|
||||
hello := uconn.HandshakeState.Hello.getPrivatePtr()
|
||||
uconn.sessionController.aboutToLoadSession()
|
||||
uconn.sessionController.utlsAboutToLoadSession()
|
||||
session, earlySecret, binderKey, err := uconn.loadSession(hello)
|
||||
if session == nil || err != nil {
|
||||
return err
|
||||
|
@ -164,10 +164,16 @@ func (uconn *UConn) uLoadSession() error {
|
|||
}
|
||||
|
||||
func (uconn *UConn) uApplyPatch() {
|
||||
helloLen := len(uconn.HandshakeState.Hello.Raw)
|
||||
if uconn.sessionController.shouldUpdateBinders() {
|
||||
uconn.sessionController.updateBinders()
|
||||
uconn.sessionController.setPsk()
|
||||
uconn.sessionController.setPskToHandshake()
|
||||
}
|
||||
uAssert(helloLen == len(uconn.HandshakeState.Hello.Raw), "tls: uApplyPatch Failed: the patch should never change the length of the marshaled clientHello")
|
||||
}
|
||||
|
||||
func (uconn *UConn) DidTls12Resume() bool {
|
||||
return uconn.didResume
|
||||
}
|
||||
|
||||
// SetSessionState12 sets the session ticket, which may be preshared or fake.
|
||||
|
|
|
@ -16,29 +16,73 @@ type PreSharedKeyCommon struct {
|
|||
Session *SessionState
|
||||
}
|
||||
|
||||
// The lifecycle of a PreSharedKeyExtension:
|
||||
//
|
||||
// Creation Phase:
|
||||
// - The extension is created.
|
||||
//
|
||||
// Write Phase:
|
||||
//
|
||||
// - [writeToUConn() called]:
|
||||
//
|
||||
// > - During this phase, it is important to note that implementations should not write any session data to the UConn (Underlying Connection) as the session is not yet loaded. The session context is not active at this point.
|
||||
//
|
||||
// Initialization Phase:
|
||||
//
|
||||
// - [IsInitialized() called]:
|
||||
//
|
||||
// If IsInitialized() returns true
|
||||
//
|
||||
// > - GetPreSharedKeyCommon() will be called subsequently and the PSK states in handshake/clientHello will be fully initialized.
|
||||
//
|
||||
// If IsInitialized() returns false:
|
||||
//
|
||||
// > - [conn.loadSession() called]:
|
||||
//
|
||||
// >> - Once the session is available:
|
||||
//
|
||||
// >>> - [InitializeByUtls() called]:
|
||||
//
|
||||
// >>>> - The InitializeByUtls() method is invoked to initialize the extension based on the loaded session data.
|
||||
//
|
||||
// >>>> - This step prepares the extension for further processing.
|
||||
//
|
||||
// Marshal Phase:
|
||||
//
|
||||
// - [Len() called], [Read() called]:
|
||||
//
|
||||
// > - Implementations should marshal the extension into bytes, using placeholder binders to maintain the correct length.
|
||||
//
|
||||
// Binders Preparation Phase:
|
||||
//
|
||||
// - [PatchBuiltHello(hello) called]:
|
||||
//
|
||||
// > - The client hello is already marshaled in the "hello.Raw" format.
|
||||
//
|
||||
// > - Implementations are expected to update the binders within the marshaled client hello.
|
||||
//
|
||||
// - [GetPreSharedKeyCommon() called]:
|
||||
//
|
||||
// > - Implementations should gather and provide the final pre-shared key (PSK) related data.
|
||||
//
|
||||
// > - This data will be incorporated into both the clientHello and HandshakeState, ensuring that the PSK-related information is properly set and ready for the handshake process.
|
||||
type PreSharedKeyExtension interface {
|
||||
// TLSExtension must be implemented by all PreSharedKeyExtension implementations.
|
||||
// However, the Read() method should return an error since it MUST NOT be used
|
||||
// for PreSharedKeyExtension.
|
||||
TLSExtension
|
||||
|
||||
// IsInitialized returns a boolean indicating whether the extension has been initialized.
|
||||
// If false is returned, utls will invoke `InitializeByUtls()` for the necessary initialization.
|
||||
IsInitialized() bool
|
||||
|
||||
// InitializeByUtls is invoked when IsInitialized() returns false.
|
||||
// It initializes the extension using a real and valid TLS 1.3 session.
|
||||
InitializeByUtls(session *SessionState, earlySecret []byte, binderKey []byte, identities []PskIdentity)
|
||||
|
||||
// GetBinders returns the binders that were computed during the handshake
|
||||
// to be set in the internal copy of the ClientHello. Only needed if expecting
|
||||
// to resume the session.
|
||||
//
|
||||
// FakePreSharedKeyExtension MUST return nil to make sure utls DOES NOT
|
||||
// try to do any session resumption.
|
||||
// GetPreSharedKeyCommon retrieves the final PreSharedKey-related states as defined in PreSharedKeyCommon.
|
||||
GetPreSharedKeyCommon() PreSharedKeyCommon
|
||||
|
||||
// ReadWithRawHello is used to read the extension from the ClientHello
|
||||
// instead of Read(), where the latter is used to read all other extensions.
|
||||
//
|
||||
// This is needed because the PSK extension needs to calculate the binder
|
||||
// based on all previous parts of the ClientHello.
|
||||
// PatchBuiltHello is called once the hello message is fully applied and marshaled.
|
||||
// Its purpose is to update the binders of PSK (Pre-Shared Key) identities.
|
||||
PatchBuiltHello(hello *PubClientHelloMsg) error
|
||||
|
||||
mustEmbedUnimplementedPreSharedKeyExtension() // this works like a type guard
|
||||
|
@ -112,6 +156,7 @@ func (e *UtlsPreSharedKeyExtension) GetPreSharedKeyCommon() PreSharedKeyCommon {
|
|||
|
||||
func pskExtLen(identities []PskIdentity, binders [][]byte) int {
|
||||
if len(identities) == 0 || len(binders) == 0 {
|
||||
// If there isn't psk identities, we don't write this ticket to the client hello, and therefore the length should be 0.
|
||||
return 0
|
||||
}
|
||||
length := 4 // extension type + extension length
|
||||
|
@ -243,9 +288,7 @@ func (e *UtlsPreSharedKeyExtension) UnmarshalJSON(_ []byte) error {
|
|||
// ClientHello.
|
||||
//
|
||||
// It does not compute binders based on ClientHello, but uses the binders specified instead.
|
||||
//
|
||||
// TODO: Only one of FakePreSharedKeyExtension and FakePreSharedKeyExtension should
|
||||
// be kept, the other one should be just removed. We still need to learn more of the safety
|
||||
// We still need to learn more of the safety
|
||||
// of hardcoding both Identities and Binders without recalculating the latter.
|
||||
type FakePreSharedKeyExtension struct {
|
||||
UnimplementedPreSharedKeyExtension
|
||||
|
@ -280,7 +323,7 @@ func (e *FakePreSharedKeyExtension) Len() int {
|
|||
|
||||
func (e *FakePreSharedKeyExtension) Read(b []byte) (int, error) {
|
||||
for _, b := range e.Binders {
|
||||
if !(anyTrue(validHashLen, func(valid *int) bool {
|
||||
if !(anyTrue(validHashLen, func(_ int, valid *int) bool {
|
||||
return len(b) == *valid
|
||||
})) {
|
||||
return 0, errors.New("tls: FakePreSharedKeyExtension.Read failed: invalid binder size")
|
||||
|
|
|
@ -2,39 +2,52 @@ package tls
|
|||
|
||||
import "fmt"
|
||||
|
||||
// Tracking the state of calling conn.loadSession
|
||||
type LoadSessionTrackerState int
|
||||
|
||||
const NeverCalled LoadSessionTrackerState = 0
|
||||
const UtlsAboutToCall LoadSessionTrackerState = 3
|
||||
const CalledByULoadSession LoadSessionTrackerState = 1
|
||||
const CalledByGoTLS LoadSessionTrackerState = 2
|
||||
const UtlsAboutToCall LoadSessionTrackerState = 1
|
||||
const CalledByULoadSession LoadSessionTrackerState = 2
|
||||
const CalledByGoTLS LoadSessionTrackerState = 3
|
||||
|
||||
// The state of the session controller
|
||||
type sessionState int
|
||||
|
||||
const NoSession sessionState = 0
|
||||
const TicketInitialized sessionState = 1
|
||||
const TicketAllSet sessionState = 4
|
||||
const PskExtInitialized sessionState = 2
|
||||
const PskAllSet sessionState = 3
|
||||
const TicketAllSet sessionState = 2
|
||||
const PskExtInitialized sessionState = 3
|
||||
const PskAllSet sessionState = 4
|
||||
|
||||
// sessionController is responsible for all session related
|
||||
// sessionController is responsible for managing and controlling all session related states. It manages the lifecycle of the session ticket extension and the psk extension, including initialization, removal if the client hello spec doesn't contain any of them, and setting the prepared state to the client hello.
|
||||
//
|
||||
// Users should never directly modify the underlying state. Violations will result in undefined behaviors.
|
||||
//
|
||||
// Users should never construct sessionController by themselves, use the function `newSessionController` instead.
|
||||
type sessionController struct {
|
||||
sessionTicketExt *SessionTicketExtension
|
||||
pskExtension PreSharedKeyExtension
|
||||
uconnRef *UConn
|
||||
state sessionState
|
||||
// sessionTicketExt logically owns the session ticket extension
|
||||
sessionTicketExt *SessionTicketExtension
|
||||
|
||||
// pskExtension logically owns the psk extension
|
||||
pskExtension PreSharedKeyExtension
|
||||
|
||||
// uconnRef is a reference to the uconn
|
||||
uconnRef *UConn
|
||||
|
||||
// state represents the internal state of the sessionController. Users are advised to modify the state only through designated methods and avoid direct manipulation, as doing so may result in undefined behavior.
|
||||
state sessionState
|
||||
|
||||
// loadSessionTracker keeps track of how the conn.loadSession method is being utilized.
|
||||
loadSessionTracker LoadSessionTrackerState
|
||||
|
||||
// callingLoadSession is a boolean flag that indicates whether the `conn.loadSession` function is currently being invoked.
|
||||
callingLoadSession bool
|
||||
locked bool
|
||||
|
||||
// locked is a boolean flag that becomes true once all states are appropriately set. Once `locked` is true, further modifications are disallowed, except for the binders.
|
||||
locked bool
|
||||
}
|
||||
|
||||
type shouldLoadSessionResult int
|
||||
|
||||
const shouldReturn shouldLoadSessionResult = 0
|
||||
const shouldSetTicket shouldLoadSessionResult = 1
|
||||
const shouldSetPsk shouldLoadSessionResult = 2
|
||||
const shouldLoad shouldLoadSessionResult = 3
|
||||
|
||||
// newSessionController constructs a new SessionController
|
||||
func newSessionController(uconn *UConn) *sessionController {
|
||||
return &sessionController{
|
||||
uconnRef: uconn,
|
||||
|
@ -51,10 +64,22 @@ func (s *sessionController) isSessionLocked() bool {
|
|||
return s.locked
|
||||
}
|
||||
|
||||
type shouldLoadSessionResult int
|
||||
|
||||
const shouldReturn shouldLoadSessionResult = 0
|
||||
const shouldSetTicket shouldLoadSessionResult = 1
|
||||
const shouldSetPsk shouldLoadSessionResult = 2
|
||||
const shouldLoad shouldLoadSessionResult = 3
|
||||
|
||||
// shouldLoadSession determines the appropriate action to take when it is time to load the session for the clientHello.
|
||||
// There are several possible scenarios:
|
||||
// - If a session ticket is already initialized, typically via the `initSessionTicketExt()` function, the ticket should be set in the client hello.
|
||||
// - If a pre-shared key (PSK) is already initialized, typically via the `overridePskExt()` function, the PSK should be set in the client hello.
|
||||
// - If both the `sessionTicketExt` and `pskExtension` are nil, which might occur if the client hello spec does not include them, we should skip the loadSession().
|
||||
// - In all other cases, the function proceeds to load the session.
|
||||
func (s *sessionController) shouldLoadSession() shouldLoadSessionResult {
|
||||
if s.sessionTicketExt == nil && s.pskExtension == nil || s.uconnRef.clientHelloBuildStatus != NotBuilt {
|
||||
fmt.Println("uLoadSession s.sessionTicketExt == nil && s.pskExtension == nil")
|
||||
// There's no need to load session since we don't have the related extensions.
|
||||
// No need to load session since we don't have the related extensions.
|
||||
return shouldReturn
|
||||
}
|
||||
if s.state == TicketInitialized {
|
||||
|
@ -66,11 +91,15 @@ func (s *sessionController) shouldLoadSession() shouldLoadSessionResult {
|
|||
return shouldLoad
|
||||
}
|
||||
|
||||
func (s *sessionController) aboutToLoadSession() {
|
||||
// utlsAboutToLoadSession updates the loadSessionTracker to `UtlsAboutToCall` to signal the initiation of a session loading operation,
|
||||
// provided that the preconditions are met. If the preconditions are not met (due to incorrect utls implementation), this function triggers a panic.
|
||||
func (s *sessionController) utlsAboutToLoadSession() {
|
||||
uAssert(s.state == NoSession && !s.locked, "tls: aboutToLoadSession failed: must only load session when the session of the client hello is not locked and when there's currently no session")
|
||||
s.loadSessionTracker = UtlsAboutToCall
|
||||
}
|
||||
|
||||
// commonCheck performs various common precondition checks, including validating the `clientHelloBuildStatus`,
|
||||
// checking the internal state, and verifying the provided parameters.
|
||||
func (s *sessionController) commonCheck(failureMsg string, params ...any) {
|
||||
if s.uconnRef.clientHelloBuildStatus != NotBuilt {
|
||||
panic(failureMsg + ": we can't modify the session after the clientHello is built")
|
||||
|
@ -81,11 +110,16 @@ func (s *sessionController) commonCheck(failureMsg string, params ...any) {
|
|||
panicOnNil(failureMsg, params...)
|
||||
}
|
||||
|
||||
// finalCheck performs a comprehensive check on the updated state to ensure the correctness of the changes.
|
||||
// If the checks pass successfully, the sessionController's state will be locked.
|
||||
// Any failure in passing the tests indicates incorrect implementations in the utls, which will result in triggering a panic.
|
||||
// Refer to the documentation for the `locked` field for more detailed information.
|
||||
func (s *sessionController) finalCheck() {
|
||||
uAssert(s.state == PskAllSet || s.state == TicketAllSet || s.state == NoSession, "tls: SessionController.finalCheck failed: the session is half set")
|
||||
s.locked = true
|
||||
}
|
||||
|
||||
// initSessionTicketExt initializes the ticket and sets the state to `TicketInitialized`.
|
||||
func (s *sessionController) initSessionTicketExt(session *SessionState, ticket []byte) {
|
||||
s.commonCheck("tls: initSessionTicket failed", s.sessionTicketExt, session, ticket)
|
||||
s.sessionTicketExt.Session = session
|
||||
|
@ -93,6 +127,7 @@ func (s *sessionController) initSessionTicketExt(session *SessionState, ticket [
|
|||
s.state = TicketInitialized
|
||||
}
|
||||
|
||||
// setSessionTicketToUConn write the ticket states from the session ticket extension to the client hello and handshake state.
|
||||
func (s *sessionController) setSessionTicketToUConn() {
|
||||
uAssert(s.sessionTicketExt != nil && s.state == TicketInitialized, "tls: setSessionTicketExt failed: invalid state")
|
||||
s.uconnRef.HandshakeState.Session = s.sessionTicketExt.Session
|
||||
|
@ -100,14 +135,9 @@ func (s *sessionController) setSessionTicketToUConn() {
|
|||
s.state = TicketAllSet
|
||||
}
|
||||
|
||||
func mapSlice[T any, U any](slice []T, transform func(T) U) []U {
|
||||
newSlice := make([]U, 0, len(slice))
|
||||
for _, t := range slice {
|
||||
newSlice = append(newSlice, transform(t))
|
||||
}
|
||||
return newSlice
|
||||
}
|
||||
|
||||
// initPSK initializes the PSK extension using a valid session. The PSK extension
|
||||
// should not be initialized previously, and the parameters must not be nil;
|
||||
// otherwise, this function will trigger a panic.
|
||||
func (s *sessionController) initPsk(session *SessionState, earlySecret []byte, binderKey []byte, pskIdentities []pskIdentity) {
|
||||
s.commonCheck("tls: initPsk failed", s.pskExtension, session, earlySecret, pskIdentities)
|
||||
uAssert(!s.pskExtension.IsInitialized(), "tls: initPsk failed: the psk extension is already initialized")
|
||||
|
@ -128,8 +158,9 @@ func (s *sessionController) initPsk(session *SessionState, earlySecret []byte, b
|
|||
s.state = PskExtInitialized
|
||||
}
|
||||
|
||||
func (s *sessionController) setPsk() {
|
||||
uAssert(s.pskExtension != nil && (s.state == PskExtInitialized || s.state == PskAllSet), "tls: setPsk failed: invalid state")
|
||||
// setPskToHandshake sets the psk to the handshake state and client hello.
|
||||
func (s *sessionController) setPskToHandshake() {
|
||||
uAssert(s.pskExtension != nil && (s.state == PskExtInitialized || s.state == PskAllSet), "tls: setPskToHandshake failed: invalid state")
|
||||
pskCommon := s.pskExtension.GetPreSharedKeyCommon()
|
||||
if s.state == PskExtInitialized {
|
||||
s.uconnRef.HandshakeState.State13.EarlySecret = pskCommon.EarlySecret
|
||||
|
@ -137,22 +168,19 @@ func (s *sessionController) setPsk() {
|
|||
s.uconnRef.HandshakeState.Hello.PskIdentities = pskCommon.Identities
|
||||
s.uconnRef.HandshakeState.Hello.PskBinders = pskCommon.Binders
|
||||
} else if s.state == PskAllSet {
|
||||
uAssert(sliceEq([]any{
|
||||
s.uconnRef.HandshakeState.State13.EarlySecret,
|
||||
s.uconnRef.HandshakeState.Session,
|
||||
s.uconnRef.HandshakeState.Hello.PskIdentities,
|
||||
s.uconnRef.HandshakeState.Hello.PskBinders,
|
||||
}, []any{
|
||||
pskCommon.EarlySecret,
|
||||
pskCommon.Session,
|
||||
pskCommon.Identities,
|
||||
pskCommon.Binders,
|
||||
}), "setPsk failed: only binders are allowed to change on state `PskAllSet`")
|
||||
uAssert(s.uconnRef.HandshakeState.Session == pskCommon.Session && sliceEq(s.uconnRef.HandshakeState.State13.EarlySecret, pskCommon.EarlySecret) &&
|
||||
allTrue(s.uconnRef.HandshakeState.Hello.PskIdentities, func(i int, psk *PskIdentity) bool {
|
||||
return pskCommon.Identities[i].ObfuscatedTicketAge == psk.ObfuscatedTicketAge && sliceEq(pskCommon.Identities[i].Label, psk.Label)
|
||||
}), "tls: setPskToHandshake failed: only binders are allowed to change on state `PskAllSet`")
|
||||
}
|
||||
s.uconnRef.HandshakeState.State13.BinderKey = pskCommon.BinderKey
|
||||
s.state = PskAllSet
|
||||
}
|
||||
|
||||
// shouldUpdateBinders determines whether binders should be updated based on the presence of an initialized psk extension.
|
||||
// This function returns true if an initialized psk extension exists. Binders are allowed to be updated when the state is `PskAllSet`,
|
||||
// as the `BuildHandshakeState` function can be called multiple times in this case. However, it's important to note that
|
||||
// the session state, apart from binders, should not be altered more than once.
|
||||
func (s *sessionController) shouldUpdateBinders() bool {
|
||||
if s.pskExtension == nil {
|
||||
return false
|
||||
|
@ -165,6 +193,7 @@ func (s *sessionController) updateBinders() {
|
|||
s.pskExtension.PatchBuiltHello(s.uconnRef.HandshakeState.Hello)
|
||||
}
|
||||
|
||||
// overridePskExt allows the user of utls to customize the psk extension.
|
||||
func (s *sessionController) overridePskExt(psk PreSharedKeyExtension) error {
|
||||
if s.state != NoSession {
|
||||
return fmt.Errorf("SetSessionState13 failed: there's already a session")
|
||||
|
@ -183,6 +212,11 @@ var customizedHellos = []ClientHelloID{
|
|||
HelloRandomizedNoALPN,
|
||||
}
|
||||
|
||||
// CheckSessionExt is designed to be called after applying client hello specs. It performs the following checks and fixups:
|
||||
// - If the session ticket extension or PSK extension is missing from the extension list, owned extensions are dropped and states are reset.
|
||||
// - Ensures that the session ticket extension or PSK extension matches the owned one.
|
||||
// - Ensures that there is only one session ticket extension or PSK extension.
|
||||
// - Ensures that the PSK extension is the last extension in the extension list.
|
||||
func (s *sessionController) checkSessionExt() {
|
||||
uAssert(s.uconnRef.clientHelloBuildStatus == NotBuilt, "tls: checkSessionExt failed: we can't modify the session after the clientHello is built")
|
||||
numSessionExt := 0
|
||||
|
@ -191,7 +225,7 @@ func (s *sessionController) checkSessionExt() {
|
|||
switch ext := e.(type) {
|
||||
case *SessionTicketExtension:
|
||||
if ext != s.uconnRef.sessionController.sessionTicketExt {
|
||||
if anyTrue(customizedHellos, func(h *ClientHelloID) bool {
|
||||
if anyTrue(customizedHellos, func(_ int, h *ClientHelloID) bool {
|
||||
return s.uconnRef.ClientHelloID.Client == h.Client
|
||||
}) {
|
||||
s.uconnRef.Extensions[i] = s.uconnRef.sessionController.sessionTicketExt
|
||||
|
@ -203,7 +237,7 @@ func (s *sessionController) checkSessionExt() {
|
|||
case PreSharedKeyExtension:
|
||||
uAssert(i == len(s.uconnRef.Extensions)-1, "tls: checkSessionExt failed: PreSharedKeyExtension must be the last extension")
|
||||
if ext != s.uconnRef.sessionController.pskExtension {
|
||||
if anyTrue(customizedHellos, func(h *ClientHelloID) bool {
|
||||
if anyTrue(customizedHellos, func(_ int, h *ClientHelloID) bool {
|
||||
return s.uconnRef.ClientHelloID.Client == h.Client
|
||||
}) {
|
||||
s.uconnRef.Extensions[i] = s.uconnRef.sessionController.pskExtension
|
||||
|
@ -233,6 +267,8 @@ func (s *sessionController) checkSessionExt() {
|
|||
}
|
||||
}
|
||||
|
||||
// onEnterLoadSessionCheck is intended to be invoked upon entering the `conn.loadSession` function.
|
||||
// It is designed to ensure the correctness of the utls implementation. If the utls implementation is found to be incorrect, this function will trigger a panic.
|
||||
func (s *sessionController) onEnterLoadSessionCheck() {
|
||||
uAssert(!s.locked, "tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: session is set and locked, no call to loadSession is allowed")
|
||||
switch s.loadSessionTracker {
|
||||
|
@ -245,6 +281,9 @@ func (s *sessionController) onEnterLoadSessionCheck() {
|
|||
}
|
||||
}
|
||||
|
||||
// onLoadSessionReturn is intended to be invoked upon returning from the `conn.loadSession` function.
|
||||
// It serves as a validation step for the correctness of the underlying utls implementation.
|
||||
// If the utls implementation is incorrect, this function will trigger a panic.
|
||||
func (s *sessionController) onLoadSessionReturn() {
|
||||
uAssert(s.callingLoadSession, "tls: LoadSessionCoordinator.onLoadSessionReturn failed: it's not loading sessions, perhaps this function is not being called by loadSession.")
|
||||
switch s.loadSessionTracker {
|
||||
|
@ -258,7 +297,9 @@ func (s *sessionController) onLoadSessionReturn() {
|
|||
s.callingLoadSession = false
|
||||
}
|
||||
|
||||
func (s *sessionController) shouldWriteBinders() bool {
|
||||
// shouldLoadSessionWriteBinders checks if `conn.loadSession` should proceed to write binders and marshal the client hello. If the utls implementation
|
||||
// is incorrect, this function will trigger a panic.
|
||||
func (s *sessionController) shouldLoadSessionWriteBinders() bool {
|
||||
uAssert(s.callingLoadSession, "tls: shouldWriteBinders failed: LoadSessionCoordinator isn't loading sessions, perhaps this function is not being called by loadSession.")
|
||||
|
||||
switch s.loadSessionTracker {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue