mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-03 03:57:36 +03:00
469 lines
14 KiB
Go
469 lines
14 KiB
Go
// Copyright 2017 Google Inc. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package tls
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"crypto/cipher"
|
|
"encoding/binary"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"sync/atomic"
|
|
)
|
|
|
|
type UConn struct {
|
|
*Conn
|
|
|
|
Extensions []TLSExtension
|
|
clientHelloID ClientHelloID
|
|
|
|
HandshakeState ClientHandshakeState
|
|
|
|
HandshakeStateBuilt bool
|
|
|
|
// sessionID may or may not depend on ticket; nil => random
|
|
GetSessionID func(ticket []byte) [32]byte
|
|
|
|
greaseSeed [ssl_grease_last_index]uint16
|
|
}
|
|
|
|
// UClient returns a new uTLS client, with behavior depending on clientHelloID.
|
|
// Config CAN be nil, but make sure to eventually specify ServerName.
|
|
func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn {
|
|
if config == nil {
|
|
config = &Config{}
|
|
}
|
|
tlsConn := Conn{conn: conn, config: config, isClient: true}
|
|
handshakeState := ClientHandshakeState{C: &tlsConn, Hello: &ClientHelloMsg{}}
|
|
uconn := UConn{Conn: &tlsConn, clientHelloID: clientHelloID, HandshakeState: handshakeState}
|
|
return &uconn
|
|
}
|
|
|
|
// BuildHandshakeState() overwrites most fields, therefore, it is advised to manually call this function,
|
|
// if you need to inspect/change contents after parroting/making default Golang ClientHello.
|
|
// Otherwise, there is no need to call this function explicitly.
|
|
func (uconn *UConn) BuildHandshakeState() error {
|
|
if uconn.clientHelloID == HelloGolang {
|
|
// use default Golang ClientHello.
|
|
hello, err := makeClientHello(uconn.config)
|
|
if uconn.HandshakeState.Session != nil {
|
|
// session is lost at makeClientHello(), let's reapply
|
|
uconn.SetSessionState(uconn.HandshakeState.Session)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
uconn.HandshakeState.Hello = hello.getPublicPtr()
|
|
} else {
|
|
var err error
|
|
err = uconn.applyPresetByID(uconn.clientHelloID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = uconn.ApplyConfig()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = uconn.MarshalClientHello()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
uconn.HandshakeStateBuilt = true
|
|
return nil
|
|
}
|
|
|
|
// If you want session tickets to be reused - use same cache on following connections
|
|
func (uconn *UConn) SetSessionState(session *ClientSessionState) error {
|
|
uconn.HandshakeState.Session = session
|
|
var sessionTicket []uint8
|
|
if session != nil {
|
|
sessionTicket = session.sessionTicket
|
|
}
|
|
uconn.HandshakeState.Hello.TicketSupported = true
|
|
uconn.HandshakeState.Hello.SessionTicket = sessionTicket
|
|
for _, ext := range uconn.Extensions {
|
|
st, ok := ext.(*SessionTicketExtension)
|
|
if !ok {
|
|
continue
|
|
}
|
|
st.Session = session
|
|
if session != nil {
|
|
if len(session.SessionTicket()) > 0 {
|
|
if uconn.GetSessionID != nil {
|
|
sid := uconn.GetSessionID(session.SessionTicket())
|
|
uconn.HandshakeState.Hello.SessionId = sid[:]
|
|
return nil
|
|
}
|
|
}
|
|
var sessionID [32]byte
|
|
_, err := io.ReadFull(uconn.config.rand(), uconn.HandshakeState.Hello.SessionId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
uconn.HandshakeState.Hello.SessionId = sessionID[:]
|
|
}
|
|
return nil
|
|
}
|
|
return errors.New("SessionTicketExtension is not part of current spec and won't be sent")
|
|
}
|
|
|
|
// If you want session tickets to be reused - use same cache on following connections
|
|
func (uconn *UConn) SetSessionCache(cache ClientSessionCache) {
|
|
uconn.config.ClientSessionCache = cache
|
|
uconn.HandshakeState.Hello.TicketSupported = true
|
|
}
|
|
|
|
// r has to be 32 bytes long
|
|
func (uconn *UConn) SetClientRandom(r []byte) error {
|
|
if len(r) != 32 {
|
|
return errors.New("Incorrect client random length! Expected: 32, got: " + strconv.Itoa(len(r)))
|
|
} else {
|
|
uconn.HandshakeState.Hello.Random = make([]byte, 32)
|
|
copy(uconn.HandshakeState.Hello.Random, r)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (uconn *UConn) SetSNI(sni string) {
|
|
hname := hostnameInSNI(sni)
|
|
uconn.config.ServerName = hname
|
|
for _, ext := range uconn.Extensions {
|
|
sniExt, ok := ext.(*SNIExtension)
|
|
if ok {
|
|
sniExt.ServerName = hname
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handshake runs the client handshake using given clientHandshakeState
|
|
// Requires hs.hello, and, optionally, hs.session to be set.
|
|
func (c *UConn) Handshake() error {
|
|
// This code was copied almost as is from tls/conn.go
|
|
// c.handshakeErr and c.handshakeComplete are protected by
|
|
// c.handshakeMutex. In order to perform a handshake, we need to lock
|
|
// c.in also and c.handshakeMutex must be locked after c.in.
|
|
//
|
|
// However, if a Read() operation is hanging then it'll be holding the
|
|
// lock on c.in and so taking it here would cause all operations that
|
|
// need to check whether a handshake is pending (such as Write) to
|
|
// block.
|
|
//
|
|
// Thus we first take c.handshakeMutex to check whether a handshake is
|
|
// needed.
|
|
//
|
|
// If so then, previously, this code would unlock handshakeMutex and
|
|
// then lock c.in and handshakeMutex in the correct order to run the
|
|
// handshake. The problem was that it was possible for a Read to
|
|
// complete the handshake once handshakeMutex was unlocked and then
|
|
// keep c.in while waiting for network data. Thus a concurrent
|
|
// operation could be blocked on c.in.
|
|
//
|
|
// Thus handshakeCond is used to signal that a goroutine is committed
|
|
// to running the handshake and other goroutines can wait on it if they
|
|
// need. handshakeCond is protected by handshakeMutex.
|
|
c.handshakeMutex.Lock()
|
|
defer c.handshakeMutex.Unlock()
|
|
|
|
if err := c.handshakeErr; err != nil {
|
|
return err
|
|
}
|
|
if c.handshakeComplete {
|
|
return nil
|
|
}
|
|
|
|
c.in.Lock()
|
|
defer c.in.Unlock()
|
|
|
|
if c.isClient {
|
|
if !c.HandshakeStateBuilt {
|
|
err := c.BuildHandshakeState()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
privateState := c.HandshakeState.getPrivatePtr()
|
|
c.handshakeErr = c.clientHandshakeWithState(privateState)
|
|
c.HandshakeState = *privateState.getPublicPtr()
|
|
} else {
|
|
c.handshakeErr = c.serverHandshake()
|
|
}
|
|
if c.handshakeErr == nil {
|
|
c.handshakes++
|
|
} else {
|
|
// If an error occurred during the hadshake try to flush the
|
|
// alert that might be left in the buffer.
|
|
c.flush()
|
|
}
|
|
|
|
if c.handshakeErr == nil && !c.handshakeComplete {
|
|
panic("handshake should have had a result.")
|
|
}
|
|
|
|
return c.handshakeErr
|
|
}
|
|
|
|
// Copy-pasted from tls.Conn in its entirety. But c.Handshake() is now utls' one, not tls.
|
|
// Write writes data to the connection.
|
|
func (c *UConn) Write(b []byte) (int, error) {
|
|
// interlock with Close below
|
|
for {
|
|
x := atomic.LoadInt32(&c.activeCall)
|
|
if x&1 != 0 {
|
|
return 0, errClosed
|
|
}
|
|
if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
|
|
defer atomic.AddInt32(&c.activeCall, -2)
|
|
break
|
|
}
|
|
}
|
|
|
|
if err := c.Handshake(); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
c.out.Lock()
|
|
defer c.out.Unlock()
|
|
|
|
if err := c.out.err; err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if !c.handshakeComplete {
|
|
return 0, alertInternalError
|
|
}
|
|
|
|
if c.closeNotifySent {
|
|
return 0, errShutdown
|
|
}
|
|
|
|
// SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
|
|
// attack when using block mode ciphers due to predictable IVs.
|
|
// This can be prevented by splitting each Application Data
|
|
// record into two records, effectively randomizing the IV.
|
|
//
|
|
// http://www.openssl.org/~bodo/tls-cbc.txt
|
|
// https://bugzilla.mozilla.org/show_bug.cgi?id=665814
|
|
// http://www.imperialviolet.org/2012/01/15/beastfollowup.html
|
|
|
|
var m int
|
|
if len(b) > 1 && c.vers <= VersionTLS10 {
|
|
if _, ok := c.out.cipher.(cipher.BlockMode); ok {
|
|
n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
|
|
if err != nil {
|
|
return n, c.out.setErrorLocked(err)
|
|
}
|
|
m, b = 1, b[1:]
|
|
}
|
|
}
|
|
|
|
n, err := c.writeRecordLocked(recordTypeApplicationData, b)
|
|
return n + m, c.out.setErrorLocked(err)
|
|
}
|
|
|
|
// c.out.Mutex <= L; c.handshakeMutex <= L.
|
|
func (c *Conn) clientHandshakeWithState(hs *clientHandshakeState) error {
|
|
if c.config == nil {
|
|
c.config = defaultConfig()
|
|
}
|
|
|
|
// This may be a renegotiation handshake, in which case some fields
|
|
// need to be reset.
|
|
c.didResume = false
|
|
|
|
// UTLS: don't make new ClientHello, use hs.hello
|
|
|
|
// UTLS: preserve the check from makeClientHello [start]
|
|
if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify {
|
|
return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
|
|
}
|
|
|
|
nextProtosLength := 0
|
|
for _, proto := range c.config.NextProtos {
|
|
if l := len(proto); l == 0 || l > 255 {
|
|
return errors.New("tls: invalid NextProtos value")
|
|
} else {
|
|
nextProtosLength += 1 + l
|
|
}
|
|
}
|
|
|
|
if nextProtosLength > 0xffff {
|
|
return errors.New("tls: NextProtos values too large")
|
|
}
|
|
// UTLS: preserve the check from makeClientHello [end]
|
|
|
|
if c.handshakes > 0 {
|
|
hs.hello.secureRenegotiation = c.clientFinished[:]
|
|
}
|
|
|
|
var session *ClientSessionState
|
|
var cacheKey string
|
|
sessionCache := c.config.ClientSessionCache
|
|
|
|
if c.config.SessionTicketsDisabled {
|
|
sessionCache = nil
|
|
}
|
|
|
|
if sessionCache != nil {
|
|
hs.hello.ticketSupported = true
|
|
}
|
|
|
|
// Session resumption is not allowed if renegotiating because
|
|
// renegotiation is primarily used to allow a client to send a client
|
|
// certificate, which would be skipped if session resumption occurred.
|
|
if sessionCache != nil && c.handshakes == 0 && hs.session == nil { // [UTLS] hs.session == nil
|
|
// Try to resume a previously negotiated TLS session, if
|
|
// available.
|
|
cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
|
|
candidateSession, ok := sessionCache.Get(cacheKey)
|
|
if ok {
|
|
// Check that the ciphersuite/version used for the
|
|
// previous session are still valid.
|
|
cipherSuiteOk := false
|
|
for _, id := range hs.hello.cipherSuites {
|
|
if id == candidateSession.cipherSuite {
|
|
cipherSuiteOk = true
|
|
break
|
|
}
|
|
}
|
|
|
|
versOk := candidateSession.vers >= c.config.minVersion() &&
|
|
candidateSession.vers <= c.config.maxVersion()
|
|
if versOk && cipherSuiteOk {
|
|
session = candidateSession
|
|
}
|
|
}
|
|
}
|
|
|
|
if session != nil {
|
|
if hs.hello.sessionTicket == nil { // [UTLS] check
|
|
hs.hello.sessionTicket = session.sessionTicket
|
|
}
|
|
// A random session ID is used to detect when the
|
|
// server accepted the ticket and is resuming a session
|
|
// (see RFC 5077).
|
|
|
|
if hs.hello.sessionId == nil { // [UTLS] check
|
|
hs.hello.sessionId = make([]byte, 16)
|
|
if _, err := io.ReadFull(c.config.rand(), hs.hello.sessionId); err != nil {
|
|
return errors.New("tls: short read from Rand: " + err.Error())
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := hs.handshake(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// If we had a successful handshake and hs.session is different from
|
|
// the one already cached - cache a new one
|
|
if sessionCache != nil && hs.session != nil && session != hs.session {
|
|
sessionCache.Put(cacheKey, hs.session)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (uconn *UConn) ApplyConfig() error {
|
|
for _, ext := range uconn.Extensions {
|
|
err := ext.writeToUConn(uconn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (uconn *UConn) MarshalClientHello() error {
|
|
hello := uconn.HandshakeState.Hello
|
|
headerLength := 2 + 32 + 1 + len(hello.SessionId) +
|
|
2 + len(hello.CipherSuites)*2 +
|
|
1 + len(hello.CompressionMethods)
|
|
|
|
extensionsLen := 0
|
|
var paddingExt *UtlsPaddingExtension
|
|
for _, ext := range uconn.Extensions {
|
|
if pe, ok := ext.(*UtlsPaddingExtension); !ok {
|
|
// If not padding - just add length of extension to total length
|
|
extensionsLen += ext.Len()
|
|
} else {
|
|
// If padding - process it later
|
|
if paddingExt == nil {
|
|
paddingExt = pe
|
|
} else {
|
|
return errors.New("Multiple padding extensions!")
|
|
}
|
|
}
|
|
}
|
|
|
|
if paddingExt != nil {
|
|
// determine padding extension presence and length
|
|
paddingExt.Update(headerLength + 4 + extensionsLen + 2)
|
|
extensionsLen += paddingExt.Len()
|
|
}
|
|
|
|
helloLen := headerLength
|
|
if len(uconn.Extensions) > 0 {
|
|
helloLen += 2 + extensionsLen // 2 bytes for extensions' length
|
|
}
|
|
|
|
helloBuffer := bytes.Buffer{}
|
|
bufferedWriter := bufio.NewWriterSize(&helloBuffer, helloLen+4) // 1 byte for tls record type, 3 for length
|
|
// We use buffered Writer to avoid checking write errors after every Write(): whenever first error happens
|
|
// Write() will become noop, and error will be accessible via Flush(), which is called once in the end
|
|
|
|
binary.Write(bufferedWriter, binary.BigEndian, typeClientHello)
|
|
helloLenBytes := []byte{byte(helloLen >> 16), byte(helloLen >> 8), byte(helloLen)} // poor man's uint24
|
|
binary.Write(bufferedWriter, binary.BigEndian, helloLenBytes)
|
|
binary.Write(bufferedWriter, binary.BigEndian, hello.Vers)
|
|
|
|
binary.Write(bufferedWriter, binary.BigEndian, hello.Random)
|
|
|
|
binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.SessionId)))
|
|
binary.Write(bufferedWriter, binary.BigEndian, hello.SessionId)
|
|
|
|
binary.Write(bufferedWriter, binary.BigEndian, uint16(len(hello.CipherSuites)<<1))
|
|
for _, suite := range hello.CipherSuites {
|
|
binary.Write(bufferedWriter, binary.BigEndian, suite)
|
|
}
|
|
|
|
binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.CompressionMethods)))
|
|
binary.Write(bufferedWriter, binary.BigEndian, hello.CompressionMethods)
|
|
|
|
if len(uconn.Extensions) > 0 {
|
|
binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen))
|
|
for _, ext := range uconn.Extensions {
|
|
bufferedWriter.ReadFrom(ext)
|
|
}
|
|
}
|
|
|
|
if helloBuffer.Len() != 4+helloLen {
|
|
return errors.New("utls: unexpected ClientHello length. Expected: " + strconv.Itoa(4+helloLen) +
|
|
". Got: " + strconv.Itoa(helloBuffer.Len()))
|
|
}
|
|
|
|
err := bufferedWriter.Flush()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
hello.Raw = helloBuffer.Bytes()
|
|
return nil
|
|
}
|
|
|
|
// get current state of cipher and encrypt zeros to get keystream
|
|
func (uconn *UConn) GetOutKeystream(length int) ([]byte, error) {
|
|
zeros := make([]byte, length)
|
|
|
|
if outCipher, ok := uconn.out.cipher.(cipher.AEAD); ok {
|
|
// AEAD.Seal() does not mutate internal state, other ciphers might
|
|
return outCipher.Seal(nil, uconn.out.seq[:], zeros, nil), nil
|
|
}
|
|
return nil, errors.New("Could not convert OutCipher to cipher.AEAD")
|
|
}
|