fix: PSK failing if config session cache set

* Fix a bug causing PSK to fail if Config.ClientSessionCache is set.
* Removed `ClientSessionCacheOverride` from `UtlsPreSharedKeyExtension`. Set the `ClientSessionCache` in `Config`!

Co-Authored-By: zeeker999 <13848632+zeeker999@users.noreply.github.com>
This commit is contained in:
Gaukas Wang 2023-08-16 19:09:01 -06:00
parent 3d7eea3346
commit 3162534cc7
No known key found for this signature in database
GPG key ID: 9E2F8986D76F8B5D
6 changed files with 27 additions and 125 deletions

View file

@ -86,11 +86,9 @@ func main() {
} }
tlsConnPSK := tls.UClient(tcpConnPSK, &tls.Config{ tlsConnPSK := tls.UClient(tcpConnPSK, &tls.Config{
ServerName: strings.Split(serverAddr, ":")[0], ServerName: strings.Split(serverAddr, ":")[0],
// ClientSessionCache: csc, // set this will cause PSK to fail. This is a bug... ClientSessionCache: csc,
}, tls.HelloChrome_100_PSK, &tls.UtlsPreSharedKeyExtension{ }, tls.HelloChrome_100_PSK, &tls.UtlsPreSharedKeyExtension{})
ClientSessionCacheOverride: csc, // ONLY set your own ClientSessionCache here if you want to use PSK
})
// HS // HS
err = tlsConnPSK.Handshake() err = tlsConnPSK.Handshake()

View file

