From e2467ffd0445817c7f91556955c88d9d6302e8b6 Mon Sep 17 00:00:00 2001 From: jmwample Date: Wed, 21 Jun 2023 12:58:04 -0600 Subject: [PATCH] wip - support preshared ticket based session resumption --- u_conn.go | 29 +++++--- u_server.go | 174 +++++++++++++++++++++++++++++++++++++++++++++++ u_server_test.go | 16 +++++ 3 files changed, 209 insertions(+), 10 deletions(-) create mode 100644 u_server.go create mode 100644 u_server_test.go diff --git a/u_conn.go b/u_conn.go index caff710..4b94d9a 100644 --- a/u_conn.go +++ b/u_conn.go @@ -588,6 +588,20 @@ func (uconn *UConn) GetOutKeystream(length int) ([]byte, error) { // Error is only returned if things are in clearly undesirable state // to help user fix them. func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []TLSExtension) error { + + minTLSVers, maxTLSVers, err := getTLSVers(minTLSVers, maxTLSVers, specExtensions) + if err != nil { + return err + } + + uconn.HandshakeState.Hello.SupportedVersions = makeSupportedVersions(minTLSVers, maxTLSVers) + uconn.config.MinVersion = minTLSVers + uconn.config.MaxVersion = maxTLSVers + + return nil +} + +func getTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []TLSExtension) (uint16, uint16, error) { if minTLSVers == 0 && maxTLSVers == 0 { // if version is not set explicitly in the ClientHelloSpec, check the SupportedVersions extension supportedVersionsExtensionsPresent := 0 @@ -615,7 +629,7 @@ func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []T supportedVersionsExtensionsPresent += 1 minTLSVers, maxTLSVers = findVersionsInSupportedVersionsExtensions(ext.Versions) if minTLSVers == 0 && maxTLSVers == 0 { - return fmt.Errorf("SupportedVersions extension has invalid Versions field") + return 0, 0, fmt.Errorf("SupportedVersions extension has invalid Versions field") } // else: proceed } } @@ -626,24 +640,19 @@ func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []T maxTLSVers = VersionTLS12 case 1: default: - return fmt.Errorf("uconn.Extensions contains %v separate SupportedVersions extensions", + return 0, 0, fmt.Errorf("uconn.Extensions contains %v separate SupportedVersions extensions", supportedVersionsExtensionsPresent) } } if minTLSVers < VersionTLS10 || minTLSVers > VersionTLS12 { - return fmt.Errorf("uTLS does not support 0x%X as min version", minTLSVers) + return 0, 0, fmt.Errorf("uTLS does not support 0x%X as min version", minTLSVers) } if maxTLSVers < VersionTLS10 || maxTLSVers > VersionTLS13 { - return fmt.Errorf("uTLS does not support 0x%X as max version", maxTLSVers) + return 0, 0, fmt.Errorf("uTLS does not support 0x%X as max version", maxTLSVers) } - - uconn.HandshakeState.Hello.SupportedVersions = makeSupportedVersions(minTLSVers, maxTLSVers) - uconn.config.MinVersion = minTLSVers - uconn.config.MaxVersion = maxTLSVers - - return nil + return minTLSVers, maxTLSVers, nil } func (uconn *UConn) SetUnderlyingConn(c net.Conn) { diff --git a/u_server.go b/u_server.go new file mode 100644 index 0000000..62603eb --- /dev/null +++ b/u_server.go @@ -0,0 +1,174 @@ +package tls + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "errors" + "fmt" + "io" + "time" +) + +// ServerSessionState contains the information that is serialized into a session +// ticket in order to later resume a connection. +type ServerSessionState struct { + Vers uint16 + CipherSuite uint16 + CreatedAt uint64 + MasterSecret []byte // opaque master_secret<1..2^16-1>; + // struct { opaque certificate<1..2^24-1> } Certificate; + Certificates [][]byte // Certificate certificate_list<0..2^24-1>; + + // usedOldKey is true if the ticket from which this session came from + // was encrypted with an older key and thus should be refreshed. + UsedOldKey bool +} + +// ForgeServerSessionState allows the creation of a Session (and SessionTicket) +// from a (presumably shared) secret value allowing a client to to +// "re-establish" a non-existent previous connection. With these values a +// ClientSessionState can be created to "resume" a session based on the secret +// value known to both the client and the server. +// +// Warning: you should probably not use this function, unless you are absolutely +// sure this is the functionality you are looking for. +func ForgeServerSessionState(masterSecret []byte, chID ClientHelloID) (*ServerSessionState, error) { + config := &Config{} + chSpec, err := utlsIdToSpec(chID) + if err != nil { + return nil, err + } + + clientVersions := []uint16{} + minVers, maxVers, err := getTLSVers(chSpec.TLSVersMin, chSpec.TLSVersMax, chSpec.Extensions) + if err != nil { + return nil, err + } + clientVersions = makeSupportedVersions(minVers, maxVers) + + vers, ok := config.mutualVersion(roleServer, clientVersions) + if !ok { + return nil, fmt.Errorf("unable to select mutual version") + } + + clientCipherSuites := make([]uint16, len(chSpec.CipherSuites)) + copy(clientCipherSuites, chSpec.CipherSuites) + + chosenCiphersuite, err := pickCipherSuite(clientCipherSuites, vers, config) + if err != nil { + return nil, err + } + + sessionState := &ServerSessionState{ + Vers: vers, + CipherSuite: chosenCiphersuite, + CreatedAt: uint64(time.Now().UnixMicro()), + MasterSecret: masterSecret, // TODO + Certificates: nil, + // We are fabricating this session state for the key so it can't be old. + UsedOldKey: false, + } + + return sessionState, nil +} + +// Marshal serializes the sessionState object to bytes. +func (ss *ServerSessionState) Marshal() ([]byte, error) { + pss := ss.toPrivate() + if pss == nil { + return nil, nil + } + return pss.marshal() +} + +func (ss *ServerSessionState) toPrivate() *sessionState { + if ss == nil { + return nil + } + return &sessionState{ + vers: ss.Vers, + cipherSuite: ss.CipherSuite, + createdAt: ss.CreatedAt, + masterSecret: ss.MasterSecret, + certificates: ss.Certificates, + usedOldKey: ss.UsedOldKey, + } +} + +// MakeEncryptedTicket creates an encrypted session ticket that a client can +// then use to "re-establish" a non-existent previous connection. The value +// provided as keyBytes should be added to the servers ticketKeys using something +// like SetSessionKeys. +func (ss *ServerSessionState) MakeEncryptedTicket(keyBytes [32]byte, config *Config) ([]byte, error) { + if config == nil { + config = &Config{} + } + key := config.ticketKeyFromBytes(keyBytes) + state, err := ss.Marshal() + if err != nil { + return nil, err + } + + encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size) + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + + if _, err := io.ReadFull(config.rand(), iv); err != nil { + return nil, err + } + + copy(keyName, key.keyName[:]) + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) + } + cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state) + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + mac.Sum(macBytes[:0]) + + return encrypted, nil +} + +func pickCipherSuite(clientCipherSuites []uint16, vers uint16, config *Config) (uint16, error) { + preferenceOrder := cipherSuitesPreferenceOrder + if !hasAESGCMHardwareSupport || !aesgcmPreferred(clientCipherSuites) { + preferenceOrder = cipherSuitesPreferenceOrderNoAES + } + + configCipherSuites := config.cipherSuites() + preferenceList := make([]uint16, 0, len(configCipherSuites)) + for _, suiteID := range preferenceOrder { + for _, id := range configCipherSuites { + if id == suiteID { + preferenceList = append(preferenceList, id) + break + } + } + } + + var cipherSuiteOk = func(*cipherSuite) bool { + return true + } + suite := selectCipherSuite(preferenceList, clientCipherSuites, cipherSuiteOk) + if suite == nil { + return 0, errors.New("tls: no cipher suite supported by both client and server") + } + cipherSuite := suite.id + + for _, id := range clientCipherSuites { + if id == TLS_FALLBACK_SCSV { + // The client is doing a fallback connection. See RFC 7507. + if vers < config.maxSupportedVersion(roleServer) { + return 0, errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + return cipherSuite, nil +} diff --git a/u_server_test.go b/u_server_test.go new file mode 100644 index 0000000..79132e0 --- /dev/null +++ b/u_server_test.go @@ -0,0 +1,16 @@ +package tls + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUTLSServerForgeSession(t *testing.T) { + require.Nil(t, nil) +} + + +func TestUTLSServerForgeSession13(t *testing.T) { + require.Nil(t, nil) +}