mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-03 03:57:36 +03:00
377 lines
11 KiB
Go
377 lines
11 KiB
Go
// Copyright 2017 Google Inc. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package tls
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
)
|
|
|
|
type UConn struct {
|
|
*Conn
|
|
|
|
Extensions []TLSExtension
|
|
clientHelloID ClientHelloID
|
|
|
|
HandshakeState ClientHandshakeState
|
|
|
|
HandshakeStateBuilt bool
|
|
}
|
|
|
|
// UClient returns a new uTLS client, with behavior depending on clientHelloID.
|
|
// Config CAN be nil, but make sure to eventually specify ServerName.
|
|
func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn {
|
|
if config == nil {
|
|
config = &Config{}
|
|
}
|
|
tlsConn := Conn{conn: conn, config: config, isClient: true}
|
|
handshakeState := ClientHandshakeState{C: &tlsConn, Hello: &ClientHelloMsg{}}
|
|
uconn := UConn{Conn: &tlsConn, clientHelloID: clientHelloID, HandshakeState: handshakeState}
|
|
return &uconn
|
|
}
|
|
|
|
// BuildHandshakeState() overwrites most fields, therefore, it is advised to manually call this function,
|
|
// if you need to inspect/change contents after parroting/making default Golang ClientHello.
|
|
// Otherwise, there is no need to call this function explicitly.
|
|
func (uconn *UConn) BuildHandshakeState() error {
|
|
if uconn.clientHelloID == HelloGolang {
|
|
// use default Golang ClientHello.
|
|
hello, err := makeClientHello(uconn.config)
|
|
if uconn.HandshakeState.Session != nil {
|
|
// session is lost at makeClientHello(), let's reapply
|
|
uconn.SetSessionState(uconn.HandshakeState.Session)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
uconn.HandshakeState.Hello = hello.getPublicPtr()
|
|
} else {
|
|
err := uconn.generateClientHelloConfig(uconn.clientHelloID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = uconn.applyConfig()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = uconn.marshalClientHello()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
uconn.HandshakeStateBuilt = true
|
|
return nil
|
|
}
|
|
|
|
// If you want you session tickets to be reused - use same cache on following connections
|
|
func (uconn *UConn) SetSessionState(session *ClientSessionState) {
|
|
uconn.HandshakeState.Session = session
|
|
if session != nil {
|
|
uconn.HandshakeState.Hello.SessionTicket = session.sessionTicket
|
|
}
|
|
uconn.HandshakeState.Hello.TicketSupported = true
|
|
for _, ext := range uconn.Extensions {
|
|
st, ok := ext.(*SessionTicketExtension)
|
|
if ok {
|
|
st.Session = session
|
|
}
|
|
}
|
|
}
|
|
|
|
// If you want you session tickets to be reused - use same cache on following connections
|
|
func (uconn *UConn) SetSessionCache(cache ClientSessionCache) {
|
|
uconn.config.ClientSessionCache = cache
|
|
uconn.HandshakeState.Hello.TicketSupported = true
|
|
}
|
|
|
|
// r has to be 32 bytes long
|
|
func (uconn *UConn) SetClientRandom(r []byte) error {
|
|
if len(r) != 32 {
|
|
return errors.New("Incorrect client random length! Expected: 32, got: " + strconv.Itoa(len(r)))
|
|
} else {
|
|
uconn.HandshakeState.Hello.Random = make([]byte, 32)
|
|
copy(uconn.HandshakeState.Hello.Random, r)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (uconn *UConn) SetSNI(sni string) {
|
|
hname := hostnameInSNI(sni)
|
|
uconn.config.ServerName = hname
|
|
for _, ext := range uconn.Extensions {
|
|
sniExt, ok := ext.(*SNIExtension)
|
|
if ok {
|
|
sniExt.ServerName = hname
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handshake runs the client handshake using given clientHandshakeState
|
|
// Requires hs.hello, and, optionally, hs.session to be set.
|
|
func (c *UConn) Handshake() error {
|
|
// This code was copied almost as is from tls/conn.go
|
|
// c.handshakeErr and c.handshakeComplete are protected by
|
|
// c.handshakeMutex. In order to perform a handshake, we need to lock
|
|
// c.in also and c.handshakeMutex must be locked after c.in.
|
|
//
|
|
// However, if a Read() operation is hanging then it'll be holding the
|
|
// lock on c.in and so taking it here would cause all operations that
|
|
// need to check whether a handshake is pending (such as Write) to
|
|
// block.
|
|
//
|
|
// Thus we first take c.handshakeMutex to check whether a handshake is
|
|
// needed.
|
|
//
|
|
// If so then, previously, this code would unlock handshakeMutex and
|
|
// then lock c.in and handshakeMutex in the correct order to run the
|
|
// handshake. The problem was that it was possible for a Read to
|
|
// complete the handshake once handshakeMutex was unlocked and then
|
|
// keep c.in while waiting for network data. Thus a concurrent
|
|
// operation could be blocked on c.in.
|
|
//
|
|
// Thus handshakeCond is used to signal that a goroutine is committed
|
|
// to running the handshake and other goroutines can wait on it if they
|
|
// need. handshakeCond is protected by handshakeMutex.
|
|
c.handshakeMutex.Lock()
|
|
defer c.handshakeMutex.Unlock()
|
|
|
|
for {
|
|
if err := c.handshakeErr; err != nil {
|
|
return err
|
|
}
|
|
if c.handshakeComplete {
|
|
return nil
|
|
}
|
|
if c.handshakeCond == nil {
|
|
break
|
|
}
|
|
|
|
c.handshakeCond.Wait()
|
|
}
|
|
|
|
// Set handshakeCond to indicate that this goroutine is committing to
|
|
// running the handshake.
|
|
c.handshakeCond = sync.NewCond(&c.handshakeMutex)
|
|
c.handshakeMutex.Unlock()
|
|
|
|
c.in.Lock()
|
|
defer c.in.Unlock()
|
|
|
|
c.handshakeMutex.Lock()
|
|
|
|
// The handshake cannot have completed when handshakeMutex was unlocked
|
|
// because this goroutine set handshakeCond.
|
|
if c.handshakeErr != nil || c.handshakeComplete {
|
|
panic("handshake should not have been able to complete after handshakeCond was set")
|
|
}
|
|
|
|
if !c.isClient {
|
|
panic("Servers should not call ClientHandshakeWithState()")
|
|
}
|
|
|
|
if !c.HandshakeStateBuilt {
|
|
err := c.BuildHandshakeState()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
privateState := c.HandshakeState.getPrivatePtr()
|
|
c.handshakeErr = c.clientHandshakeWithState(privateState)
|
|
c.HandshakeState = *privateState.getPublicPtr()
|
|
|
|
if c.handshakeErr == nil {
|
|
c.handshakes++
|
|
} else {
|
|
// If an error occurred during the hadshake try to flush the
|
|
// alert that might be left in the buffer.
|
|
c.flush()
|
|
}
|
|
|
|
if c.handshakeErr == nil && !c.handshakeComplete {
|
|
panic("handshake should have had a result.")
|
|
}
|
|
|
|
// Wake any other goroutines that are waiting for this handshake to complete.
|
|
c.handshakeCond.Broadcast()
|
|
c.handshakeCond = nil
|
|
|
|
return c.handshakeErr
|
|
}
|
|
|
|
// c.out.Mutex <= L; c.handshakeMutex <= L.
|
|
func (c *UConn) clientHandshakeWithState(hs *clientHandshakeState) error {
|
|
// This code was copied almost as is from tls/handshake_client.go
|
|
if c.config == nil {
|
|
c.config = &Config{}
|
|
}
|
|
|
|
// This may be a renegotiation handshake, in which case some fields
|
|
// need to be reset.
|
|
c.didResume = false
|
|
|
|
if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify {
|
|
return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
|
|
}
|
|
|
|
nextProtosLength := 0
|
|
for _, proto := range c.config.NextProtos {
|
|
if l := len(proto); l == 0 || l > 255 {
|
|
return errors.New("tls: invalid NextProtos value")
|
|
} else {
|
|
nextProtosLength += 1 + l
|
|
}
|
|
}
|
|
if nextProtosLength > 0xffff {
|
|
return errors.New("tls: NextProtos values too large")
|
|
}
|
|
|
|
var session *ClientSessionState
|
|
sessionCache := c.config.ClientSessionCache
|
|
cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
|
|
|
|
// If sessionCache is set but session itself isn't - try to retrieve session from cache
|
|
if sessionCache != nil && hs.session != nil {
|
|
hs.hello.ticketSupported = true
|
|
// Session resumption is not allowed if renegotiating because
|
|
// renegotiation is primarily used to allow a client to send a client
|
|
// certificate, which would be skipped if session resumption occurred.
|
|
if c.handshakes == 0 {
|
|
// Try to resume a previously negotiated TLS session, if
|
|
// available.
|
|
candidateSession, ok := sessionCache.Get(cacheKey)
|
|
if ok {
|
|
// Check that the ciphersuite/version used for the
|
|
// previous session are still valid.
|
|
cipherSuiteOk := false
|
|
for _, id := range hs.hello.cipherSuites {
|
|
if id == candidateSession.cipherSuite {
|
|
cipherSuiteOk = true
|
|
break
|
|
}
|
|
}
|
|
|
|
versOk := candidateSession.vers >= c.config.minVersion() &&
|
|
candidateSession.vers <= c.config.maxVersion()
|
|
if versOk && cipherSuiteOk {
|
|
session = candidateSession
|
|
}
|
|
if session != nil {
|
|
hs.hello.sessionTicket = session.sessionTicket
|
|
// A random session ID is used to detect when the
|
|
// server accepted the ticket and is resuming a session
|
|
// (see RFC 5077).
|
|
hs.hello.sessionId = make([]byte, 16)
|
|
if _, err := io.ReadFull(c.config.rand(), hs.hello.sessionId); err != nil {
|
|
return errors.New("tls: short read from Rand: " + err.Error())
|
|
}
|
|
}
|
|
hs.session = session
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := hs.handshake(); err != nil {
|
|
return err
|
|
}
|
|
// If we had a successful handshake and hs.session is different from the one already cached - cache a new one
|
|
if sessionCache != nil && hs.session != nil && hs.session != session {
|
|
sessionCache.Put(cacheKey, hs.session)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (uconn *UConn) applyConfig() error {
|
|
for _, ext := range uconn.Extensions {
|
|
err := ext.writeToUConn(uconn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (uconn *UConn) marshalClientHello() error {
|
|
hello := uconn.HandshakeState.Hello
|
|
headerLength := 2 + 32 + 1 + len(hello.SessionId) +
|
|
2 + len(hello.CipherSuites)*2 +
|
|
1 + len(hello.CompressionMethods)
|
|
|
|
extensionsLen := 0
|
|
var paddingExt *utlsPaddingExtension
|
|
for _, ext := range uconn.Extensions {
|
|
if pe, ok := ext.(*utlsPaddingExtension); !ok {
|
|
// If not padding - just add length of extension to total length
|
|
extensionsLen += ext.Len()
|
|
} else {
|
|
// If padding - process it later
|
|
if paddingExt == nil {
|
|
paddingExt = pe
|
|
} else {
|
|
return errors.New("Multiple padding extensions!")
|
|
}
|
|
}
|
|
}
|
|
|
|
if paddingExt != nil {
|
|
// determine padding extension presence and length
|
|
paddingExt.Update(headerLength + 4 + extensionsLen + 2)
|
|
extensionsLen += paddingExt.Len()
|
|
}
|
|
|
|
helloLen := headerLength
|
|
if len(uconn.Extensions) > 0 {
|
|
helloLen += 2 + extensionsLen // 2 bytes for extensions' length
|
|
}
|
|
|
|
helloBuffer := bytes.Buffer{}
|
|
bufferedWriter := bufio.NewWriterSize(&helloBuffer, helloLen+4) // 1 byte for tls record type, 3 for length
|
|
// We use buffered Writer to avoid checking write errors after every Write(): whenever first error happens
|
|
// Write() will become noop, and error will be accessible via Flush(), which is called once in the end
|
|
|
|
binary.Write(bufferedWriter, binary.BigEndian, typeClientHello)
|
|
helloLenBytes := []byte{byte(helloLen >> 16), byte(helloLen >> 8), byte(helloLen)} // poor man's uint24
|
|
binary.Write(bufferedWriter, binary.BigEndian, helloLenBytes)
|
|
binary.Write(bufferedWriter, binary.BigEndian, hello.Vers)
|
|
|
|
binary.Write(bufferedWriter, binary.BigEndian, hello.Random)
|
|
|
|
binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.SessionId)))
|
|
binary.Write(bufferedWriter, binary.BigEndian, hello.SessionId)
|
|
|
|
binary.Write(bufferedWriter, binary.BigEndian, uint16(len(hello.CipherSuites)<<1))
|
|
for _, suite := range hello.CipherSuites {
|
|
binary.Write(bufferedWriter, binary.BigEndian, suite)
|
|
}
|
|
|
|
binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.CompressionMethods)))
|
|
binary.Write(bufferedWriter, binary.BigEndian, hello.CompressionMethods)
|
|
|
|
if len(uconn.Extensions) > 0 {
|
|
binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen))
|
|
for _, ext := range uconn.Extensions {
|
|
bufferedWriter.ReadFrom(ext)
|
|
}
|
|
}
|
|
|
|
if helloBuffer.Len() != 4+helloLen {
|
|
return errors.New("utls: unexpected ClientHello length. Expected: " + strconv.Itoa(4+helloLen) +
|
|
". Got: " + strconv.Itoa(helloBuffer.Len()))
|
|
}
|
|
|
|
err := bufferedWriter.Flush()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
hello.Raw = helloBuffer.Bytes()
|
|
return nil
|
|
}
|