mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-04 12:37:35 +03:00
[dev.boringcrypto] all: merge master into dev.boringcrypto
Change-Id: Idd59c37d2fd759b0f73d2ee01b30f72ef4e9aee8
This commit is contained in:
commit
8dbc9ce040
6 changed files with 354 additions and 30 deletions
36
conn.go
36
conn.go
|
@ -24,8 +24,9 @@ import (
|
||||||
// It implements the net.Conn interface.
|
// It implements the net.Conn interface.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
// constant
|
// constant
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
isClient bool
|
isClient bool
|
||||||
|
handshakeFn func() error // (*Conn).clientHandshake or serverHandshake
|
||||||
|
|
||||||
// handshakeStatus is 1 if the connection is currently transferring
|
// handshakeStatus is 1 if the connection is currently transferring
|
||||||
// application data (i.e. is not currently processing a handshake).
|
// application data (i.e. is not currently processing a handshake).
|
||||||
|
@ -162,9 +163,22 @@ type halfConn struct {
|
||||||
trafficSecret []byte // current TLS 1.3 traffic secret
|
trafficSecret []byte // current TLS 1.3 traffic secret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type permamentError struct {
|
||||||
|
err net.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *permamentError) Error() string { return e.err.Error() }
|
||||||
|
func (e *permamentError) Unwrap() error { return e.err }
|
||||||
|
func (e *permamentError) Timeout() bool { return e.err.Timeout() }
|
||||||
|
func (e *permamentError) Temporary() bool { return false }
|
||||||
|
|
||||||
func (hc *halfConn) setErrorLocked(err error) error {
|
func (hc *halfConn) setErrorLocked(err error) error {
|
||||||
hc.err = err
|
if e, ok := err.(net.Error); ok {
|
||||||
return err
|
hc.err = &permamentError{err: e}
|
||||||
|
} else {
|
||||||
|
hc.err = err
|
||||||
|
}
|
||||||
|
return hc.err
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareCipherSpec sets the encryption and MAC states
|
// prepareCipherSpec sets the encryption and MAC states
|
||||||
|
@ -1320,8 +1334,12 @@ func (c *Conn) closeNotify() error {
|
||||||
|
|
||||||
// Handshake runs the client or server handshake
|
// Handshake runs the client or server handshake
|
||||||
// protocol if it has not yet been run.
|
// protocol if it has not yet been run.
|
||||||
// Most uses of this package need not call Handshake
|
//
|
||||||
// explicitly: the first Read or Write will call it automatically.
|
// Most uses of this package need not call Handshake explicitly: the
|
||||||
|
// first Read or Write will call it automatically.
|
||||||
|
//
|
||||||
|
// For control over canceling or setting a timeout on a handshake, use
|
||||||
|
// the Dialer's DialContext method.
|
||||||
func (c *Conn) Handshake() error {
|
func (c *Conn) Handshake() error {
|
||||||
c.handshakeMutex.Lock()
|
c.handshakeMutex.Lock()
|
||||||
defer c.handshakeMutex.Unlock()
|
defer c.handshakeMutex.Unlock()
|
||||||
|
@ -1336,11 +1354,7 @@ func (c *Conn) Handshake() error {
|
||||||
c.in.Lock()
|
c.in.Lock()
|
||||||
defer c.in.Unlock()
|
defer c.in.Unlock()
|
||||||
|
|
||||||
if c.isClient {
|
c.handshakeErr = c.handshakeFn()
|
||||||
c.handshakeErr = c.clientHandshake()
|
|
||||||
} else {
|
|
||||||
c.handshakeErr = c.serverHandshake()
|
|
||||||
}
|
|
||||||
if c.handshakeErr == nil {
|
if c.handshakeErr == nil {
|
||||||
c.handshakes++
|
c.handshakes++
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -355,7 +355,8 @@ func TestAlertForwarding(t *testing.T) {
|
||||||
|
|
||||||
err := Server(s, testConfig).Handshake()
|
err := Server(s, testConfig).Handshake()
|
||||||
s.Close()
|
s.Close()
|
||||||
if e, ok := err.(*net.OpError); !ok || e.Err != error(alertUnknownCA) {
|
var opErr *net.OpError
|
||||||
|
if !errors.As(err, &opErr) || opErr.Err != error(alertUnknownCA) {
|
||||||
t.Errorf("Got error: %s; expected: %s", err, error(alertUnknownCA))
|
t.Errorf("Got error: %s; expected: %s", err, error(alertUnknownCA))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -173,11 +173,8 @@ func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
xShared, _ := curve.ScalarMult(x, y, p.privateKey)
|
xShared, _ := curve.ScalarMult(x, y, p.privateKey)
|
||||||
sharedKey := make([]byte, (curve.Params().BitSize+7)>>3)
|
sharedKey := make([]byte, (curve.Params().BitSize+7)/8)
|
||||||
xBytes := xShared.Bytes()
|
return xShared.FillBytes(sharedKey)
|
||||||
copy(sharedKey[len(sharedKey)-len(xBytes):], xBytes)
|
|
||||||
|
|
||||||
return sharedKey
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type x25519Parameters struct {
|
type x25519Parameters struct {
|
||||||
|
|
121
link_test.go
Normal file
121
link_test.go
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
// Copyright 2020 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"internal/testenv"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tests that the linker is able to remove references to the Client or Server if unused.
|
||||||
|
func TestLinkerGC(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping in short mode")
|
||||||
|
}
|
||||||
|
t.Parallel()
|
||||||
|
goBin := testenv.GoToolPath(t)
|
||||||
|
testenv.MustHaveGoBuild(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
program string
|
||||||
|
want []string
|
||||||
|
bad []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty_import",
|
||||||
|
program: `package main
|
||||||
|
import _ "crypto/tls"
|
||||||
|
func main() {}
|
||||||
|
`,
|
||||||
|
bad: []string{
|
||||||
|
"tls.(*Conn)",
|
||||||
|
"type.crypto/tls.clientHandshakeState",
|
||||||
|
"type.crypto/tls.serverHandshakeState",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only_conn",
|
||||||
|
program: `package main
|
||||||
|
import "crypto/tls"
|
||||||
|
var c = new(tls.Conn)
|
||||||
|
func main() {}
|
||||||
|
`,
|
||||||
|
want: []string{"tls.(*Conn)"},
|
||||||
|
bad: []string{
|
||||||
|
"type.crypto/tls.clientHandshakeState",
|
||||||
|
"type.crypto/tls.serverHandshakeState",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "client_and_server",
|
||||||
|
program: `package main
|
||||||
|
import "crypto/tls"
|
||||||
|
func main() {
|
||||||
|
tls.Dial("", "", nil)
|
||||||
|
tls.Server(nil, nil)
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: []string{
|
||||||
|
"crypto/tls.(*Conn).clientHandshake",
|
||||||
|
"crypto/tls.(*Conn).serverHandshake",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only_client",
|
||||||
|
program: `package main
|
||||||
|
import "crypto/tls"
|
||||||
|
func main() { tls.Dial("", "", nil) }
|
||||||
|
`,
|
||||||
|
want: []string{
|
||||||
|
"crypto/tls.(*Conn).clientHandshake",
|
||||||
|
},
|
||||||
|
bad: []string{
|
||||||
|
"crypto/tls.(*Conn).serverHandshake",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// TODO: add only_server like func main() { tls.Server(nil, nil) }
|
||||||
|
// That currently brings in the client via Conn.handleRenegotiation.
|
||||||
|
|
||||||
|
}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
goFile := filepath.Join(tmpDir, "x.go")
|
||||||
|
exeFile := filepath.Join(tmpDir, "x.exe")
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := ioutil.WriteFile(goFile, []byte(tt.program), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
os.Remove(exeFile)
|
||||||
|
cmd := exec.Command(goBin, "build", "-o", "x.exe", "x.go")
|
||||||
|
cmd.Dir = tmpDir
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
t.Fatalf("compile: %v, %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = exec.Command(goBin, "tool", "nm", "x.exe")
|
||||||
|
cmd.Dir = tmpDir
|
||||||
|
nm, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("nm: %v, %s", err, nm)
|
||||||
|
}
|
||||||
|
for _, sym := range tt.want {
|
||||||
|
if !bytes.Contains(nm, []byte(sym)) {
|
||||||
|
t.Errorf("expected symbol %q not found", sym)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, sym := range tt.bad {
|
||||||
|
if bytes.Contains(nm, []byte(sym)) {
|
||||||
|
t.Errorf("unexpected symbol %q found", sym)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
104
tls.go
104
tls.go
|
@ -13,6 +13,7 @@ package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
@ -32,7 +33,12 @@ import (
|
||||||
// The configuration config must be non-nil and must include
|
// The configuration config must be non-nil and must include
|
||||||
// at least one certificate or else set GetCertificate.
|
// at least one certificate or else set GetCertificate.
|
||||||
func Server(conn net.Conn, config *Config) *Conn {
|
func Server(conn net.Conn, config *Config) *Conn {
|
||||||
return &Conn{conn: conn, config: config}
|
c := &Conn{
|
||||||
|
conn: conn,
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
c.handshakeFn = c.serverHandshake
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client returns a new TLS client side connection
|
// Client returns a new TLS client side connection
|
||||||
|
@ -40,7 +46,13 @@ func Server(conn net.Conn, config *Config) *Conn {
|
||||||
// The config cannot be nil: users must set either ServerName or
|
// The config cannot be nil: users must set either ServerName or
|
||||||
// InsecureSkipVerify in the config.
|
// InsecureSkipVerify in the config.
|
||||||
func Client(conn net.Conn, config *Config) *Conn {
|
func Client(conn net.Conn, config *Config) *Conn {
|
||||||
return &Conn{conn: conn, config: config, isClient: true}
|
c := &Conn{
|
||||||
|
conn: conn,
|
||||||
|
config: config,
|
||||||
|
isClient: true,
|
||||||
|
}
|
||||||
|
c.handshakeFn = c.clientHandshake
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// A listener implements a network listener (net.Listener) for TLS connections.
|
// A listener implements a network listener (net.Listener) for TLS connections.
|
||||||
|
@ -100,29 +112,35 @@ func (timeoutError) Temporary() bool { return true }
|
||||||
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
||||||
// configuration; see the documentation of Config for the defaults.
|
// configuration; see the documentation of Config for the defaults.
|
||||||
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||||
|
return dial(context.Background(), dialer, network, addr, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||||
// We want the Timeout and Deadline values from dialer to cover the
|
// We want the Timeout and Deadline values from dialer to cover the
|
||||||
// whole process: TCP connection and TLS handshake. This means that we
|
// whole process: TCP connection and TLS handshake. This means that we
|
||||||
// also need to start our own timers now.
|
// also need to start our own timers now.
|
||||||
timeout := dialer.Timeout
|
timeout := netDialer.Timeout
|
||||||
|
|
||||||
if !dialer.Deadline.IsZero() {
|
if !netDialer.Deadline.IsZero() {
|
||||||
deadlineTimeout := time.Until(dialer.Deadline)
|
deadlineTimeout := time.Until(netDialer.Deadline)
|
||||||
if timeout == 0 || deadlineTimeout < timeout {
|
if timeout == 0 || deadlineTimeout < timeout {
|
||||||
timeout = deadlineTimeout
|
timeout = deadlineTimeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var errChannel chan error
|
// hsErrCh is non-nil if we might not wait for Handshake to complete.
|
||||||
|
var hsErrCh chan error
|
||||||
|
if timeout != 0 || ctx.Done() != nil {
|
||||||
|
hsErrCh = make(chan error, 2)
|
||||||
|
}
|
||||||
if timeout != 0 {
|
if timeout != 0 {
|
||||||
errChannel = make(chan error, 2)
|
|
||||||
timer := time.AfterFunc(timeout, func() {
|
timer := time.AfterFunc(timeout, func() {
|
||||||
errChannel <- timeoutError{}
|
hsErrCh <- timeoutError{}
|
||||||
})
|
})
|
||||||
defer timer.Stop()
|
defer timer.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
rawConn, err := dialer.Dial(network, addr)
|
rawConn, err := netDialer.DialContext(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -147,14 +165,26 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
|
||||||
|
|
||||||
conn := Client(rawConn, config)
|
conn := Client(rawConn, config)
|
||||||
|
|
||||||
if timeout == 0 {
|
if hsErrCh == nil {
|
||||||
err = conn.Handshake()
|
err = conn.Handshake()
|
||||||
} else {
|
} else {
|
||||||
go func() {
|
go func() {
|
||||||
errChannel <- conn.Handshake()
|
hsErrCh <- conn.Handshake()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = <-errChannel
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
err = ctx.Err()
|
||||||
|
case err = <-hsErrCh:
|
||||||
|
if err != nil {
|
||||||
|
// If the error was due to the context
|
||||||
|
// closing, prefer the context's error, rather
|
||||||
|
// than some random network teardown error.
|
||||||
|
if e := ctx.Err(); e != nil {
|
||||||
|
err = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -175,6 +205,54 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
|
||||||
return DialWithDialer(new(net.Dialer), network, addr, config)
|
return DialWithDialer(new(net.Dialer), network, addr, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dialer dials TLS connections given a configuration and a Dialer for the
|
||||||
|
// underlying connection.
|
||||||
|
type Dialer struct {
|
||||||
|
// NetDialer is the optional dialer to use for the TLS connections'
|
||||||
|
// underlying TCP connections.
|
||||||
|
// A nil NetDialer is equivalent to the net.Dialer zero value.
|
||||||
|
NetDialer *net.Dialer
|
||||||
|
|
||||||
|
// Config is the TLS configuration to use for new connections.
|
||||||
|
// A nil configuration is equivalent to the zero
|
||||||
|
// configuration; see the documentation of Config for the
|
||||||
|
// defaults.
|
||||||
|
Config *Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial connects to the given network address and initiates a TLS
|
||||||
|
// handshake, returning the resulting TLS connection.
|
||||||
|
//
|
||||||
|
// The returned Conn, if any, will always be of type *Conn.
|
||||||
|
func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
|
||||||
|
return d.DialContext(context.Background(), network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dialer) netDialer() *net.Dialer {
|
||||||
|
if d.NetDialer != nil {
|
||||||
|
return d.NetDialer
|
||||||
|
}
|
||||||
|
return new(net.Dialer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial connects to the given network address and initiates a TLS
|
||||||
|
// handshake, returning the resulting TLS connection.
|
||||||
|
//
|
||||||
|
// The provided Context must be non-nil. If the context expires before
|
||||||
|
// the connection is complete, an error is returned. Once successfully
|
||||||
|
// connected, any expiration of the context will not affect the
|
||||||
|
// connection.
|
||||||
|
//
|
||||||
|
// The returned Conn, if any, will always be of type *Conn.
|
||||||
|
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
|
||||||
|
if err != nil {
|
||||||
|
// Don't return c (a typed nil) in an interface.
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
// LoadX509KeyPair reads and parses a public/private key pair from a pair
|
// LoadX509KeyPair reads and parses a public/private key pair from a pair
|
||||||
// of files. The files must contain PEM encoded data. The certificate file
|
// of files. The files must contain PEM encoded data. The certificate file
|
||||||
// may contain intermediate certificates following the leaf certificate to
|
// may contain intermediate certificates following the leaf certificate to
|
||||||
|
|
113
tls_test.go
113
tls_test.go
|
@ -6,6 +6,7 @@ package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -201,6 +202,118 @@ func TestDialTimeout(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeadlineOnWrite(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
ln := newLocalListener(t)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
srvCh := make(chan *Conn, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
sconn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
srvCh <- nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srv := Server(sconn, testConfig.Clone())
|
||||||
|
if err := srv.Handshake(); err != nil {
|
||||||
|
srvCh <- nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srvCh <- srv
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientConfig := testConfig.Clone()
|
||||||
|
clientConfig.MaxVersion = VersionTLS12
|
||||||
|
conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
srv := <-srvCh
|
||||||
|
if srv == nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the client/server is setup correctly and is able to do a typical Write/Read
|
||||||
|
buf := make([]byte, 6)
|
||||||
|
if _, err := srv.Write([]byte("foobar")); err != nil {
|
||||||
|
t.Errorf("Write err: %v", err)
|
||||||
|
}
|
||||||
|
if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" {
|
||||||
|
t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a deadline which should cause Write to timeout
|
||||||
|
if err = srv.SetDeadline(time.Now()); err != nil {
|
||||||
|
t.Fatalf("SetDeadline(time.Now()) err: %v", err)
|
||||||
|
}
|
||||||
|
if _, err = srv.Write([]byte("should fail")); err == nil {
|
||||||
|
t.Fatal("Write should have timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear deadline and make sure it still times out
|
||||||
|
if err = srv.SetDeadline(time.Time{}); err != nil {
|
||||||
|
t.Fatalf("SetDeadline(time.Time{}) err: %v", err)
|
||||||
|
}
|
||||||
|
if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil {
|
||||||
|
t.Fatal("Write which previously failed should still time out")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the error
|
||||||
|
if ne := err.(net.Error); ne.Temporary() != false {
|
||||||
|
t.Error("Write timed out but incorrectly classified the error as Temporary")
|
||||||
|
}
|
||||||
|
if !isTimeoutError(err) {
|
||||||
|
t.Error("Write timed out but did not classify the error as a Timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type readerFunc func([]byte) (int, error)
|
||||||
|
|
||||||
|
func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
|
||||||
|
|
||||||
|
// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
|
||||||
|
// (The other cases are all handled by the existing dial tests in this package, which
|
||||||
|
// all also flow through the same code shared code paths)
|
||||||
|
func TestDialer(t *testing.T) {
|
||||||
|
ln := newLocalListener(t)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
unblockServer := make(chan struct{}) // close-only
|
||||||
|
defer close(unblockServer)
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
<-unblockServer
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
d := Dialer{Config: &Config{
|
||||||
|
Rand: readerFunc(func(b []byte) (n int, err error) {
|
||||||
|
// By the time crypto/tls wants randomness, that means it has a TCP
|
||||||
|
// connection, so we're past the Dialer's dial and now blocked
|
||||||
|
// in a handshake. Cancel our context and see if we get unstuck.
|
||||||
|
// (Our TCP listener above never reads or writes, so the Handshake
|
||||||
|
// would otherwise be stuck forever)
|
||||||
|
cancel()
|
||||||
|
return len(b), nil
|
||||||
|
}),
|
||||||
|
ServerName: "foo",
|
||||||
|
}}
|
||||||
|
_, err := d.DialContext(ctx, "tcp", ln.Addr().String())
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Errorf("err = %v; want context.Canceled", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func isTimeoutError(err error) bool {
|
func isTimeoutError(err error) bool {
|
||||||
if ne, ok := err.(net.Error); ok {
|
if ne, ok := err.(net.Error); ok {
|
||||||
return ne.Timeout()
|
return ne.Timeout()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue