diff --git a/conn_test.go b/conn_test.go index 5c7f7ce..f948717 100644 --- a/conn_test.go +++ b/conn_test.go @@ -134,12 +134,13 @@ func TestCertificateSelection(t *testing.T) { // Run with multiple crypto configs to test the logic for computing TLS record overheads. func runDynamicRecordSizingTest(t *testing.T, config *Config) { - clientConn, serverConn := net.Pipe() + clientConn, serverConn := localPipe(t) serverConfig := config.Clone() serverConfig.DynamicRecordSizingDisabled = false tlsConn := Server(serverConn, serverConfig) + handshakeDone := make(chan struct{}) recordSizesChan := make(chan []int, 1) go func() { // This goroutine performs a TLS handshake over clientConn and @@ -153,6 +154,7 @@ func runDynamicRecordSizingTest(t *testing.T, config *Config) { t.Errorf("Error from client handshake: %v", err) return } + close(handshakeDone) var recordHeader [recordHeaderLen]byte var record []byte @@ -192,6 +194,7 @@ func runDynamicRecordSizingTest(t *testing.T, config *Config) { if err := tlsConn.Handshake(); err != nil { t.Fatalf("Error from server handshake: %s", err) } + <-handshakeDone // The server writes these plaintexts in order. plaintext := bytes.Join([][]byte{ @@ -269,7 +272,7 @@ func (conn *hairpinConn) Close() error { func TestHairpinInClose(t *testing.T) { // This tests that the underlying net.Conn can call back into the // tls.Conn when being closed without deadlocking. - client, server := net.Pipe() + client, server := localPipe(t) defer server.Close() defer client.Close() diff --git a/handshake_client_test.go b/handshake_client_test.go index 1f1c93d..dcd6914 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -179,7 +179,7 @@ func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, var pemOut bytes.Buffer pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes}) - keyPath := tempFile(string(pemOut.Bytes())) + keyPath := tempFile(pemOut.String()) defer os.Remove(keyPath) var command []string @@ -293,7 +293,7 @@ func (test *clientTest) run(t *testing.T, write bool) { } clientConn = recordingConn } else { - clientConn, serverConn = net.Pipe() + clientConn, serverConn = localPipe(t) } config := test.config @@ -682,7 +682,7 @@ func TestClientResumption(t *testing.T) { } testResumeState := func(test string, didResume bool) { - _, hs, err := testHandshake(clientConfig, serverConfig) + _, hs, err := testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("%s: handshake failed: %s", test, err) } @@ -800,7 +800,7 @@ func TestKeyLog(t *testing.T) { serverConfig := testConfig.Clone() serverConfig.KeyLogWriter = &serverBuf - c, s := net.Pipe() + c, s := localPipe(t) done := make(chan bool) go func() { @@ -838,8 +838,8 @@ func TestKeyLog(t *testing.T) { } } - checkKeylogLine("client", string(clientBuf.Bytes())) - checkKeylogLine("server", string(serverBuf.Bytes())) + checkKeylogLine("client", clientBuf.String()) + checkKeylogLine("server", serverBuf.String()) } func TestHandshakeClientALPNMatch(t *testing.T) { @@ -1021,7 +1021,7 @@ var hostnameInSNITests = []struct { func TestHostnameInSNI(t *testing.T) { for _, tt := range hostnameInSNITests { - c, s := net.Pipe() + c, s := localPipe(t) go func(host string) { Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake() @@ -1059,7 +1059,7 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { // This checks that the server can't select a cipher suite that the // client didn't offer. See #13174. - c, s := net.Pipe() + c, s := localPipe(t) errChan := make(chan error, 1) go func() { @@ -1228,7 +1228,7 @@ func TestVerifyPeerCertificate(t *testing.T) { } for i, test := range tests { - c, s := net.Pipe() + c, s := localPipe(t) done := make(chan error) var clientCalled, serverCalled bool @@ -1287,7 +1287,7 @@ func (b *brokenConn) Write(data []byte) (int, error) { func TestFailedWrite(t *testing.T) { // Test that a write error during the handshake is returned. for _, breakAfter := range []int{0, 1} { - c, s := net.Pipe() + c, s := localPipe(t) done := make(chan bool) go func() { @@ -1321,7 +1321,7 @@ func (wcc *writeCountingConn) Write(data []byte) (int, error) { } func TestBuffering(t *testing.T) { - c, s := net.Pipe() + c, s := localPipe(t) done := make(chan bool) clientWCC := &writeCountingConn{Conn: c} @@ -1350,7 +1350,7 @@ func TestBuffering(t *testing.T) { } func TestAlertFlushing(t *testing.T) { - c, s := net.Pipe() + c, s := localPipe(t) done := make(chan bool) clientWCC := &writeCountingConn{Conn: c} @@ -1399,7 +1399,7 @@ func TestHandshakeRace(t *testing.T) { // order to provide some evidence that there are no races or deadlocks // in the handshake locking. for i := 0; i < 32; i++ { - c, s := net.Pipe() + c, s := localPipe(t) go func() { server := Server(s, testConfig) @@ -1430,7 +1430,7 @@ func TestHandshakeRace(t *testing.T) { go func() { <-startRead var reply [1]byte - if n, err := client.Read(reply[:]); err != nil || n != 1 { + if _, err := io.ReadFull(client, reply[:]); err != nil { panic(err) } c.Close() @@ -1559,7 +1559,7 @@ func TestGetClientCertificate(t *testing.T) { err error } - c, s := net.Pipe() + c, s := localPipe(t) done := make(chan serverResult) go func() { @@ -1637,7 +1637,7 @@ RwBA9Xk1KBNF } func TestCloseClientConnectionOnIdleServer(t *testing.T) { - clientConn, serverConn := net.Pipe() + clientConn, serverConn := localPipe(t) client := Client(clientConn, testConfig.Clone()) go func() { var b [1]byte @@ -1647,8 +1647,8 @@ func TestCloseClientConnectionOnIdleServer(t *testing.T) { client.SetWriteDeadline(time.Now().Add(time.Second)) err := client.Handshake() if err != nil { - if !strings.Contains(err.Error(), "read/write on closed pipe") { - t.Errorf("Error expected containing 'read/write on closed pipe' but got '%s'", err.Error()) + if err, ok := err.(net.Error); ok && err.Timeout() { + t.Errorf("Expected a closed network connection error but got '%s'", err.Error()) } } else { t.Errorf("Error expected, but no error returned") diff --git a/handshake_server_test.go b/handshake_server_test.go index c366f47..44c67ed 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -70,10 +70,7 @@ func testClientHello(t *testing.T, serverConfig *Config, m handshakeMessage) { } func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessage, expectedSubStr string) { - // Create in-memory network connection, - // send message to server. Should return - // expected error. - c, s := net.Pipe() + c, s := localPipe(t) go func() { cli := Client(c, testConfig) if ch, ok := m.(*clientHelloMsg); ok { @@ -201,25 +198,26 @@ func TestRenegotiationExtension(t *testing.T) { cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, } - var buf []byte - c, s := net.Pipe() + bufChan := make(chan []byte) + c, s := localPipe(t) go func() { cli := Client(c, testConfig) cli.vers = clientHello.vers cli.writeRecord(recordTypeHandshake, clientHello.marshal()) - buf = make([]byte, 1024) + buf := make([]byte, 1024) n, err := c.Read(buf) if err != nil { t.Errorf("Server read returned error: %s", err) return } - buf = buf[:n] c.Close() + bufChan <- buf[:n] }() Server(s, testConfig).Handshake() + buf := <-bufChan if len(buf) < 5+4 { t.Fatalf("Server returned short message of length %d", len(buf)) @@ -262,22 +260,27 @@ func TestTLS12OnlyCipherSuites(t *testing.T) { supportedPoints: []uint8{pointFormatUncompressed}, } - c, s := net.Pipe() - var reply interface{} - var clientErr error + c, s := localPipe(t) + replyChan := make(chan interface{}) go func() { cli := Client(c, testConfig) cli.vers = clientHello.vers cli.writeRecord(recordTypeHandshake, clientHello.marshal()) - reply, clientErr = cli.readHandshake() + reply, err := cli.readHandshake() c.Close() + if err != nil { + replyChan <- err + } else { + replyChan <- reply + } }() config := testConfig.Clone() config.CipherSuites = clientHello.cipherSuites Server(s, config).Handshake() s.Close() - if clientErr != nil { - t.Fatal(clientErr) + reply := <-replyChan + if err, ok := reply.(error); ok { + t.Fatal(err) } serverHello, ok := reply.(*serverHelloMsg) if !ok { @@ -289,7 +292,7 @@ func TestTLS12OnlyCipherSuites(t *testing.T) { } func TestAlertForwarding(t *testing.T) { - c, s := net.Pipe() + c, s := localPipe(t) go func() { Client(c, testConfig).sendAlert(alertUnknownCA) c.Close() @@ -303,7 +306,7 @@ func TestAlertForwarding(t *testing.T) { } func TestClose(t *testing.T) { - c, s := net.Pipe() + c, s := localPipe(t) go c.Close() err := Server(s, testConfig).Handshake() @@ -313,8 +316,8 @@ func TestClose(t *testing.T) { } } -func testHandshake(clientConfig, serverConfig *Config) (serverState, clientState ConnectionState, err error) { - c, s := net.Pipe() +func testHandshake(t *testing.T, clientConfig, serverConfig *Config) (serverState, clientState ConnectionState, err error) { + c, s := localPipe(t) done := make(chan bool) go func() { cli := Client(c, clientConfig) @@ -341,7 +344,7 @@ func TestVersion(t *testing.T) { clientConfig := &Config{ InsecureSkipVerify: true, } - state, _, err := testHandshake(clientConfig, serverConfig) + state, _, err := testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("handshake failed: %s", err) } @@ -360,7 +363,7 @@ func TestCipherSuitePreference(t *testing.T) { CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_RC4_128_SHA}, InsecureSkipVerify: true, } - state, _, err := testHandshake(clientConfig, serverConfig) + state, _, err := testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("handshake failed: %s", err) } @@ -370,7 +373,7 @@ func TestCipherSuitePreference(t *testing.T) { } serverConfig.PreferServerCipherSuites = true - state, _, err = testHandshake(clientConfig, serverConfig) + state, _, err = testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("handshake failed: %s", err) } @@ -391,7 +394,7 @@ func TestSCTHandshake(t *testing.T) { clientConfig := &Config{ InsecureSkipVerify: true, } - _, state, err := testHandshake(clientConfig, serverConfig) + _, state, err := testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("handshake failed: %s", err) } @@ -420,13 +423,13 @@ func TestCrossVersionResume(t *testing.T) { // Establish a session at TLS 1.1. clientConfig.MaxVersion = VersionTLS11 - _, _, err := testHandshake(clientConfig, serverConfig) + _, _, err := testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("handshake failed: %s", err) } // The client session cache now contains a TLS 1.1 session. - state, _, err := testHandshake(clientConfig, serverConfig) + state, _, err := testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("handshake failed: %s", err) } @@ -436,7 +439,7 @@ func TestCrossVersionResume(t *testing.T) { // Test that the server will decline to resume at a lower version. clientConfig.MaxVersion = VersionTLS10 - state, _, err = testHandshake(clientConfig, serverConfig) + state, _, err = testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("handshake failed: %s", err) } @@ -445,7 +448,7 @@ func TestCrossVersionResume(t *testing.T) { } // The client session cache now contains a TLS 1.0 session. - state, _, err = testHandshake(clientConfig, serverConfig) + state, _, err = testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("handshake failed: %s", err) } @@ -455,7 +458,7 @@ func TestCrossVersionResume(t *testing.T) { // Test that the server will decline to resume at a higher version. clientConfig.MaxVersion = VersionTLS11 - state, _, err = testHandshake(clientConfig, serverConfig) + state, _, err = testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("handshake failed: %s", err) } @@ -579,7 +582,7 @@ func (test *serverTest) run(t *testing.T, write bool) { } serverConn = recordingConn } else { - clientConn, serverConn = net.Pipe() + clientConn, serverConn = localPipe(t) } config := test.config if config == nil { @@ -832,7 +835,7 @@ func TestHandshakeServerSNIGetCertificate(t *testing.T) { nameToCert := config.NameToCertificate config.NameToCertificate = nil config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { - cert, _ := nameToCert[clientHello.ServerName] + cert := nameToCert[clientHello.ServerName] return cert, nil } test := &serverTest{ @@ -1025,7 +1028,7 @@ func benchmarkHandshakeServer(b *testing.B, cipherSuite uint16, curve CurveID, c config.Certificates[0].PrivateKey = key config.BuildNameToCertificate() - clientConn, serverConn := net.Pipe() + clientConn, serverConn := localPipe(b) serverConn = &recordingConn{Conn: serverConn} go func() { client := Client(clientConn, testConfig) @@ -1039,7 +1042,7 @@ func benchmarkHandshakeServer(b *testing.B, cipherSuite uint16, curve CurveID, c flows := serverConn.(*recordingConn).flows feeder := make(chan struct{}) - clientConn, serverConn = net.Pipe() + clientConn, serverConn = localPipe(b) go func() { for range feeder { @@ -1051,10 +1054,10 @@ func benchmarkHandshakeServer(b *testing.B, cipherSuite uint16, curve CurveID, c ff := make([]byte, len(f)) n, err := io.ReadFull(clientConn, ff) if err != nil { - b.Fatalf("#%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", i+1, err, n, len(ff), ff[:n], f) + b.Errorf("#%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", i+1, err, n, len(ff), ff[:n], f) } if !bytes.Equal(f, ff) { - b.Fatalf("#%d: mismatch on read: got:%x want:%x", i+1, ff, f) + b.Errorf("#%d: mismatch on read: got:%x want:%x", i+1, ff, f) } } } @@ -1216,7 +1219,7 @@ func TestSNIGivenOnFailure(t *testing.T) { // Erase the server's cipher suites to ensure the handshake fails. serverConfig.CipherSuites = nil - c, s := net.Pipe() + c, s := localPipe(t) go func() { cli := Client(c, testConfig) cli.vers = clientHello.vers @@ -1346,7 +1349,7 @@ func TestGetConfigForClient(t *testing.T) { configReturned = config return config, err } - c, s := net.Pipe() + c, s := localPipe(t) done := make(chan error) go func() { @@ -1423,7 +1426,7 @@ var testECDSAPrivateKey = &ecdsa.PrivateKey{ var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75")) func TestCloseServerConnectionOnIdleClient(t *testing.T) { - clientConn, serverConn := net.Pipe() + clientConn, serverConn := localPipe(t) server := Server(serverConn, testConfig.Clone()) go func() { clientConn.Write([]byte{'0'}) @@ -1432,8 +1435,8 @@ func TestCloseServerConnectionOnIdleClient(t *testing.T) { server.SetReadDeadline(time.Now().Add(time.Second)) err := server.Handshake() if err != nil { - if !strings.Contains(err.Error(), "read/write on closed pipe") { - t.Errorf("Error expected containing 'read/write on closed pipe' but got '%s'", err.Error()) + if err, ok := err.(net.Error); ok && err.Timeout() { + t.Errorf("Expected a closed network connection error but got '%s'", err.Error()) } } else { t.Errorf("Error expected, but no error returned") diff --git a/handshake_test.go b/handshake_test.go index 4b3fa23..18d4624 100644 --- a/handshake_test.go +++ b/handshake_test.go @@ -13,6 +13,7 @@ import ( "io" "io/ioutil" "net" + "os" "os/exec" "strconv" "strings" @@ -224,3 +225,45 @@ func tempFile(contents string) string { file.Close() return path } + +// localListener is set up by TestMain and used by localPipe to create Conn +// pairs like net.Pipe, but connected by an actual buffered TCP connection. +var localListener struct { + sync.Mutex + net.Listener +} + +func localPipe(t testing.TB) (net.Conn, net.Conn) { + localListener.Lock() + defer localListener.Unlock() + c := make(chan net.Conn) + go func() { + conn, err := localListener.Accept() + if err != nil { + t.Errorf("Failed to accept local connection: %v", err) + } + c <- conn + }() + addr := localListener.Addr() + c1, err := net.Dial(addr.Network(), addr.String()) + if err != nil { + t.Fatalf("Failed to dial local connection: %v", err) + } + c2 := <-c + return c1, c2 +} + +func TestMain(m *testing.M) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + l, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to open local listener: %v", err) + os.Exit(1) + } + localListener.Listener = l + exitCode := m.Run() + localListener.Close() + os.Exit(exitCode) +}