@ -85,9 +85,9 @@ func (c *CompressionMethodsJSONUnmarshaler) CompressionMethods() []uint8 {
} }
type TLSExtensionsJSONUnmarshaler struct { type TLSExtensionsJSONUnmarshaler struct {
AllowUnknownExt bool // if set, unknown extensions will be added as GenericExtension, without recovering ext payload AllowUnknownExt bool // if set, unknown extensions will be added as GenericExtension, without recovering ext payload
ClientSessionCache ClientSessionCache // if set, PSK extension will be Unmarshaled into UtlsPreSharedKeyExtension. Otherwise FakePreSharedKeyExtension. UseRealPSK bool // if set, PSK extension will be real PSK extension, otherwise it will be fake PSK extension
extensions []TLSExtensionJSON extensions []TLSExtensionJSON
} }
func (e *TLSExtensionsJSONUnmarshaler) UnmarshalJSON(jsonStr []byte) error { func (e *TLSExtensionsJSONUnmarshaler) UnmarshalJSON(jsonStr []byte) error {
@ -120,8 +120,8 @@ func (e *TLSExtensionsJSONUnmarshaler) UnmarshalJSON(jsonStr []byte) error {
switch extID { switch extID {
case extensionPreSharedKey: case extensionPreSharedKey:
// PSK extension, need to see if we do real or fake PSK // PSK extension, need to see if we do real or fake PSK
if e.ClientSessionCache != nil { if e.UseRealPSK {
ext = &UtlsPreSharedKeyExtension{ClientSessionCacheOverride: e.ClientSessionCache} ext = &UtlsPreSharedKeyExtension{}
} else { } else {
ext = &FakePreSharedKeyExtension{} ext = &FakePreSharedKeyExtension{}
} }

View file

@ -210,7 +210,7 @@ func (chs *ClientHelloSpec) ReadCompressionMethods(compressionMethods []byte) er
// a byte slice into []TLSExtension. // a byte slice into []TLSExtension.
// //
// If keepPSK is not set, the PSK extension will cause an error. // If keepPSK is not set, the PSK extension will cause an error.
func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool, clientSessionCache ...ClientSessionCache) error { func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool, realPSK bool) error {
extensions := cryptobyte.String(b) extensions := cryptobyte.String(b)
for !extensions.Empty() { for !extensions.Empty() {
var extension uint16 var extension uint16
@ -228,8 +228,8 @@ func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool,
switch extension { switch extension {
case extensionPreSharedKey: case extensionPreSharedKey:
// PSK extension, need to see if we do real or fake PSK // PSK extension, need to see if we do real or fake PSK
if len(clientSessionCache) > 0 && clientSessionCache[0] != nil { if realPSK {
extWriter = &UtlsPreSharedKeyExtension{ClientSessionCacheOverride: clientSessionCache[0]} extWriter = &UtlsPreSharedKeyExtension{}
} else { } else {
extWriter = &FakePreSharedKeyExtension{} extWriter = &FakePreSharedKeyExtension{}
} }
@ -464,14 +464,20 @@ func (chs *ClientHelloSpec) ImportTLSClientHelloFromJSON(jsonB []byte) error {
} }
// FromRaw converts a ClientHello message in the form of raw bytes into a ClientHelloSpec. // FromRaw converts a ClientHello message in the form of raw bytes into a ClientHelloSpec.
func (chs *ClientHelloSpec) FromRaw(raw []byte, allowBluntMimicry ...bool) error { //
// ctrlFlags: []bool{bluntMimicry, realPSK}
func (chs *ClientHelloSpec) FromRaw(raw []byte, ctrlFlags ...bool) error {
if chs == nil { if chs == nil {
return errors.New("cannot unmarshal into nil ClientHelloSpec") return errors.New("cannot unmarshal into nil ClientHelloSpec")
} }
var bluntMimicry = false var bluntMimicry = false
if len(allowBluntMimicry) == 1 { var realPSK = false
bluntMimicry = allowBluntMimicry[0] if len(ctrlFlags) > 0 {
bluntMimicry = ctrlFlags[0]
}
if len(ctrlFlags) > 1 {
realPSK = ctrlFlags[1]
} }
*chs = ClientHelloSpec{} // reset *chs = ClientHelloSpec{} // reset
@ -538,89 +544,7 @@ func (chs *ClientHelloSpec) FromRaw(raw []byte, allowBluntMimicry ...bool) error
return errors.New("unable to read extensions data") return errors.New("unable to read extensions data")
} }
if err := chs.ReadTLSExtensions(extensions, bluntMimicry); err != nil { if err := chs.ReadTLSExtensions(extensions, bluntMimicry, realPSK); err != nil {
return err
}
return nil
}
// FromRaw converts a ClientHello message in the form of raw bytes into a ClientHelloSpec.
func (chs *ClientHelloSpec) FromRawWithClientSessionCache(raw []byte, csc ClientSessionCache, allowBluntMimicry ...bool) error {
if chs == nil {
return errors.New("cannot unmarshal into nil ClientHelloSpec")
}
var bluntMimicry = false
if len(allowBluntMimicry) == 1 {
bluntMimicry = allowBluntMimicry[0]
}
*chs = ClientHelloSpec{} // reset
s := cryptobyte.String(raw)
var contentType uint8
var recordVersion uint16
if !s.ReadUint8(&contentType) || // record type
!s.ReadUint16(&recordVersion) || !s.Skip(2) { // record version and length
return errors.New("unable to read record type, version, and length")
}
if recordType(contentType) != recordTypeHandshake {
return errors.New("record is not a handshake")
}
var handshakeVersion uint16
var handshakeType uint8
if !s.ReadUint8(&handshakeType) || !s.Skip(3) || // message type and 3 byte length
!s.ReadUint16(&handshakeVersion) || !s.Skip(32) { // 32 byte random
return errors.New("unable to read handshake message type, length, and random")
}
if handshakeType != typeClientHello {
return errors.New("handshake message is not a ClientHello")
}
chs.TLSVersMin = recordVersion
chs.TLSVersMax = handshakeVersion
var ignoredSessionID cryptobyte.String
if !s.ReadUint8LengthPrefixed(&ignoredSessionID) {
return errors.New("unable to read session id")
}
// CipherSuites
var cipherSuitesBytes cryptobyte.String
if !s.ReadUint16LengthPrefixed(&cipherSuitesBytes) {
return errors.New("unable to read ciphersuites")
}
if err := chs.ReadCipherSuites(cipherSuitesBytes); err != nil {
return err
}
// CompressionMethods
var compressionMethods cryptobyte.String
if !s.ReadUint8LengthPrefixed(&compressionMethods) {
return errors.New("unable to read compression methods")
}
if err := chs.ReadCompressionMethods(compressionMethods); err != nil {
return err
}
if s.Empty() {
// Extensions are optional
return nil
}
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) {
return errors.New("unable to read extensions data")
}
if err := chs.ReadTLSExtensions(extensions, bluntMimicry, csc); err != nil {
return err return err
} }

View file

@ -19,7 +19,7 @@ type Fingerprinter struct {
// (including things like different SNI lengths) would cause padding to be necessary // (including things like different SNI lengths) would cause padding to be necessary
AlwaysAddPadding bool AlwaysAddPadding bool
ClientSessionCache ClientSessionCache // if set, PSK extension will be made into UtlsPreSharedKeyExtension. Otherwise FakePreSharedKeyExtension. RealPSKResumption bool // if set, PSK extension (if any) will be real PSK extension, otherwise it will be fake PSK extension
} }
// FingerprintClientHello returns a ClientHelloSpec which is based on the // FingerprintClientHello returns a ClientHelloSpec which is based on the
@ -46,11 +46,7 @@ func (f *Fingerprinter) FingerprintClientHello(data []byte) (clientHelloSpec *Cl
func (f *Fingerprinter) RawClientHello(raw []byte) (clientHelloSpec *ClientHelloSpec, err error) { func (f *Fingerprinter) RawClientHello(raw []byte) (clientHelloSpec *ClientHelloSpec, err error) {
clientHelloSpec = &ClientHelloSpec{} clientHelloSpec = &ClientHelloSpec{}
if f.ClientSessionCache != nil { err = clientHelloSpec.FromRaw(raw, f.AllowBluntMimicry, f.RealPSKResumption)
err = clientHelloSpec.FromRawWithClientSessionCache(raw, f.ClientSessionCache, f.AllowBluntMimicry)
} else {
err = clientHelloSpec.FromRaw(raw, f.AllowBluntMimicry)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -2453,7 +2453,10 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
if cs != nil { if cs != nil {
session = cs.session session = cs.session
} }
// TLS 1.3 (PSK) resumption is handled by PreSharedKeyExtension in MarshalClientHello() }
// TLS 1.3 (PSK) resumption is handled by PreSharedKeyExtension in MarshalClientHello()
if session != nil && session.version == VersionTLS13 {
break
} }
err := uconn.SetSessionState(cs) err := uconn.SetSessionState(cs)
if err != nil { if err != nil {

View file

@ -61,13 +61,6 @@ func (*UnimplementedPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int,
type UtlsPreSharedKeyExtension struct { type UtlsPreSharedKeyExtension struct {
UnimplementedPreSharedKeyExtension UnimplementedPreSharedKeyExtension
// ClientSessionCacheOverride is used to specify the ClientSessionCache to be used
// for PSK-resumption.
//
// bug: tls.Config.ClientSessionCache must be nil for PSK-resumption to work, even though
// it is supposed to be overridden by ClientSessionCacheOverride.
ClientSessionCacheOverride ClientSessionCache
identities []pskIdentity identities []pskIdentity
binders [][]byte binders [][]byte
binderKey []byte // this will be used to compute the binder when hello message is ready binderKey []byte // this will be used to compute the binder when hello message is ready
@ -172,12 +165,6 @@ func (e *UtlsPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error)
} }
func (e *UtlsPreSharedKeyExtension) preloadSession(uc *UConn) error { func (e *UtlsPreSharedKeyExtension) preloadSession(uc *UConn) error {
// var sessionCache ClientSessionCache
// must set either e.Session or uc.config.ClientSessionCache
if e.ClientSessionCacheOverride != nil {
uc.config.ClientSessionCache = e.ClientSessionCacheOverride
}
// load Hello // load Hello
hello := uc.HandshakeState.Hello.getPrivatePtr() hello := uc.HandshakeState.Hello.getPrivatePtr()
// try to use loadSession() // try to use loadSession()
@ -203,16 +190,10 @@ func (e *UtlsPreSharedKeyExtension) preloadSession(uc *UConn) error {
} }
func (e *UtlsPreSharedKeyExtension) Write(b []byte) (int, error) { func (e *UtlsPreSharedKeyExtension) Write(b []byte) (int, error) {
if e.ClientSessionCacheOverride == nil {
return 0, errors.New("tls: ClientSessionCache must be set to use UtlsPreSharedKeyExtension")
}
return len(b), nil // ignore the data return len(b), nil // ignore the data
} }
func (e *UtlsPreSharedKeyExtension) UnmarshalJSON(_ []byte) error { func (e *UtlsPreSharedKeyExtension) UnmarshalJSON(_ []byte) error {
if e.ClientSessionCacheOverride == nil {
return errors.New("tls: ClientSessionCache must be set to use UtlsPreSharedKeyExtension")
}
return nil // ignore the data return nil // ignore the data
} }