fix: PSK failing if config session cache set

* Fix a bug causing PSK to fail if Config.ClientSessionCache is set.
* Removed `ClientSessionCacheOverride` from `UtlsPreSharedKeyExtension`. Set the `ClientSessionCache` in `Config`!

Co-Authored-By: zeeker999 <13848632+zeeker999@users.noreply.github.com>
This commit is contained in:
Gaukas Wang 2023-08-16 19:09:01 -06:00
parent 3d7eea3346
commit 3162534cc7
No known key found for this signature in database
GPG key ID: 9E2F8986D76F8B5D
6 changed files with 27 additions and 125 deletions

View file

@ -86,11 +86,9 @@ func main() {
}
tlsConnPSK := tls.UClient(tcpConnPSK, &tls.Config{
ServerName: strings.Split(serverAddr, ":")[0],
// ClientSessionCache: csc, // set this will cause PSK to fail. This is a bug...
}, tls.HelloChrome_100_PSK, &tls.UtlsPreSharedKeyExtension{
ClientSessionCacheOverride: csc, // ONLY set your own ClientSessionCache here if you want to use PSK
})
ServerName: strings.Split(serverAddr, ":")[0],
ClientSessionCache: csc,
}, tls.HelloChrome_100_PSK, &tls.UtlsPreSharedKeyExtension{})
// HS
err = tlsConnPSK.Handshake()

View file

