From 323a55944c764a2cd5615c11f331fb5d74a25497 Mon Sep 17 00:00:00 2001 From: Eric Wustrow Date: Wed, 30 Jan 2019 11:25:23 -0700 Subject: [PATCH] Add MakeConnWithCompleteHandshake function (#18) Add MakeConnWithCompleteHandshake function + regression test and usage example --- examples/examples.go | 60 ++++++++++++++++++++++++++++++++++++++++ handshake_server_test.go | 8 +++--- u_conn.go | 56 +++++++++++++++++++++++++++++++++++++ u_conn_test.go | 42 ++++++++++++++++++++++++++++ 4 files changed, 162 insertions(+), 4 deletions(-) diff --git a/examples/examples.go b/examples/examples.go index 63f8eaa..6a89595 100644 --- a/examples/examples.go +++ b/examples/examples.go @@ -249,6 +249,64 @@ func HttpGetGoogleWithRoller() (*http.Response, error) { return httpGetOverConn(c, c.HandshakeState.ServerHello.AlpnProtocol) } +func forgeConn() { + // this gets tls connection with google.com + // then replaces underlying connection of that tls connection with an in-memory pipe + // to a forged local in-memory "server-side" connection, + // that uses cryptographic parameters passed by a client + clientTcp, err := net.DialTimeout("tcp", "google.com:443", 10*time.Second) + if err != nil { + fmt.Printf("net.DialTimeout error: %+v", err) + return + } + + clientUtls := tls.UClient(clientTcp, nil, tls.HelloGolang) + defer clientUtls.Close() + clientUtls.SetSNI("google.com") // have to set SNI, if config was nil + err = clientUtls.Handshake() + if err != nil { + fmt.Printf("clientUtls.Handshake() error: %+v", err) + } + + serverConn, clientConn := net.Pipe() + + clientUtls.SetNetConn(clientConn) + + hs := clientUtls.HandshakeState + serverTls := tls.MakeConnWithCompleteHandshake(serverConn, hs.ServerHello.Vers, hs.ServerHello.CipherSuite, + hs.MasterSecret, hs.Hello.Random, hs.ServerHello.Random, false) + + go func() { + clientUtls.Write([]byte("Hello, world!")) + resp := make([]byte, 13) + read, err := clientUtls.Read(resp) + if err != nil { + fmt.Printf("error reading client: %+v\n", err) + } + fmt.Printf("Client read %d bytes: %s\n", read, string(resp)) + fmt.Println("Client closing...") + clientUtls.Close() + fmt.Println("client closed") + }() + + buf := make([]byte, 13) + read, err := serverTls.Read(buf) + if err != nil { + fmt.Printf("error reading server: %+v\n", err) + } + + fmt.Printf("Server read %d bytes: %s\n", read, string(buf)) + serverTls.Write([]byte("Test response")) + + // Have to do a final read (that will error) + // to consume client's closeNotify + // because net Pipes are weird + serverTls.Read(buf) + fmt.Println("Server closed") + +} + + func main() { var response *http.Response var err error @@ -312,6 +370,8 @@ func main() { } } + forgeConn() + return } diff --git a/handshake_server_test.go b/handshake_server_test.go index 6aefa56..0bd0ae0 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -224,9 +224,9 @@ func TestDontSelectRSAWithECDSAKey(t *testing.T) { func TestRenegotiationExtension(t *testing.T) { clientHello := &clientHelloMsg{ - vers: VersionTLS12, - compressionMethods: []uint8{compressionNone}, - random: make([]byte, 32), + vers: VersionTLS12, + compressionMethods: []uint8{compressionNone}, + random: make([]byte, 32), secureRenegotiationSupported: true, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, } @@ -1166,7 +1166,7 @@ func TestFallbackSCSV(t *testing.T) { name: "FallbackSCSV", config: &serverConfig, // OpenSSL 1.0.1j is needed for the -fallback_scsv option. - command: []string{"openssl", "s_client", "-fallback_scsv"}, + command: []string{"openssl", "s_client", "-fallback_scsv"}, expectHandshakeErrorIncluding: "inappropriate protocol fallback", } runServerTestTLS11(t, test) diff --git a/u_conn.go b/u_conn.go index b76eec6..516d6f7 100644 --- a/u_conn.go +++ b/u_conn.go @@ -496,6 +496,62 @@ func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16) error { return nil } +func (uconn *UConn) SetUnderlyingConn(c net.Conn) { + uconn.Conn.conn = c +} + +func (uconn *UConn) SetNetConn(c net.Conn) { + uconn.Conn.conn = c +} + +// MakeConnWithCompleteHandshake allows to forge both server and client side TLS connections. +// Major Hack Alert. +func MakeConnWithCompleteHandshake(tcpConn net.Conn, version uint16, cipherSuite uint16, masterSecret []byte, clientRandom []byte, serverRandom []byte, isClient bool) *Conn { + tlsConn := &Conn{conn: tcpConn, config: &Config{}, isClient: isClient} + cs := cipherSuiteByID(cipherSuite) + + // This is mostly borrowed from establishKeys() + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(version, cs, masterSecret, clientRandom, serverRandom, + cs.macLen, cs.keyLen, cs.ivLen) + + var clientCipher, serverCipher interface{} + var clientHash, serverHash macFunction + if cs.cipher != nil { + clientCipher = cs.cipher(clientKey, clientIV, true /* for reading */) + clientHash = cs.mac(version, clientMAC) + serverCipher = cs.cipher(serverKey, serverIV, false /* not for reading */) + serverHash = cs.mac(version, serverMAC) + } else { + clientCipher = cs.aead(clientKey, clientIV) + serverCipher = cs.aead(serverKey, serverIV) + } + + if isClient { + tlsConn.in.prepareCipherSpec(version, serverCipher, serverHash) + tlsConn.out.prepareCipherSpec(version, clientCipher, clientHash) + } else { + tlsConn.in.prepareCipherSpec(version, clientCipher, clientHash) + tlsConn.out.prepareCipherSpec(version, serverCipher, serverHash) + } + + // skip the handshake states + tlsConn.handshakeStatus = 1 + tlsConn.cipherSuite = cipherSuite + tlsConn.haveVers = true + tlsConn.vers = version + + // Update to the new cipher specs + // and consume the finished messages + tlsConn.in.changeCipherSpec() + tlsConn.out.changeCipherSpec() + + tlsConn.in.incSeq() + tlsConn.out.incSeq() + + return tlsConn +} + func makeSupportedVersions(minVers, maxVers uint16) []uint16 { a := make([]uint16, maxVers-minVers+1) for i := range a { diff --git a/u_conn_test.go b/u_conn_test.go index 7f568b1..d4dd4cd 100644 --- a/u_conn_test.go +++ b/u_conn_test.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "crypto/tls" "fmt" "io" "net" @@ -584,3 +585,44 @@ func (test *clientTest) runUTLS(t *testing.T, write bool, helloID ClientHelloID) fmt.Printf("Wrote %s\n", path) } } + +func TestUTLSMakeConnWithCompleteHandshake(t *testing.T) { + serverConn, clientConn := net.Pipe() + + masterSecret := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47} + clientRandom := []byte{40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71} + serverRandom := []byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111} + serverTls := MakeConnWithCompleteHandshake(serverConn, tls.VersionTLS12, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + masterSecret, clientRandom, serverRandom, false) + clientTls := MakeConnWithCompleteHandshake(clientConn, tls.VersionTLS12, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + masterSecret, clientRandom, serverRandom, true) + + clientMsg := []byte("Hello, world!") + serverMsg := []byte("Test response!") + + go func() { + clientTls.Write(clientMsg) + resp := make([]byte, 20) + read, err := clientTls.Read(resp) + if !bytes.Equal(resp[:read], serverMsg) { + t.Errorf("client expected to receive: %v, got %v\n", + serverMsg, resp[:read]) + } + if err != nil { + t.Errorf("error reading client: %+v\n", err) + } + clientConn.Close() + }() + + buf := make([]byte, 20) + read, err := serverTls.Read(buf) + if !bytes.Equal(buf[:read], clientMsg) { + t.Errorf("server expected to receive: %v, got %v\n", + clientMsg, buf[:read]) + } + if err != nil { + t.Errorf("error reading client: %+v\n", err) + } + + serverTls.Write(serverMsg) +}