diff --git a/client.go b/client.go index fe44c8a..9825f03 100644 --- a/client.go +++ b/client.go @@ -42,7 +42,7 @@ func NewClient(config ClientConfig) (*Client, error) { tlsHandshake: config.TLSHandshake, logger: config.Logger, } - if !client.server.IsValid() || client.dialer == nil || client.tlsHandshake == nil { + if !client.server.IsValid() || client.dialer == nil { return nil, os.ErrInvalid } switch client.version { @@ -56,11 +56,18 @@ func NewClient(config ClientConfig) (*Client, error) { return client, nil } +func (c *Client) SetHandshakeFunc(handshakeFunc TLSHandshakeFunc) { + c.tlsHandshake = handshakeFunc +} + func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server) if err != nil { return nil, err } + if c.tlsHandshake == nil { + return nil, os.ErrInvalid + } switch c.version { default: fallthrough