mirror of
https://github.com/SagerNet/sing-shadowtls.git
synced 2025-03-31 10:47:35 +03:00
126 lines
3.1 KiB
Go
126 lines
3.1 KiB
Go
package shadowtls
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha1"
|
|
"encoding/hex"
|
|
"net"
|
|
"os"
|
|
|
|
"github.com/sagernet/sing/common/debug"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
"github.com/sagernet/sing/common/logger"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
N "github.com/sagernet/sing/common/network"
|
|
)
|
|
|
|
type ClientConfig struct {
|
|
Version int
|
|
Password string
|
|
Server M.Socksaddr
|
|
Dialer N.Dialer
|
|
StrictMode bool
|
|
TLSHandshake TLSHandshakeFunc
|
|
Logger logger.ContextLogger
|
|
}
|
|
|
|
type Client struct {
|
|
version int
|
|
password string
|
|
strictMode bool
|
|
server M.Socksaddr
|
|
dialer N.Dialer
|
|
tlsHandshake TLSHandshakeFunc
|
|
logger logger.ContextLogger
|
|
}
|
|
|
|
func NewClient(config ClientConfig) (*Client, error) {
|
|
client := &Client{
|
|
version: config.Version,
|
|
password: config.Password,
|
|
strictMode: config.StrictMode,
|
|
server: config.Server,
|
|
dialer: config.Dialer,
|
|
tlsHandshake: config.TLSHandshake,
|
|
logger: config.Logger,
|
|
}
|
|
|
|
switch client.version {
|
|
case 1, 2, 3:
|
|
default:
|
|
return nil, E.New("unknown protocol version: ", client.version)
|
|
}
|
|
if client.dialer == nil {
|
|
client.dialer = N.SystemDialer
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
func (c *Client) SetHandshakeFunc(handshakeFunc TLSHandshakeFunc) {
|
|
c.tlsHandshake = handshakeFunc
|
|
}
|
|
|
|
func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
|
|
if !c.server.IsValid() {
|
|
return nil, os.ErrInvalid
|
|
}
|
|
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
shadowTLSConn, err := c.DialContextConn(ctx, conn)
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, err
|
|
}
|
|
return shadowTLSConn, nil
|
|
}
|
|
|
|
func (c *Client) DialContextConn(ctx context.Context, conn net.Conn) (net.Conn, error) {
|
|
if c.tlsHandshake == nil {
|
|
return nil, os.ErrInvalid
|
|
}
|
|
switch c.version {
|
|
default:
|
|
fallthrough
|
|
case 1:
|
|
err := c.tlsHandshake(ctx, conn, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.logger.TraceContext(ctx, "clint handshake finished")
|
|
return conn, nil
|
|
case 2:
|
|
hashConn := newHashReadConn(conn, c.password)
|
|
err := c.tlsHandshake(ctx, hashConn, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.logger.TraceContext(ctx, "clint handshake finished")
|
|
return newClientConn(hashConn), nil
|
|
case 3:
|
|
stream := newStreamWrapper(conn, c.password)
|
|
err := c.tlsHandshake(ctx, stream, generateSessionID(c.password))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.logger.TraceContext(ctx, "handshake success")
|
|
isTLS13, authorized, serverRandom, readHMAC := stream.Authorized()
|
|
if c.strictMode && !isTLS13 {
|
|
return nil, E.New("TLS1.3 is not supported")
|
|
} else if !authorized {
|
|
return nil, E.New("traffic hijacked")
|
|
}
|
|
if debug.Enabled {
|
|
c.logger.TraceContext(ctx, "authorized, server random extracted: ", hex.EncodeToString(serverRandom))
|
|
}
|
|
hmacAdd := hmac.New(sha1.New, []byte(c.password))
|
|
hmacAdd.Write(serverRandom)
|
|
hmacAdd.Write([]byte("C"))
|
|
hmacVerify := hmac.New(sha1.New, []byte(c.password))
|
|
hmacVerify.Write(serverRandom)
|
|
hmacVerify.Write([]byte("S"))
|
|
return newVerifiedConn(conn, hmacAdd, hmacVerify, readHMAC), nil
|
|
}
|
|
}
|