From dac782ca098e35f820f129e6535a19cb310fceb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 21 Feb 2023 21:05:04 +0800 Subject: [PATCH] Make server address optional --- client.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 9825f03..141be0f 100644 --- a/client.go +++ b/client.go @@ -42,9 +42,7 @@ func NewClient(config ClientConfig) (*Client, error) { tlsHandshake: config.TLSHandshake, logger: config.Logger, } - if !client.server.IsValid() || client.dialer == nil { - return nil, os.ErrInvalid - } + switch client.version { case 1, 2, 3: default: @@ -61,10 +59,17 @@ func (c *Client) SetHandshakeFunc(handshakeFunc TLSHandshakeFunc) { } func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { + if !c.server.IsValid() { + return nil, os.ErrInvalid + } conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server) if err != nil { return nil, err } + return c.DialContextConn(ctx, conn) +} + +func (c *Client) DialContextConn(ctx context.Context, conn net.Conn) (net.Conn, error) { if c.tlsHandshake == nil { return nil, os.ErrInvalid } @@ -72,7 +77,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { default: fallthrough case 1: - err = c.tlsHandshake(ctx, conn, nil) + err := c.tlsHandshake(ctx, conn, nil) if err != nil { return nil, err } @@ -80,7 +85,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { return conn, nil case 2: hashConn := newHashReadConn(conn, c.password) - err = c.tlsHandshake(ctx, hashConn, nil) + err := c.tlsHandshake(ctx, hashConn, nil) if err != nil { return nil, err } @@ -88,7 +93,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { return newClientConn(hashConn), nil case 3: stream := newStreamWrapper(conn, c.password) - err = c.tlsHandshake(ctx, stream, generateSessionID(c.password)) + err := c.tlsHandshake(ctx, stream, generateSessionID(c.password)) if err != nil { return nil, err }