utls/u_conn.go
Mingye Chen 9b2ccad920 feat: add support for ECH when using custom clienthello specs
Use the standard library's client hello marshaller to generate the ClientHelloInner and replace any extension that implements the EncryptedClientHelloExtension interface in the extensions list with it.
2025-03-08 16:42:47 -07:00

985 lines
30 KiB
Go

// Copyright 2017 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bufio"
"bytes"
"context"
"crypto/cipher"
"encoding/binary"
"errors"
"fmt"
"hash"
"net"
"slices"
"strconv"
"golang.org/x/crypto/cryptobyte"
)
type ClientHelloBuildStatus int
const NotBuilt ClientHelloBuildStatus = 0
const BuildByUtls ClientHelloBuildStatus = 1
const BuildByGoTLS ClientHelloBuildStatus = 2
type UConn struct {
*Conn
Extensions []TLSExtension
ClientHelloID ClientHelloID
sessionController *sessionController
clientHelloBuildStatus ClientHelloBuildStatus
clientHelloSpec *ClientHelloSpec
HandshakeState PubClientHandshakeState
greaseSeed [ssl_grease_last_index]uint16
omitSNIExtension bool
// skipResumptionOnNilExtension is copied from `Config.PreferSkipResumptionOnNilExtension`.
//
// By default, if ClientHelloSpec is predefined or utls-generated (as opposed to HelloCustom), this flag will be updated to true.
skipResumptionOnNilExtension bool
// certCompressionAlgs represents the set of advertised certificate compression
// algorithms, as specified in the ClientHello. This is only relevant client-side, for the
// server certificate. All other forms of certificate compression are unsupported.
certCompressionAlgs []CertCompressionAlgo
// ech extension is a shortcut to the ECH extension in the Extensions slice if there is one.
ech ECHExtension
// echCtx is the echContex returned by makeClientHello()
echCtx *echClientContext
}
// UClient returns a new uTLS client, with behavior depending on clientHelloID.
// Config CAN be nil, but make sure to eventually specify ServerName.
func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn {
if config == nil {
config = &Config{}
}
tlsConn := Conn{conn: conn, config: config, isClient: true}
handshakeState := PubClientHandshakeState{C: &tlsConn, Hello: &PubClientHelloMsg{}}
uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, HandshakeState: handshakeState}
uconn.HandshakeState.uconn = &uconn
uconn.handshakeFn = uconn.clientHandshake
uconn.sessionController = newSessionController(&uconn)
uconn.utls.sessionController = uconn.sessionController
uconn.skipResumptionOnNilExtension = config.PreferSkipResumptionOnNilExtension || clientHelloID.Client != helloCustom
return &uconn
}
// BuildHandshakeState behavior varies based on ClientHelloID and
// whether it was already called before.
// If HelloGolang:
//
// [only once] make default ClientHello and overwrite existing state
//
// If any other mimicking ClientHelloID is used:
//
// [only once] make ClientHello based on ID and overwrite existing state
// [each call] apply uconn.Extensions config to internal crypto/tls structures
// [each call] marshal ClientHello.
//
// BuildHandshakeState is automatically called before uTLS performs handshake,
// and should only be called explicitly to inspect/change fields of
// default/mimicked ClientHello.
// With the excpetion of session ticket and psk extensions, which cannot be changed
// after calling BuildHandshakeState, all other fields can be modified.
func (uconn *UConn) BuildHandshakeState() error {
return uconn.buildHandshakeState(true)
}
// BuildHandshakeStateWithoutSession is the same as BuildHandshakeState, but does not
// set the session. This is only useful when you want to inspect the ClientHello before
// setting the session manually through SetSessionTicketExtension or SetPSKExtension.
// BuildHandshakeState is automatically called before uTLS performs handshake.
func (uconn *UConn) BuildHandshakeStateWithoutSession() error {
return uconn.buildHandshakeState(false)
}
func (uconn *UConn) buildHandshakeState(loadSession bool) error {
if uconn.ClientHelloID == HelloGolang {
if uconn.clientHelloBuildStatus == BuildByGoTLS {
return nil
}
uAssert(uconn.clientHelloBuildStatus == NotBuilt, "BuildHandshakeState failed: invalid call, client hello has already been built by utls")
// use default Golang ClientHello.
hello, keySharePrivate, ech, err := uconn.makeClientHello()
if err != nil {
return err
}
uconn.HandshakeState.Hello = hello.getPublicPtr()
uconn.HandshakeState.State13.KeyShareKeys = keySharePrivate.ToPublic()
uconn.HandshakeState.C = uconn.Conn
uconn.echCtx = ech
uconn.clientHelloBuildStatus = BuildByGoTLS
} else {
uAssert(uconn.clientHelloBuildStatus == BuildByUtls || uconn.clientHelloBuildStatus == NotBuilt, "BuildHandshakeState failed: invalid call, client hello has already been built by go-tls")
if uconn.clientHelloBuildStatus == NotBuilt {
err := uconn.applyPresetByID(uconn.ClientHelloID)
if err != nil {
return err
}
if uconn.omitSNIExtension {
uconn.removeSNIExtension()
}
}
err := uconn.ApplyConfig()
if err != nil {
return err
}
if loadSession {
err = uconn.uLoadSession()
if err != nil {
return err
}
}
err = uconn.MarshalClientHello()
if err != nil {
return err
}
if loadSession {
uconn.uApplyPatch()
uconn.sessionController.finalCheck()
uconn.clientHelloBuildStatus = BuildByUtls
}
}
return nil
}
func (uconn *UConn) uLoadSession() error {
if cfg := uconn.config; cfg.SessionTicketsDisabled || cfg.ClientSessionCache == nil {
return nil
}
switch uconn.sessionController.shouldLoadSession() {
case shouldReturn:
case shouldSetTicket:
uconn.sessionController.setSessionTicketToUConn()
case shouldSetPsk:
uconn.sessionController.setPskToUConn()
case shouldLoad:
hello := uconn.HandshakeState.Hello.getPrivatePtr()
uconn.sessionController.utlsAboutToLoadSession()
session, earlySecret, binderKey, err := uconn.loadSession(hello)
if session == nil || err != nil {
return err
}
if session.version == VersionTLS12 {
// We use the session ticket extension for tls 1.2 session resumption
uconn.sessionController.initSessionTicketExt(session, hello.sessionTicket)
uconn.sessionController.setSessionTicketToUConn()
} else {
uconn.sessionController.initPskExt(session, earlySecret, binderKey, hello.pskIdentities)
}
}
return nil
}
func (uconn *UConn) uApplyPatch() {
helloLen := len(uconn.HandshakeState.Hello.Raw)
if uconn.sessionController.shouldUpdateBinders() {
uconn.sessionController.updateBinders()
uconn.sessionController.setPskToUConn()
}
uAssert(helloLen == len(uconn.HandshakeState.Hello.Raw), "tls: uApplyPatch Failed: the patch should never change the length of the marshaled clientHello")
}
func (uconn *UConn) DidTls12Resume() bool {
return uconn.didResume
}
// SetSessionState sets the session ticket, which may be preshared or fake.
// If session is nil, the body of session ticket extension will be unset,
// but the extension itself still MAY be present for mimicking purposes.
// Session tickets to be reused - use same cache on following connections.
//
// Deprecated: This method is deprecated in favor of SetSessionTicketExtension,
// as it only handles session override of TLS 1.2
func (uconn *UConn) SetSessionState(session *ClientSessionState) error {
sessionTicketExt := &SessionTicketExtension{Initialized: true}
if session != nil {
sessionTicketExt.Ticket = session.session.ticket
sessionTicketExt.Session = session.session
}
return uconn.SetSessionTicketExtension(sessionTicketExt)
}
// SetSessionTicket sets the session ticket extension.
// If extension is nil, this will be a no-op.
func (uconn *UConn) SetSessionTicketExtension(sessionTicketExt ISessionTicketExtension) error {
if uconn.config.SessionTicketsDisabled || uconn.config.ClientSessionCache == nil {
return fmt.Errorf("tls: SetSessionTicketExtension failed: session is disabled")
}
if sessionTicketExt == nil {
return nil
}
return uconn.sessionController.overrideSessionTicketExt(sessionTicketExt)
}
// SetPskExtension sets the psk extension for tls 1.3 resumption. This is a no-op if the psk is nil.
func (uconn *UConn) SetPskExtension(pskExt PreSharedKeyExtension) error {
if uconn.config.SessionTicketsDisabled || uconn.config.ClientSessionCache == nil {
return fmt.Errorf("tls: SetPskExtension failed: session is disabled")
}
if pskExt == nil {
return nil
}
uconn.HandshakeState.Hello.TicketSupported = true
return uconn.sessionController.overridePskExt(pskExt)
}
// If you want session tickets to be reused - use same cache on following connections
func (uconn *UConn) SetSessionCache(cache ClientSessionCache) {
uconn.config.ClientSessionCache = cache
uconn.HandshakeState.Hello.TicketSupported = true
}
// SetClientRandom sets client random explicitly.
// BuildHandshakeFirst() must be called before SetClientRandom.
// r must to be 32 bytes long.
func (uconn *UConn) SetClientRandom(r []byte) error {
if len(r) != 32 {
return errors.New("Incorrect client random length! Expected: 32, got: " + strconv.Itoa(len(r)))
} else {
uconn.HandshakeState.Hello.Random = make([]byte, 32)
copy(uconn.HandshakeState.Hello.Random, r)
return nil
}
}
func (uconn *UConn) SetSNI(sni string) {
hname := hostnameInSNI(sni)
uconn.config.ServerName = hname
for _, ext := range uconn.Extensions {
sniExt, ok := ext.(*SNIExtension)
if ok {
sniExt.ServerName = hname
}
}
}
// RemoveSNIExtension removes SNI from the list of extensions sent in ClientHello
// It returns an error when used with HelloGolang ClientHelloID
func (uconn *UConn) RemoveSNIExtension() error {
if uconn.ClientHelloID == HelloGolang {
return fmt.Errorf("cannot call RemoveSNIExtension on a UConn with a HelloGolang ClientHelloID")
}
uconn.omitSNIExtension = true
return nil
}
func (uconn *UConn) removeSNIExtension() {
filteredExts := make([]TLSExtension, 0, len(uconn.Extensions))
for _, e := range uconn.Extensions {
if _, ok := e.(*SNIExtension); !ok {
filteredExts = append(filteredExts, e)
}
}
uconn.Extensions = filteredExts
}
// Handshake runs the client handshake using given clientHandshakeState
// Requires hs.hello, and, optionally, hs.session to be set.
func (c *UConn) Handshake() error {
return c.HandshakeContext(context.Background())
}
// HandshakeContext runs the client or server handshake
// protocol if it has not yet been run.
//
// The provided Context must be non-nil. If the context is canceled before
// the handshake is complete, the handshake is interrupted and an error is returned.
// Once the handshake has completed, cancellation of the context will not affect the
// connection.
func (c *UConn) HandshakeContext(ctx context.Context) error {
// Delegate to unexported method for named return
// without confusing documented signature.
return c.handshakeContext(ctx)
}
func (c *UConn) handshakeContext(ctx context.Context) (ret error) {
// Fast sync/atomic-based exit if there is no handshake in flight and the
// last one succeeded without an error. Avoids the expensive context setup
// and mutex for most Read and Write calls.
if c.isHandshakeComplete.Load() {
return nil
}
handshakeCtx, cancel := context.WithCancel(ctx)
// Note: defer this before starting the "interrupter" goroutine
// so that we can tell the difference between the input being canceled and
// this cancellation. In the former case, we need to close the connection.
defer cancel()
// Start the "interrupter" goroutine, if this context might be canceled.
// (The background context cannot).
//
// The interrupter goroutine waits for the input context to be done and
// closes the connection if this happens before the function returns.
if c.quic != nil {
c.quic.cancelc = handshakeCtx.Done()
c.quic.cancel = cancel
} else if ctx.Done() != nil {
done := make(chan struct{})
interruptRes := make(chan error, 1)
defer func() {
close(done)
if ctxErr := <-interruptRes; ctxErr != nil {
// Return context error to user.
ret = ctxErr
}
}()
go func() {
select {
case <-handshakeCtx.Done():
// Close the connection, discarding the error
_ = c.conn.Close()
interruptRes <- handshakeCtx.Err()
case <-done:
interruptRes <- nil
}
}()
}
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if err := c.handshakeErr; err != nil {
return err
}
if c.isHandshakeComplete.Load() {
return nil
}
c.in.Lock()
defer c.in.Unlock()
// [uTLS section begins]
if c.isClient {
err := c.BuildHandshakeState()
if err != nil {
return err
}
}
// [uTLS section ends]
c.handshakeErr = c.handshakeFn(handshakeCtx)
if c.handshakeErr == nil {
c.handshakes++
} else {
// If an error occurred during the hadshake try to flush the
// alert that might be left in the buffer.
c.flush()
}
if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
}
if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
panic("tls: internal error: handshake returned an error but is marked successful")
}
if c.quic != nil {
if c.handshakeErr == nil {
c.quicHandshakeComplete()
// Provide the 1-RTT read secret now that the handshake is complete.
// The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing
// the handshake (RFC 9001, Section 5.7).
c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
} else {
var a alert
c.out.Lock()
if !errors.As(c.out.err, &a) {
a = alertInternalError
}
c.out.Unlock()
// Return an error which wraps both the handshake error and
// any alert error we may have sent, or alertInternalError
// if we didn't send an alert.
// Truncate the text of the alert to 0 characters.
c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
}
close(c.quic.blockedc)
close(c.quic.signalc)
}
return c.handshakeErr
}
// Copy-pasted from tls.Conn in its entirety. But c.Handshake() is now utls' one, not tls.
// Write writes data to the connection.
func (c *UConn) Write(b []byte) (int, error) {
// interlock with Close below
for {
x := c.activeCall.Load()
if x&1 != 0 {
return 0, net.ErrClosed
}
if c.activeCall.CompareAndSwap(x, x+2) {
defer c.activeCall.Add(-2)
break
}
}
if err := c.Handshake(); err != nil {
return 0, err
}
c.out.Lock()
defer c.out.Unlock()
if err := c.out.err; err != nil {
return 0, err
}
if !c.isHandshakeComplete.Load() {
return 0, alertInternalError
}
if c.closeNotifySent {
return 0, errShutdown
}
// SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
// attack when using block mode ciphers due to predictable IVs.
// This can be prevented by splitting each Application Data
// record into two records, effectively randomizing the IV.
//
// https://www.openssl.org/~bodo/tls-cbc.txt
// https://bugzilla.mozilla.org/show_bug.cgi?id=665814
// https://www.imperialviolet.org/2012/01/15/beastfollowup.html
var m int
if len(b) > 1 && c.vers <= VersionTLS10 {
if _, ok := c.out.cipher.(cipher.BlockMode); ok {
n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
if err != nil {
return n, c.out.setErrorLocked(err)
}
m, b = 1, b[1:]
}
}
n, err := c.writeRecordLocked(recordTypeApplicationData, b)
return n + m, c.out.setErrorLocked(err)
}
func (uconn *UConn) ApplyConfig() error {
for _, ext := range uconn.Extensions {
err := ext.writeToUConn(uconn)
if err != nil {
return err
}
}
return nil
}
func (uconn *UConn) extensionsList() []uint16 {
outerExts := []uint16{}
for _, ext := range uconn.Extensions {
buffer := cryptobyte.String(make([]byte, 2000))
ext.Read(buffer)
var extension uint16
buffer.ReadUint16(&extension)
outerExts = append(outerExts, extension)
}
return outerExts
}
func (uconn *UConn) MarshalClientHello() error {
if len(uconn.config.EncryptedClientHelloConfigList) > 0 {
inner, _, ech, err := uconn.makeClientHello()
if err != nil {
return err
}
inner.keyShares = KeyShares(uconn.HandshakeState.Hello.KeyShares).ToPrivate()
inner.supportedSignatureAlgorithms = uconn.HandshakeState.Hello.SupportedSignatureAlgorithms
inner.sessionId = uconn.HandshakeState.Hello.SessionId
inner.supportedCurves = uconn.HandshakeState.Hello.SupportedCurves
inner.supportedVersions = []uint16{VersionTLS13}
ech.innerHello = inner
encapKey := ech.encapsulatedKey
encodedInner, err := encodeInnerClientHelloReorderOuterExts(inner, int(ech.config.MaxNameLength), uconn.extensionsList())
if err != nil {
return err
}
// NOTE: the tag lengths for all of the supported AEADs are the same (16
// bytes), so we have hardcoded it here. If we add support for another AEAD
// with a different tag length, we will need to change this.
encryptedLen := len(encodedInner) + 16 // AEAD tag length
outerECHExt, err := generateOuterECHExt(ech.config.ConfigID, ech.kdfID, ech.aeadID, encapKey, make([]byte, encryptedLen))
if err != nil {
return err
}
sniExtIdex := slices.IndexFunc(uconn.Extensions, func(ext TLSExtension) bool {
_, ok := ext.(*SNIExtension)
return ok
})
uconn.Extensions[sniExtIdex] = &SNIExtension{
ServerName: string(ech.config.PublicName),
}
echExtIdx := slices.IndexFunc(uconn.Extensions, func(ext TLSExtension) bool {
_, ok := ext.(EncryptedClientHelloExtension)
return ok
})
uconn.Extensions[echExtIdx] = &GenericExtension{
Id: extensionEncryptedClientHello,
Data: outerECHExt,
}
// uconn.HandshakeState.Hello.Random = make([]byte, 32)
// _, err = io.ReadFull(uconn.config.rand(), uconn.HandshakeState.Hello.Random)
// if err != nil {
// return errors.New("tls: short read from Rand: " + err.Error())
// }
if err := uconn.MarshalClientHelloNoECH(); err != nil {
return err
}
serializedOuter := uconn.HandshakeState.Hello.Raw
serializedOuter = serializedOuter[4:] // strip the four byte prefix
encryptedInner, err := ech.hpkeContext.Seal(serializedOuter, encodedInner)
if err != nil {
return err
}
outerECHExt, err = generateOuterECHExt(ech.config.ConfigID, ech.kdfID, ech.aeadID, encapKey, encryptedInner)
if err != nil {
return err
}
uconn.Extensions[echExtIdx] = &GenericExtension{
Id: extensionEncryptedClientHello,
Data: outerECHExt,
}
if err := uconn.MarshalClientHelloNoECH(); err != nil {
return err
}
uconn.echCtx = ech
return nil
}
if err := uconn.MarshalClientHelloNoECH(); err != nil {
return err
}
return nil
}
// MarshalClientHelloNoECH marshals ClientHello as if there was no
// ECH extension present.
func (uconn *UConn) MarshalClientHelloNoECH() error {
hello := uconn.HandshakeState.Hello
headerLength := 2 + 32 + 1 + len(hello.SessionId) +
2 + len(hello.CipherSuites)*2 +
1 + len(hello.CompressionMethods)
extensionsLen := 0
var paddingExt *UtlsPaddingExtension // reference to padding extension, if present
for _, ext := range uconn.Extensions {
if pe, ok := ext.(*UtlsPaddingExtension); !ok {
// If not padding - just add length of extension to total length
extensionsLen += ext.Len()
} else {
// If padding - process it later
if paddingExt == nil {
paddingExt = pe
} else {
return errors.New("multiple padding extensions")
}
}
}
if paddingExt != nil {
// determine padding extension presence and length
paddingExt.Update(headerLength + 4 + extensionsLen + 2)
extensionsLen += paddingExt.Len()
}
helloLen := headerLength
if len(uconn.Extensions) > 0 {
helloLen += 2 + extensionsLen // 2 bytes for extensions' length
}
helloBuffer := bytes.Buffer{}
bufferedWriter := bufio.NewWriterSize(&helloBuffer, helloLen+4) // 1 byte for tls record type, 3 for length
// We use buffered Writer to avoid checking write errors after every Write(): whenever first error happens
// Write() will become noop, and error will be accessible via Flush(), which is called once in the end
binary.Write(bufferedWriter, binary.BigEndian, typeClientHello)
helloLenBytes := []byte{byte(helloLen >> 16), byte(helloLen >> 8), byte(helloLen)} // poor man's uint24
binary.Write(bufferedWriter, binary.BigEndian, helloLenBytes)
binary.Write(bufferedWriter, binary.BigEndian, hello.Vers)
binary.Write(bufferedWriter, binary.BigEndian, hello.Random)
binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.SessionId)))
binary.Write(bufferedWriter, binary.BigEndian, hello.SessionId)
binary.Write(bufferedWriter, binary.BigEndian, uint16(len(hello.CipherSuites)<<1))
for _, suite := range hello.CipherSuites {
binary.Write(bufferedWriter, binary.BigEndian, suite)
}
binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.CompressionMethods)))
binary.Write(bufferedWriter, binary.BigEndian, hello.CompressionMethods)
if len(uconn.Extensions) > 0 {
binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen))
for _, ext := range uconn.Extensions {
if _, err := bufferedWriter.ReadFrom(ext); err != nil {
return err
}
}
}
err := bufferedWriter.Flush()
if err != nil {
return err
}
if helloBuffer.Len() != 4+helloLen {
return errors.New("utls: unexpected ClientHello length. Expected: " + strconv.Itoa(4+helloLen) +
". Got: " + strconv.Itoa(helloBuffer.Len()))
}
hello.Raw = helloBuffer.Bytes()
return nil
}
// get current state of cipher and encrypt zeros to get keystream
func (uconn *UConn) GetOutKeystream(length int) ([]byte, error) {
zeros := make([]byte, length)
if outCipher, ok := uconn.out.cipher.(cipher.AEAD); ok {
// AEAD.Seal() does not mutate internal state, other ciphers might
return outCipher.Seal(nil, uconn.out.seq[:], zeros, nil), nil
}
return nil, errors.New("could not convert OutCipher to cipher.AEAD")
}
// SetTLSVers sets min and max TLS version in all appropriate places.
// Function will use first non-zero version parsed in following order:
// 1. Provided minTLSVers, maxTLSVers
// 2. specExtensions may have SupportedVersionsExtension
// 3. [default] min = TLS 1.0, max = TLS 1.2
//
// Error is only returned if things are in clearly undesirable state
// to help user fix them.
func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []TLSExtension) error {
if minTLSVers == 0 && maxTLSVers == 0 {
// if version is not set explicitly in the ClientHelloSpec, check the SupportedVersions extension
supportedVersionsExtensionsPresent := 0
for _, e := range specExtensions {
switch ext := e.(type) {
case *SupportedVersionsExtension:
findVersionsInSupportedVersionsExtensions := func(versions []uint16) (uint16, uint16) {
// returns (minVers, maxVers)
minVers := uint16(0)
maxVers := uint16(0)
for _, vers := range versions {
if isGREASEUint16(vers) {
continue
}
if maxVers < vers || maxVers == 0 {
maxVers = vers
}
if minVers > vers || minVers == 0 {
minVers = vers
}
}
return minVers, maxVers
}
supportedVersionsExtensionsPresent += 1
minTLSVers, maxTLSVers = findVersionsInSupportedVersionsExtensions(ext.Versions)
if minTLSVers == 0 && maxTLSVers == 0 {
return fmt.Errorf("SupportedVersions extension has invalid Versions field")
} // else: proceed
}
}
switch supportedVersionsExtensionsPresent {
case 0:
// if mandatory for TLS 1.3 extension is not present, just default to 1.2
minTLSVers = VersionTLS10
maxTLSVers = VersionTLS12
case 1:
default:
return fmt.Errorf("uconn.Extensions contains %v separate SupportedVersions extensions",
supportedVersionsExtensionsPresent)
}
}
if minTLSVers < VersionTLS10 || minTLSVers > VersionTLS13 {
return fmt.Errorf("uTLS does not support 0x%X as min version", minTLSVers)
}
if maxTLSVers < VersionTLS10 || maxTLSVers > VersionTLS13 {
return fmt.Errorf("uTLS does not support 0x%X as max version", maxTLSVers)
}
uconn.HandshakeState.Hello.SupportedVersions = makeSupportedVersions(minTLSVers, maxTLSVers)
uconn.config.MinVersion = minTLSVers
uconn.config.MaxVersion = maxTLSVers
return nil
}
func (uconn *UConn) SetUnderlyingConn(c net.Conn) {
uconn.Conn.conn = c
}
func (uconn *UConn) GetUnderlyingConn() net.Conn {
return uconn.Conn.conn
}
// MakeConnWithCompleteHandshake allows to forge both server and client side TLS connections.
// Major Hack Alert.
func MakeConnWithCompleteHandshake(tcpConn net.Conn, version uint16, cipherSuite uint16, masterSecret []byte, clientRandom []byte, serverRandom []byte, isClient bool) *Conn {
tlsConn := &Conn{conn: tcpConn, config: &Config{}, isClient: isClient}
cs := cipherSuiteByID(cipherSuite)
if cs != nil {
// This is mostly borrowed from establishKeys()
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
keysFromMasterSecret(version, cs, masterSecret, clientRandom, serverRandom,
cs.macLen, cs.keyLen, cs.ivLen)
var clientCipher, serverCipher interface{}
var clientHash, serverHash hash.Hash
if cs.cipher != nil {
clientCipher = cs.cipher(clientKey, clientIV, true /* for reading */)
clientHash = cs.mac(clientMAC)
serverCipher = cs.cipher(serverKey, serverIV, false /* not for reading */)
serverHash = cs.mac(serverMAC)
} else {
clientCipher = cs.aead(clientKey, clientIV)
serverCipher = cs.aead(serverKey, serverIV)
}
if isClient {
tlsConn.in.prepareCipherSpec(version, serverCipher, serverHash)
tlsConn.out.prepareCipherSpec(version, clientCipher, clientHash)
} else {
tlsConn.in.prepareCipherSpec(version, clientCipher, clientHash)
tlsConn.out.prepareCipherSpec(version, serverCipher, serverHash)
}
// skip the handshake states
tlsConn.isHandshakeComplete.Store(true)
tlsConn.cipherSuite = cipherSuite
tlsConn.haveVers = true
tlsConn.vers = version
// Update to the new cipher specs
// and consume the finished messages
tlsConn.in.changeCipherSpec()
tlsConn.out.changeCipherSpec()
tlsConn.in.incSeq()
tlsConn.out.incSeq()
return tlsConn
} else {
// TODO: Support TLS 1.3 Cipher Suites
return nil
}
}
func makeSupportedVersions(minVers, maxVers uint16) []uint16 {
a := make([]uint16, maxVers-minVers+1)
for i := range a {
a[i] = maxVers - uint16(i)
}
return a
}
// Extending (*Conn).readHandshake() to support more customized handshake messages.
func (c *Conn) utlsHandshakeMessageType(msgType byte) (handshakeMessage, error) {
switch msgType {
case utlsTypeCompressedCertificate:
return new(utlsCompressedCertificateMsg), nil
case utlsTypeEncryptedExtensions:
if c.isClient {
return new(encryptedExtensionsMsg), nil
} else {
return new(utlsClientEncryptedExtensionsMsg), nil
}
default:
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
}
// Extending (*Conn).connectionStateLocked()
func (c *Conn) utlsConnectionStateLocked(state *ConnectionState) {
state.PeerApplicationSettings = c.utls.peerApplicationSettings
state.ECHRetryConfigs = c.utls.echRetryConfigs
}
type utlsConnExtraFields struct {
// Application Settings (ALPS)
hasApplicationSettings bool
peerApplicationSettings []byte
localApplicationSettings []byte
// Encrypted Client Hello (ECH)
echRetryConfigs []ECHConfig
sessionController *sessionController
}
// Read reads data from the connection.
//
// As Read calls [Conn.Handshake], in order to prevent indefinite blocking a deadline
// must be set for both Read and [Conn.Write] before Read is called when the handshake
// has not yet completed. See [Conn.SetDeadline], [Conn.SetReadDeadline], and
// [Conn.SetWriteDeadline].
func (c *UConn) Read(b []byte) (int, error) {
if err := c.Handshake(); err != nil {
return 0, err
}
if len(b) == 0 {
// Put this after Handshake, in case people were calling
// Read(nil) for the side effect of the Handshake.
return 0, nil
}
c.in.Lock()
defer c.in.Unlock()
for c.input.Len() == 0 {
if err := c.readRecord(); err != nil {
return 0, err
}
for c.hand.Len() > 0 {
if err := c.handlePostHandshakeMessage(); err != nil {
return 0, err
}
}
}
n, _ := c.input.Read(b)
// If a close-notify alert is waiting, read it so that we can return (n,
// EOF) instead of (n, nil), to signal to the HTTP response reading
// goroutine that the connection is now closed. This eliminates a race
// where the HTTP response reading goroutine would otherwise not observe
// the EOF until its next read, by which time a client goroutine might
// have already tried to reuse the HTTP connection for a new request.
// See https://golang.org/cl/76400046 and https://golang.org/issue/3514
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
if err := c.readRecord(); err != nil {
return n, err // will be io.EOF on closeNotify
}
}
return n, nil
}
// handleRenegotiation processes a HelloRequest handshake message.
func (c *UConn) handleRenegotiation() error {
if c.vers == VersionTLS13 {
return errors.New("tls: internal error: unexpected renegotiation")
}
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
helloReq, ok := msg.(*helloRequestMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(helloReq, msg)
}
if !c.isClient {
return c.sendAlert(alertNoRenegotiation)
}
switch c.config.Renegotiation {
case RenegotiateNever:
return c.sendAlert(alertNoRenegotiation)
case RenegotiateOnceAsClient:
if c.handshakes > 1 {
return c.sendAlert(alertNoRenegotiation)
}
case RenegotiateFreelyAsClient:
// Ok.
default:
c.sendAlert(alertInternalError)
return errors.New("tls: unknown Renegotiation value")
}
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
c.isHandshakeComplete.Store(false)
// [uTLS section begins]
if err = c.BuildHandshakeState(); err != nil {
return err
}
// [uTLS section ends]
if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
c.handshakes++
}
return c.handshakeErr
}
// handlePostHandshakeMessage processes a handshake message arrived after the
// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation.
func (c *UConn) handlePostHandshakeMessage() error {
if c.vers != VersionTLS13 {
return c.handleRenegotiation()
}
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
c.retryCount++
if c.retryCount > maxUselessRecords {
c.sendAlert(alertUnexpectedMessage)
return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
}
switch msg := msg.(type) {
case *newSessionTicketMsgTLS13:
return c.handleNewSessionTicket(msg)
case *keyUpdateMsg:
return c.handleKeyUpdate(msg)
}
// The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
// as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
// unexpected_message alert here doesn't provide it with enough information to distinguish
// this condition from other unexpected messages. This is probably fine.
c.sendAlert(alertUnexpectedMessage)
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
}