mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-05 13:07:36 +03:00
Add MakeConnWithCompleteHandshake function (#18)
Add MakeConnWithCompleteHandshake function + regression test and usage example
This commit is contained in:
parent
a89e7e6da4
commit
323a55944c
4 changed files with 162 additions and 4 deletions
|
@ -249,6 +249,64 @@ func HttpGetGoogleWithRoller() (*http.Response, error) {
|
||||||
return httpGetOverConn(c, c.HandshakeState.ServerHello.AlpnProtocol)
|
return httpGetOverConn(c, c.HandshakeState.ServerHello.AlpnProtocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func forgeConn() {
|
||||||
|
// this gets tls connection with google.com
|
||||||
|
// then replaces underlying connection of that tls connection with an in-memory pipe
|
||||||
|
// to a forged local in-memory "server-side" connection,
|
||||||
|
// that uses cryptographic parameters passed by a client
|
||||||
|
clientTcp, err := net.DialTimeout("tcp", "google.com:443", 10*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("net.DialTimeout error: %+v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientUtls := tls.UClient(clientTcp, nil, tls.HelloGolang)
|
||||||
|
defer clientUtls.Close()
|
||||||
|
clientUtls.SetSNI("google.com") // have to set SNI, if config was nil
|
||||||
|
err = clientUtls.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("clientUtls.Handshake() error: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverConn, clientConn := net.Pipe()
|
||||||
|
|
||||||
|
clientUtls.SetNetConn(clientConn)
|
||||||
|
|
||||||
|
hs := clientUtls.HandshakeState
|
||||||
|
serverTls := tls.MakeConnWithCompleteHandshake(serverConn, hs.ServerHello.Vers, hs.ServerHello.CipherSuite,
|
||||||
|
hs.MasterSecret, hs.Hello.Random, hs.ServerHello.Random, false)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
clientUtls.Write([]byte("Hello, world!"))
|
||||||
|
resp := make([]byte, 13)
|
||||||
|
read, err := clientUtls.Read(resp)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("error reading client: %+v\n", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("Client read %d bytes: %s\n", read, string(resp))
|
||||||
|
fmt.Println("Client closing...")
|
||||||
|
clientUtls.Close()
|
||||||
|
fmt.Println("client closed")
|
||||||
|
}()
|
||||||
|
|
||||||
|
buf := make([]byte, 13)
|
||||||
|
read, err := serverTls.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("error reading server: %+v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Server read %d bytes: %s\n", read, string(buf))
|
||||||
|
serverTls.Write([]byte("Test response"))
|
||||||
|
|
||||||
|
// Have to do a final read (that will error)
|
||||||
|
// to consume client's closeNotify
|
||||||
|
// because net Pipes are weird
|
||||||
|
serverTls.Read(buf)
|
||||||
|
fmt.Println("Server closed")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var response *http.Response
|
var response *http.Response
|
||||||
var err error
|
var err error
|
||||||
|
@ -312,6 +370,8 @@ func main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
forgeConn()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
56
u_conn.go
56
u_conn.go
|
@ -496,6 +496,62 @@ func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (uconn *UConn) SetUnderlyingConn(c net.Conn) {
|
||||||
|
uconn.Conn.conn = c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uconn *UConn) SetNetConn(c net.Conn) {
|
||||||
|
uconn.Conn.conn = c
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeConnWithCompleteHandshake allows to forge both server and client side TLS connections.
|
||||||
|
// Major Hack Alert.
|
||||||
|
func MakeConnWithCompleteHandshake(tcpConn net.Conn, version uint16, cipherSuite uint16, masterSecret []byte, clientRandom []byte, serverRandom []byte, isClient bool) *Conn {
|
||||||
|
tlsConn := &Conn{conn: tcpConn, config: &Config{}, isClient: isClient}
|
||||||
|
cs := cipherSuiteByID(cipherSuite)
|
||||||
|
|
||||||
|
// This is mostly borrowed from establishKeys()
|
||||||
|
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
|
||||||
|
keysFromMasterSecret(version, cs, masterSecret, clientRandom, serverRandom,
|
||||||
|
cs.macLen, cs.keyLen, cs.ivLen)
|
||||||
|
|
||||||
|
var clientCipher, serverCipher interface{}
|
||||||
|
var clientHash, serverHash macFunction
|
||||||
|
if cs.cipher != nil {
|
||||||
|
clientCipher = cs.cipher(clientKey, clientIV, true /* for reading */)
|
||||||
|
clientHash = cs.mac(version, clientMAC)
|
||||||
|
serverCipher = cs.cipher(serverKey, serverIV, false /* not for reading */)
|
||||||
|
serverHash = cs.mac(version, serverMAC)
|
||||||
|
} else {
|
||||||
|
clientCipher = cs.aead(clientKey, clientIV)
|
||||||
|
serverCipher = cs.aead(serverKey, serverIV)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isClient {
|
||||||
|
tlsConn.in.prepareCipherSpec(version, serverCipher, serverHash)
|
||||||
|
tlsConn.out.prepareCipherSpec(version, clientCipher, clientHash)
|
||||||
|
} else {
|
||||||
|
tlsConn.in.prepareCipherSpec(version, clientCipher, clientHash)
|
||||||
|
tlsConn.out.prepareCipherSpec(version, serverCipher, serverHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// skip the handshake states
|
||||||
|
tlsConn.handshakeStatus = 1
|
||||||
|
tlsConn.cipherSuite = cipherSuite
|
||||||
|
tlsConn.haveVers = true
|
||||||
|
tlsConn.vers = version
|
||||||
|
|
||||||
|
// Update to the new cipher specs
|
||||||
|
// and consume the finished messages
|
||||||
|
tlsConn.in.changeCipherSpec()
|
||||||
|
tlsConn.out.changeCipherSpec()
|
||||||
|
|
||||||
|
tlsConn.in.incSeq()
|
||||||
|
tlsConn.out.incSeq()
|
||||||
|
|
||||||
|
return tlsConn
|
||||||
|
}
|
||||||
|
|
||||||
func makeSupportedVersions(minVers, maxVers uint16) []uint16 {
|
func makeSupportedVersions(minVers, maxVers uint16) []uint16 {
|
||||||
a := make([]uint16, maxVers-minVers+1)
|
a := make([]uint16, maxVers-minVers+1)
|
||||||
for i := range a {
|
for i := range a {
|
||||||
|
|
|
@ -6,6 +6,7 @@ package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
@ -584,3 +585,44 @@ func (test *clientTest) runUTLS(t *testing.T, write bool, helloID ClientHelloID)
|
||||||
fmt.Printf("Wrote %s\n", path)
|
fmt.Printf("Wrote %s\n", path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUTLSMakeConnWithCompleteHandshake(t *testing.T) {
|
||||||
|
serverConn, clientConn := net.Pipe()
|
||||||
|
|
||||||
|
masterSecret := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}
|
||||||
|
clientRandom := []byte{40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71}
|
||||||
|
serverRandom := []byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111}
|
||||||
|
serverTls := MakeConnWithCompleteHandshake(serverConn, tls.VersionTLS12, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
masterSecret, clientRandom, serverRandom, false)
|
||||||
|
clientTls := MakeConnWithCompleteHandshake(clientConn, tls.VersionTLS12, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
masterSecret, clientRandom, serverRandom, true)
|
||||||
|
|
||||||
|
clientMsg := []byte("Hello, world!")
|
||||||
|
serverMsg := []byte("Test response!")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
clientTls.Write(clientMsg)
|
||||||
|
resp := make([]byte, 20)
|
||||||
|
read, err := clientTls.Read(resp)
|
||||||
|
if !bytes.Equal(resp[:read], serverMsg) {
|
||||||
|
t.Errorf("client expected to receive: %v, got %v\n",
|
||||||
|
serverMsg, resp[:read])
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error reading client: %+v\n", err)
|
||||||
|
}
|
||||||
|
clientConn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
buf := make([]byte, 20)
|
||||||
|
read, err := serverTls.Read(buf)
|
||||||
|
if !bytes.Equal(buf[:read], clientMsg) {
|
||||||
|
t.Errorf("server expected to receive: %v, got %v\n",
|
||||||
|
clientMsg, buf[:read])
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error reading client: %+v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverTls.Write(serverMsg)
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue