diff --git a/handshake_client_test.go b/handshake_client_test.go index 4570f5b..501f9c6 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -283,7 +283,7 @@ func (test *clientTest) loadData() (flows [][]byte, err error) { } func (test *clientTest) run(t *testing.T, write bool) { - var clientConn, serverConn net.Conn + var clientConn net.Conn var recordingConn *recordingConn var childProcess *exec.Cmd var stdin opensslInput @@ -302,178 +302,138 @@ func (test *clientTest) run(t *testing.T, write bool) { } }() } else { - clientConn, serverConn = localPipe(t) + flows, err := test.loadData() + if err != nil { + t.Fatalf("failed to load data from %s: %v", test.dataPath(), err) + } + clientConn = &replayingConn{t: t, flows: flows, reading: false} } - doneChan := make(chan bool) - defer func() { - clientConn.Close() - <-doneChan - }() - go func() { - defer close(doneChan) + config := test.config + if config == nil { + config = testConfig + } + client := Client(clientConn, config) + defer client.Close() - config := test.config - if config == nil { - config = testConfig + if _, err := client.Write([]byte("hello\n")); err != nil { + t.Errorf("Client.Write failed: %s", err) + return + } + + for i := 1; i <= test.numRenegotiations; i++ { + // The initial handshake will generate a + // handshakeComplete signal which needs to be quashed. + if i == 1 && write { + <-stdout.handshakeComplete } - client := Client(clientConn, config) - defer client.Close() - if _, err := client.Write([]byte("hello\n")); err != nil { + // OpenSSL will try to interleave application data and + // a renegotiation if we send both concurrently. + // Therefore: ask OpensSSL to start a renegotiation, run + // a goroutine to call client.Read and thus process the + // renegotiation request, watch for OpenSSL's stdout to + // indicate that the handshake is complete and, + // finally, have OpenSSL write something to cause + // client.Read to complete. + if write { + stdin <- opensslRenegotiate + } + + signalChan := make(chan struct{}) + + go func() { + defer close(signalChan) + + buf := make([]byte, 256) + n, err := client.Read(buf) + + if test.checkRenegotiationError != nil { + newErr := test.checkRenegotiationError(i, err) + if err != nil && newErr == nil { + return + } + err = newErr + } + + if err != nil { + t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err) + return + } + + buf = buf[:n] + if !bytes.Equal([]byte(opensslSentinel), buf) { + t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) + } + + if expected := i + 1; client.handshakes != expected { + t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes) + } + }() + + if write && test.renegotiationExpectedToFail != i { + <-stdout.handshakeComplete + stdin <- opensslSendSentinel + } + <-signalChan + } + + if test.sendKeyUpdate { + if write { + <-stdout.handshakeComplete + stdin <- opensslKeyUpdate + } + + doneRead := make(chan struct{}) + + go func() { + defer close(doneRead) + + buf := make([]byte, 256) + n, err := client.Read(buf) + + if err != nil { + t.Errorf("Client.Read failed after KeyUpdate: %s", err) + return + } + + buf = buf[:n] + if !bytes.Equal([]byte(opensslSentinel), buf) { + t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) + } + }() + + if write { + // There's no real reason to wait for the client KeyUpdate to + // send data with the new server keys, except that s_server + // drops writes if they are sent at the wrong time. + <-stdout.readKeyUpdate + stdin <- opensslSendSentinel + } + <-doneRead + + if _, err := client.Write([]byte("hello again\n")); err != nil { t.Errorf("Client.Write failed: %s", err) return } + } - for i := 1; i <= test.numRenegotiations; i++ { - // The initial handshake will generate a - // handshakeComplete signal which needs to be quashed. - if i == 1 && write { - <-stdout.handshakeComplete - } - - // OpenSSL will try to interleave application data and - // a renegotiation if we send both concurrently. - // Therefore: ask OpensSSL to start a renegotiation, run - // a goroutine to call client.Read and thus process the - // renegotiation request, watch for OpenSSL's stdout to - // indicate that the handshake is complete and, - // finally, have OpenSSL write something to cause - // client.Read to complete. - if write { - stdin <- opensslRenegotiate - } - - signalChan := make(chan struct{}) - - go func() { - defer close(signalChan) - - buf := make([]byte, 256) - n, err := client.Read(buf) - - if test.checkRenegotiationError != nil { - newErr := test.checkRenegotiationError(i, err) - if err != nil && newErr == nil { - return - } - err = newErr - } - - if err != nil { - t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err) - return - } - - buf = buf[:n] - if !bytes.Equal([]byte(opensslSentinel), buf) { - t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) - } - - if expected := i + 1; client.handshakes != expected { - t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes) - } - }() - - if write && test.renegotiationExpectedToFail != i { - <-stdout.handshakeComplete - stdin <- opensslSendSentinel - } - <-signalChan - } - - if test.sendKeyUpdate { - if write { - <-stdout.handshakeComplete - stdin <- opensslKeyUpdate - } - - doneRead := make(chan struct{}) - - go func() { - defer close(doneRead) - - buf := make([]byte, 256) - n, err := client.Read(buf) - - if err != nil { - t.Errorf("Client.Read failed after KeyUpdate: %s", err) - return - } - - buf = buf[:n] - if !bytes.Equal([]byte(opensslSentinel), buf) { - t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) - } - }() - - if write { - // There's no real reason to wait for the client KeyUpdate to - // send data with the new server keys, except that s_server - // drops writes if they are sent at the wrong time. - <-stdout.readKeyUpdate - stdin <- opensslSendSentinel - } - <-doneRead - - if _, err := client.Write([]byte("hello again\n")); err != nil { - t.Errorf("Client.Write failed: %s", err) - return - } - } - - if test.validate != nil { - if err := test.validate(client.ConnectionState()); err != nil { - t.Errorf("validate callback returned error: %s", err) - } - } - - // If the server sent us an alert after our last flight, give it a - // chance to arrive. - if write && test.renegotiationExpectedToFail == 0 { - if err := peekError(client); err != nil { - t.Errorf("final Read returned an error: %s", err) - } - } - }() - - if !write { - flows, err := test.loadData() - if err != nil { - t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err) - } - for i, b := range flows { - if i%2 == 1 { - if *fast { - serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second)) - } else { - serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute)) - } - serverConn.Write(b) - continue - } - bb := make([]byte, len(b)) - if *fast { - serverConn.SetReadDeadline(time.Now().Add(1 * time.Second)) - } else { - serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute)) - } - _, err := io.ReadFull(serverConn, bb) - if err != nil { - t.Fatalf("%s, flow %d: %s", test.name, i+1, err) - } - if !bytes.Equal(b, bb) { - t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b) - } + if test.validate != nil { + if err := test.validate(client.ConnectionState()); err != nil { + t.Errorf("validate callback returned error: %s", err) } } - <-doneChan - if !write { - serverConn.Close() + // If the server sent us an alert after our last flight, give it a + // chance to arrive. + if write && test.renegotiationExpectedToFail == 0 { + if err := peekError(client); err != nil { + t.Errorf("final Read returned an error: %s", err) + } } if write { + clientConn.Close() path := test.dataPath() out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) if err != nil { diff --git a/handshake_server_test.go b/handshake_server_test.go index 44bc8f1..94d3d0f 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -21,6 +21,7 @@ import ( "os/exec" "path/filepath" "runtime" + "slices" "strings" "testing" "time" @@ -659,7 +660,7 @@ func (test *serverTest) loadData() (flows [][]byte, err error) { } func (test *serverTest) run(t *testing.T, write bool) { - var clientConn, serverConn net.Conn + var serverConn net.Conn var recordingConn *recordingConn var childProcess *exec.Cmd @@ -676,65 +677,33 @@ func (test *serverTest) run(t *testing.T, write bool) { } }() } else { - clientConn, serverConn = localPipe(t) + flows, err := test.loadData() + if err != nil { + t.Fatalf("Failed to load data from %s", test.dataPath()) + } + serverConn = &replayingConn{t: t, flows: flows, reading: true} } config := test.config if config == nil { config = testConfig } server := Server(serverConn, config) - connStateChan := make(chan ConnectionState, 1) - go func() { - _, err := server.Write([]byte("hello, world\n")) - if len(test.expectHandshakeErrorIncluding) > 0 { - if err == nil { - t.Errorf("Error expected, but no error returned") - } else if s := err.Error(); !strings.Contains(s, test.expectHandshakeErrorIncluding) { - t.Errorf("Error expected containing '%s' but got '%s'", test.expectHandshakeErrorIncluding, s) - } - } else { - if err != nil { - t.Logf("Error from Server.Write: '%s'", err) - } - } - server.Close() - serverConn.Close() - connStateChan <- server.ConnectionState() - }() - if !write { - flows, err := test.loadData() + _, err := server.Write([]byte("hello, world\n")) + if len(test.expectHandshakeErrorIncluding) > 0 { + if err == nil { + t.Errorf("Error expected, but no error returned") + } else if s := err.Error(); !strings.Contains(s, test.expectHandshakeErrorIncluding) { + t.Errorf("Error expected containing '%s' but got '%s'", test.expectHandshakeErrorIncluding, s) + } + } else { if err != nil { - t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath()) + t.Logf("Error from Server.Write: '%s'", err) } - for i, b := range flows { - if i%2 == 0 { - if *fast { - clientConn.SetWriteDeadline(time.Now().Add(1 * time.Second)) - } else { - clientConn.SetWriteDeadline(time.Now().Add(1 * time.Minute)) - } - clientConn.Write(b) - continue - } - bb := make([]byte, len(b)) - if *fast { - clientConn.SetReadDeadline(time.Now().Add(1 * time.Second)) - } else { - clientConn.SetReadDeadline(time.Now().Add(1 * time.Minute)) - } - n, err := io.ReadFull(clientConn, bb) - if err != nil { - t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b) - } - if !bytes.Equal(b, bb) { - t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b) - } - } - clientConn.Close() } + server.Close() - connState := <-connStateChan + connState := server.ConnectionState() peerCerts := connState.PeerCertificates if len(peerCerts) == len(test.expectedPeerCerts) { for i, peerCert := range peerCerts { @@ -754,6 +723,7 @@ func (test *serverTest) run(t *testing.T, write bool) { } if write { + serverConn.Close() path := test.dataPath() out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) if err != nil { @@ -1330,37 +1300,14 @@ func benchmarkHandshakeServer(b *testing.B, version uint16, cipherSuite uint16, serverConn.Close() flows := serverConn.(*recordingConn).flows - feeder := make(chan struct{}) - clientConn, serverConn = localPipe(b) - - go func() { - for range feeder { - for i, f := range flows { - if i%2 == 0 { - clientConn.Write(f) - continue - } - ff := make([]byte, len(f)) - n, err := io.ReadFull(clientConn, ff) - if err != nil { - b.Errorf("#%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", i+1, err, n, len(ff), ff[:n], f) - } - if !bytes.Equal(f, ff) { - b.Errorf("#%d: mismatch on read: got:%x want:%x", i+1, ff, f) - } - } - } - }() - b.ResetTimer() for i := 0; i < b.N; i++ { - feeder <- struct{}{} - server := Server(serverConn, config) + replay := &replayingConn{t: b, flows: slices.Clone(flows), reading: true} + server := Server(replay, config) if err := server.Handshake(); err != nil { b.Fatalf("handshake failed: %v", err) } } - close(feeder) } func BenchmarkHandshakeServer(b *testing.B) { diff --git a/handshake_test.go b/handshake_test.go index 57fc761..bc3d23d 100644 --- a/handshake_test.go +++ b/handshake_test.go @@ -6,6 +6,7 @@ package tls import ( "bufio" + "bytes" "crypto/ed25519" "crypto/x509" "encoding/hex" @@ -42,7 +43,6 @@ import ( var ( update = flag.Bool("update", false, "update golden files on failure") - fast = flag.Bool("fast", false, "impose a quick, possibly flaky timeout on recorded tests") keyFile = flag.String("keylog", "", "destination file for KeyLogWriter") bogoMode = flag.Bool("bogo-mode", false, "Enabled bogo shim mode, ignore everything else") bogoFilter = flag.String("bogo-filter", "", "BoGo test filter") @@ -223,6 +223,76 @@ func parseTestData(r io.Reader) (flows [][]byte, err error) { return flows, nil } +// replayingConn is a net.Conn that replays flows recorded by recordingConn. +type replayingConn struct { + t testing.TB + sync.Mutex + flows [][]byte + reading bool +} + +var _ net.Conn = (*replayingConn)(nil) + +func (r *replayingConn) Read(b []byte) (n int, err error) { + r.Lock() + defer r.Unlock() + + if !r.reading { + r.t.Errorf("expected write, got read") + return 0, fmt.Errorf("recording expected write, got read") + } + + n = copy(b, r.flows[0]) + r.flows[0] = r.flows[0][n:] + if len(r.flows[0]) == 0 { + r.flows = r.flows[1:] + if len(r.flows) == 0 { + return n, io.EOF + } else { + r.reading = false + } + } + return n, nil +} + +func (r *replayingConn) Write(b []byte) (n int, err error) { + r.Lock() + defer r.Unlock() + + if r.reading { + r.t.Errorf("expected read, got write") + return 0, fmt.Errorf("recording expected read, got write") + } + + if !bytes.HasPrefix(r.flows[0], b) { + r.t.Errorf("write mismatch: expected %x, got %x", r.flows[0], b) + return 0, fmt.Errorf("write mismatch") + } + r.flows[0] = r.flows[0][len(b):] + if len(r.flows[0]) == 0 { + r.flows = r.flows[1:] + r.reading = true + } + return len(b), nil +} + +func (r *replayingConn) Close() error { + r.Lock() + defer r.Unlock() + + if len(r.flows) > 0 { + r.t.Errorf("closed with unfinished flows") + return fmt.Errorf("unexpected close") + } + return nil +} + +func (r *replayingConn) LocalAddr() net.Addr { return nil } +func (r *replayingConn) RemoteAddr() net.Addr { return nil } +func (r *replayingConn) SetDeadline(t time.Time) error { return nil } +func (r *replayingConn) SetReadDeadline(t time.Time) error { return nil } +func (r *replayingConn) SetWriteDeadline(t time.Time) error { return nil } + // tempFile creates a temp file containing contents and returns its path. func tempFile(contents string) string { file, err := os.CreateTemp("", "go-tls-test")