@ -85,9 +85,9 @@ func (c *CompressionMethodsJSONUnmarshaler) CompressionMethods() []uint8 {
}
type TLSExtensionsJSONUnmarshaler struct {
AllowUnknownExt bool // if set, unknown extensions will be added as GenericExtension, without recovering ext payload
ClientSessionCache ClientSessionCache // if set, PSK extension will be Unmarshaled into UtlsPreSharedKeyExtension. Otherwise FakePreSharedKeyExtension.
extensions []TLSExtensionJSON
AllowUnknownExt bool // if set, unknown extensions will be added as GenericExtension, without recovering ext payload
UseRealPSK bool // if set, PSK extension will be real PSK extension, otherwise it will be fake PSK extension
extensions []TLSExtensionJSON
}
func (e *TLSExtensionsJSONUnmarshaler) UnmarshalJSON(jsonStr []byte) error {
@ -120,8 +120,8 @@ func (e *TLSExtensionsJSONUnmarshaler) UnmarshalJSON(jsonStr []byte) error {
switch extID {
case extensionPreSharedKey:
// PSK extension, need to see if we do real or fake PSK
if e.ClientSessionCache != nil {
ext = &UtlsPreSharedKeyExtension{ClientSessionCacheOverride: e.ClientSessionCache}
if e.UseRealPSK {
ext = &UtlsPreSharedKeyExtension{}
} else {
ext = &FakePreSharedKeyExtension{}
}

View file

@ -210,7 +210,7 @@ func (chs *ClientHelloSpec) ReadCompressionMethods(compressionMethods []byte) er
// a byte slice into []TLSExtension.
//
// If keepPSK is not set, the PSK extension will cause an error.
func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool, clientSessionCache ...ClientSessionCache) error {
func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool, realPSK bool) error {
extensions := cryptobyte.String(b)
for !extensions.Empty() {
var extension uint16
@ -228,8 +228,8 @@ func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool,
switch extension {
case extensionPreSharedKey:
// PSK extension, need to see if we do real or fake PSK
if len(clientSessionCache) > 0 && clientSessionCache[0] != nil {
extWriter = &UtlsPreSharedKeyExtension{ClientSessionCacheOverride: clientSessionCache[0]}
if realPSK {
extWriter = &UtlsPreSharedKeyExtension{}
} else {
extWriter = &FakePreSharedKeyExtension{}
}
@ -464,14 +464,20 @@ func (chs *ClientHelloSpec) ImportTLSClientHelloFromJSON(jsonB []byte) error {
}
// FromRaw converts a ClientHello message in the form of raw bytes into a ClientHelloSpec.
func (chs *ClientHelloSpec) FromRaw(raw []byte, allowBluntMimicry ...bool) error {
//
// ctrlFlags: []bool{bluntMimicry, realPSK}
func (chs *ClientHelloSpec) FromRaw(raw []byte, ctrlFlags ...bool) error {
if chs == nil {
return errors.New("cannot unmarshal into nil ClientHelloSpec")
}
var bluntMimicry = false
if len(allowBluntMimicry) == 1 {
bluntMimicry = allowBluntMimicry[0]
var realPSK = false
if len(ctrlFlags) > 0 {
bluntMimicry = ctrlFlags[0]
}
if len(ctrlFlags) > 1 {
realPSK = ctrlFlags[1]
}
*chs = ClientHelloSpec{} // reset
@ -538,89 +544,7 @@ func (chs *ClientHelloSpec) FromRaw(raw []byte, allowBluntMimicry ...bool) error
return errors.New("unable to read extensions data")
}
if err := chs.ReadTLSExtensions(extensions, bluntMimicry); err != nil {
return err
}
return nil
}
// FromRaw converts a ClientHello message in the form of raw bytes into a ClientHelloSpec.
func (chs *ClientHelloSpec) FromRawWithClientSessionCache(raw []byte, csc ClientSessionCache, allowBluntMimicry ...bool) error {
if chs == nil {
return errors.New("cannot unmarshal into nil ClientHelloSpec")
}
var bluntMimicry = false
if len(allowBluntMimicry) == 1 {
bluntMimicry = allowBluntMimicry[0]
}
*chs = ClientHelloSpec{} // reset
s := cryptobyte.String(raw)
var contentType uint8
var recordVersion uint16
if !s.ReadUint8(&contentType) || // record type
!s.ReadUint16(&recordVersion) || !s.Skip(2) { // record version and length
return errors.New("unable to read record type, version, and length")
}
if recordType(contentType) != recordTypeHandshake {
return errors.New("record is not a handshake")
}
var handshakeVersion uint16
var handshakeType uint8
if !s.ReadUint8(&handshakeType) || !s.Skip(3) || // message type and 3 byte length
!s.ReadUint16(&handshakeVersion) || !s.Skip(32) { // 32 byte random
return errors.New("unable to read handshake message type, length, and random")
}
if handshakeType != typeClientHello {
return errors.New("handshake message is not a ClientHello")
}
chs.TLSVersMin = recordVersion
chs.TLSVersMax = handshakeVersion
var ignoredSessionID cryptobyte.String
if !s.ReadUint8LengthPrefixed(&ignoredSessionID) {
return errors.New("unable to read session id")
}
// CipherSuites
var cipherSuitesBytes cryptobyte.String
if !s.ReadUint16LengthPrefixed(&cipherSuitesBytes) {
return errors.New("unable to read ciphersuites")
}
if err := chs.ReadCipherSuites(cipherSuitesBytes); err != nil {
return err
}
// CompressionMethods
var compressionMethods cryptobyte.String
if !s.ReadUint8LengthPrefixed(&compressionMethods) {
return errors.New("unable to read compression methods")
}
if err := chs.ReadCompressionMethods(compressionMethods); err != nil {
return err
}
if s.Empty() {
// Extensions are optional
return nil
}
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) {
return errors.New("unable to read extensions data")
}
if err := chs.ReadTLSExtensions(extensions, bluntMimicry, csc); err != nil {
if err := chs.ReadTLSExtensions(extensions, bluntMimicry, realPSK); err != nil {
return err
}

View file

@ -19,7 +19,7 @@ type Fingerprinter struct {
// (including things like different SNI lengths) would cause padding to be necessary
AlwaysAddPadding bool
ClientSessionCache ClientSessionCache // if set, PSK extension will be made into UtlsPreSharedKeyExtension. Otherwise FakePreSharedKeyExtension.
RealPSKResumption bool // if set, PSK extension (if any) will be real PSK extension, otherwise it will be fake PSK extension
}
// FingerprintClientHello returns a ClientHelloSpec which is based on the
@ -46,11 +46,7 @@ func (f *Fingerprinter) FingerprintClientHello(data []byte) (clientHelloSpec *Cl
func (f *Fingerprinter) RawClientHello(raw []byte) (clientHelloSpec *ClientHelloSpec, err error) {
clientHelloSpec = &ClientHelloSpec{}
if f.ClientSessionCache != nil {
err = clientHelloSpec.FromRawWithClientSessionCache(raw, f.ClientSessionCache, f.AllowBluntMimicry)
} else {
err = clientHelloSpec.FromRaw(raw, f.AllowBluntMimicry)
}
err = clientHelloSpec.FromRaw(raw, f.AllowBluntMimicry, f.RealPSKResumption)
if err != nil {
return nil, err
}

View file

@ -2453,7 +2453,10 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
if cs != nil {
session = cs.session
}
// TLS 1.3 (PSK) resumption is handled by PreSharedKeyExtension in MarshalClientHello()
}
// TLS 1.3 (PSK) resumption is handled by PreSharedKeyExtension in MarshalClientHello()
if session != nil && session.version == VersionTLS13 {
break
}
err := uconn.SetSessionState(cs)
if err != nil {

View file

@ -61,13 +61,6 @@ func (*UnimplementedPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int,
type UtlsPreSharedKeyExtension struct {
UnimplementedPreSharedKeyExtension
// ClientSessionCacheOverride is used to specify the ClientSessionCache to be used
// for PSK-resumption.
//
// bug: tls.Config.ClientSessionCache must be nil for PSK-resumption to work, even though
// it is supposed to be overridden by ClientSessionCacheOverride.
ClientSessionCacheOverride ClientSessionCache
identities []pskIdentity
binders [][]byte
binderKey []byte // this will be used to compute the binder when hello message is ready
@ -172,12 +165,6 @@ func (e *UtlsPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error)
}
func (e *UtlsPreSharedKeyExtension) preloadSession(uc *UConn) error {
// var sessionCache ClientSessionCache
// must set either e.Session or uc.config.ClientSessionCache
if e.ClientSessionCacheOverride != nil {
uc.config.ClientSessionCache = e.ClientSessionCacheOverride
}
// load Hello
hello := uc.HandshakeState.Hello.getPrivatePtr()
// try to use loadSession()
@ -203,16 +190,10 @@ func (e *UtlsPreSharedKeyExtension) preloadSession(uc *UConn) error {
}
func (e *UtlsPreSharedKeyExtension) Write(b []byte) (int, error) {
if e.ClientSessionCacheOverride == nil {
return 0, errors.New("tls: ClientSessionCache must be set to use UtlsPreSharedKeyExtension")
}
return len(b), nil // ignore the data
}
func (e *UtlsPreSharedKeyExtension) UnmarshalJSON(_ []byte) error {
if e.ClientSessionCacheOverride == nil {
return errors.New("tls: ClientSessionCache must be set to use UtlsPreSharedKeyExtension")
}
return nil // ignore the data
}