sing/protocol/shadowsocks/client.go
2022-04-08 11:14:06 +08:00

178 lines
4.4 KiB
Go

package shadowsocks
import (
"context"
"io"
"net"
"strconv"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/socksaddr"
"github.com/sagernet/sing/protocol/socks"
)
var (
ErrBadKey = exceptions.New("bad key")
ErrMissingPassword = exceptions.New("password not specified")
)
type ClientConfig struct {
Server string `json:"server"`
ServerPort uint16 `json:"server_port"`
Method string `json:"method"`
Password []byte `json:"password"`
Key []byte `json:"key"`
}
type Client struct {
dialer *net.Dialer
cipher Cipher
server string
key []byte
}
func NewClient(dialer *net.Dialer, config *ClientConfig) (*Client, error) {
if config.Server == "" {
return nil, exceptions.New("missing server address")
}
if config.ServerPort == 0 {
return nil, exceptions.New("missing server port")
}
if config.Method == "" {
return nil, exceptions.New("missing server method")
}
cipher, err := CreateCipher(config.Method)
if err != nil {
return nil, err
}
client := &Client{
dialer: dialer,
cipher: cipher,
server: net.JoinHostPort(config.Server, strconv.Itoa(int(config.ServerPort))),
}
if keyLen := len(config.Key); keyLen > 0 {
if keyLen == cipher.KeySize() {
client.key = config.Key
} else {
return nil, ErrBadKey
}
} else if len(config.Password) > 0 {
client.key = Key(config.Password, cipher.KeySize())
} else {
return nil, ErrMissingPassword
}
return client, nil
}
func (c *Client) DialContextTCP(ctx context.Context, addr socksaddr.Addr, port uint16) (net.Conn, error) {
conn, err := c.dialer.DialContext(ctx, "tcp", c.server)
if err != nil {
return nil, exceptions.Cause(err, "connect to server")
}
return c.DialConn(conn, addr, port), nil
}
func (c *Client) DialConn(conn net.Conn, addr socksaddr.Addr, port uint16) net.Conn {
header := buf.New()
header.WriteRandom(c.cipher.SaltSize())
writer := &buf.BufferedWriter{
Writer: conn,
Buffer: header,
}
protocolWriter := c.cipher.CreateWriter(c.key, header.Bytes(), writer)
requestBuffer := buf.New()
contentWriter := &buf.BufferedWriter{
Writer: protocolWriter,
Buffer: requestBuffer,
}
common.Must(AddressSerializer.WriteAddressAndPort(contentWriter, addr, port))
return &shadowsocksConn{
Client: c,
Conn: conn,
Writer: &common.FlushOnceWriter{Writer: contentWriter},
}
}
type shadowsocksConn struct {
*Client
net.Conn
io.Writer
reader io.Reader
}
func (c *shadowsocksConn) Read(p []byte) (n int, err error) {
if c.reader == nil {
buffer := buf.Or(p, c.cipher.SaltSize())
defer buffer.Release()
_, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize())
if err != nil {
return
}
c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn)
}
return c.reader.Read(p)
}
func (c *shadowsocksConn) WriteTo(w io.Writer) (n int64, err error) {
if c.reader == nil {
buffer := buf.NewSize(c.cipher.SaltSize())
defer buffer.Release()
_, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize())
if err != nil {
return
}
c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn)
}
return c.reader.(io.WriterTo).WriteTo(w)
}
func (c *shadowsocksConn) Write(p []byte) (n int, err error) {
return c.Writer.Write(p)
}
func (c *shadowsocksConn) ReadFrom(r io.Reader) (n int64, err error) {
return rw.ReadFromVar(&c.Writer, r)
}
func (c *Client) DialContextUDP(ctx context.Context) socks.PacketConn {
conn, err := c.dialer.DialContext(ctx, "udp", c.server)
if err != nil {
return nil
}
return &shadowsocksPacketConn{c, conn}
}
type shadowsocksPacketConn struct {
*Client
net.Conn
}
func (c *shadowsocksPacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error {
defer buffer.Release()
header := buf.New()
header.WriteRandom(c.cipher.SaltSize())
common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port))
buffer = buffer.WriteBufferAtFirst(header)
err := c.cipher.EncodePacket(c.key, buffer)
if err != nil {
return err
}
return common.Error(c.Conn.Write(buffer.Bytes()))
}
func (c *shadowsocksPacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return nil, 0, err
}
buffer.Truncate(n)
err = c.cipher.DecodePacket(c.key, buffer)
if err != nil {
return nil, 0, err
}
return AddressSerializer.ReadAddressAndPort(buffer)
}