diff --git a/tls_test.go b/tls_test.go index d8a43ad..3e43a56 100644 --- a/tls_test.go +++ b/tls_test.go @@ -170,35 +170,59 @@ func TestDialTimeout(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } - listener := newLocalListener(t) - addr := listener.Addr().String() - defer listener.Close() + timeout := 100 * time.Microsecond + for !t.Failed() { + acceptc := make(chan net.Conn) + listener := newLocalListener(t) + go func() { + for { + conn, err := listener.Accept() + if err != nil { + close(acceptc) + return + } + acceptc <- conn + } + }() - complete := make(chan bool) - defer close(complete) - - go func() { - conn, err := listener.Accept() - if err != nil { - t.Error(err) - return + addr := listener.Addr().String() + dialer := &net.Dialer{ + Timeout: timeout, + } + if conn, err := DialWithDialer(dialer, "tcp", addr, nil); err == nil { + conn.Close() + t.Errorf("DialWithTimeout unexpectedly completed successfully") + } else if !isTimeoutError(err) { + t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err) } - <-complete - conn.Close() - }() - dialer := &net.Dialer{ - Timeout: 10 * time.Millisecond, - } + listener.Close() - var err error - if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil { - t.Fatal("DialWithTimeout completed successfully") - } + // We're looking for a timeout during the handshake, so check that the + // Listener actually accepted the connection to initiate it. (If the server + // takes too long to accept the connection, we might cancel before the + // underlying net.Conn is ever dialed — without ever attempting a + // handshake.) + lconn, ok := <-acceptc + if ok { + // The Listener accepted a connection, so assume that it was from our + // Dial: we triggered the timeout at the point where we wanted it! + t.Logf("Listener accepted a connection from %s", lconn.RemoteAddr()) + lconn.Close() + } + // Close any spurious extra connecitions from the listener. (This is + // possible if there are, for example, stray Dial calls from other tests.) + for extraConn := range acceptc { + t.Logf("spurious extra connection from %s", extraConn.RemoteAddr()) + extraConn.Close() + } + if ok { + break + } - if !isTimeoutError(err) { - t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err) + t.Logf("with timeout %v, DialWithDialer returned before listener accepted any connections; retrying", timeout) + timeout *= 2 } }