mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 20:07:38 +03:00
Refactor shadowsocks
This commit is contained in:
parent
3f23b25edf
commit
00cd0d4b8f
75 changed files with 3169 additions and 1318 deletions
192
protocol/http/listener.go
Normal file
192
protocol/http/listener.go
Normal file
|
@ -0,0 +1,192 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common/auth"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/transport/tcp"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
tcp.Handler
|
||||
}
|
||||
|
||||
func HandleConnection(conn *buf.BufferedConn, authenticator auth.Authenticator, handler Handler) error {
|
||||
var httpClient *http.Client
|
||||
for {
|
||||
request, err := readRequest(conn.Reader())
|
||||
if err != nil {
|
||||
return E.Cause(err, "read http request")
|
||||
}
|
||||
|
||||
if authenticator != nil {
|
||||
var authOk bool
|
||||
authorization := request.Header.Get("Proxy-Authorization")
|
||||
if strings.HasPrefix(authorization, "BASIC ") {
|
||||
userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
|
||||
userPswdArr := strings.SplitN(string(userPassword), ":", 2)
|
||||
authOk = authenticator.Verify(userPswdArr[0], userPswdArr[1])
|
||||
}
|
||||
if !authOk {
|
||||
err = responseWith(request, http.StatusProxyAuthRequired).Write(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if request.Method == "CONNECT" {
|
||||
portStr := request.URL.Port()
|
||||
if portStr == "" {
|
||||
portStr = "80"
|
||||
}
|
||||
destination, err := M.ParseAddrPort(request.URL.Hostname(), portStr)
|
||||
if err != nil {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err = fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established")
|
||||
if err != nil {
|
||||
return E.Cause(err, "write http response")
|
||||
}
|
||||
return handler.NewConnection(conn, M.Metadata{
|
||||
Destination: destination,
|
||||
})
|
||||
}
|
||||
|
||||
keepAlive := strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
|
||||
|
||||
host := request.Header.Get("Host")
|
||||
if host != "" {
|
||||
request.Host = host
|
||||
}
|
||||
|
||||
request.RequestURI = ""
|
||||
|
||||
removeHopByHopHeaders(request.Header)
|
||||
removeExtraHTTPHostPort(request)
|
||||
|
||||
if request.URL.Scheme == "" || request.URL.Host == "" {
|
||||
return responseWith(request, http.StatusBadRequest).Write(conn)
|
||||
}
|
||||
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
DialContext: func(context context.Context, network, address string) (net.Conn, error) {
|
||||
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
||||
return nil, E.New("unsupported network ", network)
|
||||
}
|
||||
|
||||
destination, err := M.ParseAddress(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
left, right := net.Pipe()
|
||||
go func() {
|
||||
err = handler.NewConnection(right, M.Metadata{
|
||||
Destination: destination,
|
||||
})
|
||||
if err != nil {
|
||||
handler.HandleError(err)
|
||||
}
|
||||
}()
|
||||
return left, nil
|
||||
},
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
handler.HandleError(err)
|
||||
return responseWith(request, http.StatusBadGateway).Write(conn)
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(response.Header)
|
||||
|
||||
if keepAlive {
|
||||
response.Header.Set("Proxy-Connection", "keep-alive")
|
||||
response.Header.Set("Connection", "keep-alive")
|
||||
response.Header.Set("Keep-Alive", "timeout=4")
|
||||
}
|
||||
|
||||
response.Close = !keepAlive
|
||||
|
||||
err = response.Write(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//go:linkname readRequest net/http.ReadRequest
|
||||
func readRequest(b *bufio.Reader) (req *http.Request, err error)
|
||||
|
||||
func removeHopByHopHeaders(header http.Header) {
|
||||
// Strip hop-by-hop header based on RFC:
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1
|
||||
// https://www.mnot.net/blog/2011/07/11/what_proxies_must_do
|
||||
|
||||
header.Del("Proxy-Connection")
|
||||
header.Del("Proxy-Authenticate")
|
||||
header.Del("Proxy-Authorization")
|
||||
header.Del("TE")
|
||||
header.Del("Trailers")
|
||||
header.Del("Transfer-Encoding")
|
||||
header.Del("Upgrade")
|
||||
|
||||
connections := header.Get("Connection")
|
||||
header.Del("Connection")
|
||||
if len(connections) == 0 {
|
||||
return
|
||||
}
|
||||
for _, h := range strings.Split(connections, ",") {
|
||||
header.Del(strings.TrimSpace(h))
|
||||
}
|
||||
}
|
||||
|
||||
func removeExtraHTTPHostPort(req *http.Request) {
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
|
||||
if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" {
|
||||
host = pHost
|
||||
}
|
||||
|
||||
req.Host = host
|
||||
req.URL.Host = host
|
||||
}
|
||||
|
||||
func responseWith(request *http.Request, statusCode int) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Status: http.StatusText(statusCode),
|
||||
Proto: request.Proto,
|
||||
ProtoMajor: request.ProtoMajor,
|
||||
ProtoMinor: request.ProtoMinor,
|
||||
Header: http.Header{},
|
||||
}
|
||||
}
|
0
protocol/http/stub.s
Normal file
0
protocol/http/stub.s
Normal file
|
@ -1,47 +0,0 @@
|
|||
package shadowsocks
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/list"
|
||||
)
|
||||
|
||||
type Cipher interface {
|
||||
KeySize() int
|
||||
SaltSize() int
|
||||
CreateReader(key []byte, salt []byte, reader io.Reader) io.Reader
|
||||
CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer
|
||||
EncodePacket(key []byte, buffer *buf.Buffer) error
|
||||
DecodePacket(key []byte, buffer *buf.Buffer) error
|
||||
}
|
||||
|
||||
type CipherCreator func() Cipher
|
||||
|
||||
var (
|
||||
cipherList *list.List[string]
|
||||
cipherMap map[string]CipherCreator
|
||||
)
|
||||
|
||||
func init() {
|
||||
cipherList = new(list.List[string])
|
||||
cipherMap = make(map[string]CipherCreator)
|
||||
}
|
||||
|
||||
func RegisterCipher(method string, creator CipherCreator) {
|
||||
cipherList.PushBack(method)
|
||||
cipherMap[method] = creator
|
||||
}
|
||||
|
||||
func CreateCipher(method string) (Cipher, error) {
|
||||
creator := cipherMap[method]
|
||||
if creator != nil {
|
||||
return creator(), nil
|
||||
}
|
||||
return nil, exceptions.New("unsupported method: ", method)
|
||||
}
|
||||
|
||||
func ListCiphers() []string {
|
||||
return cipherList.Array()
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
package shadowsocks
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterCipher("none", func() Cipher {
|
||||
return (*NoneCipher)(nil)
|
||||
})
|
||||
}
|
||||
|
||||
type NoneCipher struct{}
|
||||
|
||||
func (c *NoneCipher) KeySize() int {
|
||||
return 16
|
||||
}
|
||||
|
||||
func (c *NoneCipher) SaltSize() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *NoneCipher) CreateReader(_ []byte, _ []byte, reader io.Reader) io.Reader {
|
||||
return reader
|
||||
}
|
||||
|
||||
func (c *NoneCipher) CreateWriter(key []byte, iv []byte, writer io.Writer) io.Writer {
|
||||
return writer
|
||||
}
|
||||
|
||||
func (c *NoneCipher) EncodePacket([]byte, *buf.Buffer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NoneCipher) DecodePacket([]byte, *buf.Buffer) error {
|
||||
return nil
|
||||
}
|
|
@ -1,178 +0,0 @@
|
|||
package shadowsocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBadKey = exceptions.New("bad key")
|
||||
ErrMissingPassword = exceptions.New("password not specified")
|
||||
)
|
||||
|
||||
type ClientConfig struct {
|
||||
Server string `json:"server"`
|
||||
ServerPort uint16 `json:"server_port"`
|
||||
Method string `json:"method"`
|
||||
Password []byte `json:"password"`
|
||||
Key []byte `json:"key"`
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
dialer *net.Dialer
|
||||
cipher Cipher
|
||||
server string
|
||||
key []byte
|
||||
}
|
||||
|
||||
func NewClient(dialer *net.Dialer, config *ClientConfig) (*Client, error) {
|
||||
if config.Server == "" {
|
||||
return nil, exceptions.New("missing server address")
|
||||
}
|
||||
if config.ServerPort == 0 {
|
||||
return nil, exceptions.New("missing server port")
|
||||
}
|
||||
if config.Method == "" {
|
||||
return nil, exceptions.New("missing server method")
|
||||
}
|
||||
|
||||
cipher, err := CreateCipher(config.Method)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := &Client{
|
||||
dialer: dialer,
|
||||
cipher: cipher,
|
||||
server: net.JoinHostPort(config.Server, strconv.Itoa(int(config.ServerPort))),
|
||||
}
|
||||
if keyLen := len(config.Key); keyLen > 0 {
|
||||
if keyLen == cipher.KeySize() {
|
||||
client.key = config.Key
|
||||
} else {
|
||||
return nil, ErrBadKey
|
||||
}
|
||||
} else if len(config.Password) > 0 {
|
||||
client.key = Key(config.Password, cipher.KeySize())
|
||||
} else {
|
||||
return nil, ErrMissingPassword
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) DialContextTCP(ctx context.Context, addr socksaddr.Addr, port uint16) (net.Conn, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, "tcp", c.server)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "connect to server")
|
||||
}
|
||||
return c.DialConn(conn, addr, port), nil
|
||||
}
|
||||
|
||||
func (c *Client) DialConn(conn net.Conn, addr socksaddr.Addr, port uint16) net.Conn {
|
||||
header := buf.New()
|
||||
header.WriteRandom(c.cipher.SaltSize())
|
||||
writer := &buf.BufferedWriter{
|
||||
Writer: conn,
|
||||
Buffer: header,
|
||||
}
|
||||
protocolWriter := c.cipher.CreateWriter(c.key, header.Bytes(), writer)
|
||||
requestBuffer := buf.New()
|
||||
contentWriter := &buf.BufferedWriter{
|
||||
Writer: protocolWriter,
|
||||
Buffer: requestBuffer,
|
||||
}
|
||||
common.Must(AddressSerializer.WriteAddressAndPort(contentWriter, addr, port))
|
||||
return &shadowsocksConn{
|
||||
Client: c,
|
||||
Conn: conn,
|
||||
Writer: &common.FlushOnceWriter{Writer: contentWriter},
|
||||
}
|
||||
}
|
||||
|
||||
type shadowsocksConn struct {
|
||||
*Client
|
||||
net.Conn
|
||||
io.Writer
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (c *shadowsocksConn) Read(p []byte) (n int, err error) {
|
||||
if c.reader == nil {
|
||||
buffer := buf.Or(p, c.cipher.SaltSize())
|
||||
defer buffer.Release()
|
||||
_, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn)
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *shadowsocksConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if c.reader == nil {
|
||||
buffer := buf.NewSize(c.cipher.SaltSize())
|
||||
defer buffer.Release()
|
||||
_, err = buffer.ReadFullFrom(c.Conn, c.cipher.SaltSize())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.reader = c.cipher.CreateReader(c.key, buffer.Bytes(), c.Conn)
|
||||
}
|
||||
return c.reader.(io.WriterTo).WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *shadowsocksConn) Write(p []byte) (n int, err error) {
|
||||
return c.Writer.Write(p)
|
||||
}
|
||||
|
||||
func (c *shadowsocksConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
return rw.ReadFromVar(&c.Writer, r)
|
||||
}
|
||||
|
||||
func (c *Client) DialContextUDP(ctx context.Context) socks.PacketConn {
|
||||
conn, err := c.dialer.DialContext(ctx, "udp", c.server)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &shadowsocksPacketConn{c, conn}
|
||||
}
|
||||
|
||||
type shadowsocksPacketConn struct {
|
||||
*Client
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *shadowsocksPacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
header.WriteRandom(c.cipher.SaltSize())
|
||||
common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port))
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
err := c.cipher.EncodePacket(c.key, buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(c.Conn.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (c *shadowsocksPacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) {
|
||||
n, err := c.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = c.cipher.DecodePacket(c.key, buffer)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return AddressSerializer.ReadAddressAndPort(buffer)
|
||||
}
|
153
protocol/shadowsocks/none.go
Normal file
153
protocol/shadowsocks/none.go
Normal file
|
@ -0,0 +1,153 @@
|
|||
package shadowsocks
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
const MethodNone = "none"
|
||||
|
||||
type NoneMethod struct{}
|
||||
|
||||
func NewNone() Method {
|
||||
return &NoneMethod{}
|
||||
}
|
||||
|
||||
func (m *NoneMethod) Name() string {
|
||||
return MethodNone
|
||||
}
|
||||
|
||||
func (m *NoneMethod) KeyLength() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *NoneMethod) NewSession(key []byte) Session {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *NoneMethod) DialConn(_ Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||
shadowsocksConn := &noneConn{
|
||||
Conn: conn,
|
||||
handshake: true,
|
||||
destination: destination,
|
||||
}
|
||||
return shadowsocksConn, shadowsocksConn.clientHandshake()
|
||||
}
|
||||
|
||||
func (m *NoneMethod) DialEarlyConn(_ Session, conn net.Conn, destination *M.AddrPort) net.Conn {
|
||||
return &noneConn{
|
||||
Conn: conn,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *NoneMethod) DialPacketConn(_ Session, conn net.Conn) socks.PacketConn {
|
||||
return &nonePacketConn{conn}
|
||||
}
|
||||
|
||||
type noneConn struct {
|
||||
net.Conn
|
||||
|
||||
access sync.Mutex
|
||||
handshake bool
|
||||
destination *M.AddrPort
|
||||
}
|
||||
|
||||
func (c *noneConn) clientHandshake() error {
|
||||
err := socks.AddressSerializer.WriteAddrPort(c.Conn, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.handshake = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *noneConn) Write(b []byte) (n int, err error) {
|
||||
if c.handshake {
|
||||
goto direct
|
||||
}
|
||||
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
if c.handshake {
|
||||
goto direct
|
||||
}
|
||||
|
||||
{
|
||||
if len(b) == 0 {
|
||||
return 0, c.clientHandshake()
|
||||
}
|
||||
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
|
||||
err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
bufN, _ := buffer.Write(b)
|
||||
_, err = c.Conn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if bufN < len(b) {
|
||||
_, err = c.Conn.Write(b[bufN:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
n = len(b)
|
||||
}
|
||||
|
||||
direct:
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if !c.handshake {
|
||||
panic("missing client handshake")
|
||||
}
|
||||
return c.Conn.(io.ReaderFrom).ReadFrom(r)
|
||||
}
|
||||
|
||||
func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.Conn.(io.WriterTo).WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *noneConn) RemoteAddr() net.Addr {
|
||||
return c.destination.TCPAddr()
|
||||
}
|
||||
|
||||
type nonePacketConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
_, err := buffer.ReadFrom(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return socks.AddressSerializer.ReadAddrPort(buffer)
|
||||
}
|
||||
|
||||
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
|
||||
if err != nil {
|
||||
header.Release()
|
||||
return err
|
||||
}
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
return common.Error(buffer.WriteTo(c))
|
||||
}
|
|
@ -2,23 +2,26 @@ package shadowsocks
|
|||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"math/rand"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
"net"
|
||||
)
|
||||
|
||||
const MaxPacketSize = 16*1024 - 1
|
||||
type Session interface {
|
||||
Key() []byte
|
||||
ReplayFilter() replay.Filter
|
||||
}
|
||||
|
||||
func Kdf(key, iv []byte, keyLength int) []byte {
|
||||
subKey := make([]byte, keyLength)
|
||||
kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
|
||||
common.Must1(io.ReadFull(kdf, subKey))
|
||||
return subKey
|
||||
type Method interface {
|
||||
Name() string
|
||||
KeyLength() int
|
||||
DialConn(session Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error)
|
||||
DialEarlyConn(session Session, conn net.Conn, destination *M.AddrPort) net.Conn
|
||||
DialPacketConn(session Session, conn net.Conn) socks.PacketConn
|
||||
}
|
||||
|
||||
func Key(password []byte, keySize int) []byte {
|
||||
|
@ -43,19 +46,18 @@ func Key(password []byte, keySize int) []byte {
|
|||
return m[:keySize]
|
||||
}
|
||||
|
||||
func RemapToPrintable(input []byte) {
|
||||
const charSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!#$%&()*+,./:;<=>?@[]^_`{|}~\\\""
|
||||
seed := rand.New(rand.NewSource(int64(crc32.ChecksumIEEE(input))))
|
||||
for i := range input {
|
||||
input[i] = charSet[seed.Intn(len(charSet))]
|
||||
}
|
||||
type ReducedEntropyReader struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
var AddressSerializer = socksaddr.NewSerializer(
|
||||
socksaddr.AddressFamilyByte(0x01, socksaddr.AddressFamilyIPv4),
|
||||
socksaddr.AddressFamilyByte(0x04, socksaddr.AddressFamilyIPv6),
|
||||
socksaddr.AddressFamilyByte(0x03, socksaddr.AddressFamilyFqdn),
|
||||
socksaddr.WithFamilyParser(func(b byte) byte {
|
||||
return b & 0x0F
|
||||
}),
|
||||
)
|
||||
func (r *ReducedEntropyReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.Reader.Read(p)
|
||||
if n > 6 {
|
||||
const charSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!#$%&()*+,./:;<=>?@[]^_`{|}~\\\""
|
||||
seed := rand.New(rand.NewSource(int64(crc32.ChecksumIEEE(p[:6]))))
|
||||
for i := range p[:6] {
|
||||
p[i] = charSet[seed.Intn(len(charSet))]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,138 +1,20 @@
|
|||
package shadowsocks
|
||||
package shadowaead
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
const PacketLengthBufferSize = 2
|
||||
const (
|
||||
MaxPacketSize = 16*1024 - 1
|
||||
PacketLengthBufferSize = 2
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterCipher("aes-128-gcm", func() Cipher {
|
||||
return &AEADCipher{
|
||||
KeyLength: 16,
|
||||
SaltLength: 16,
|
||||
Constructor: aesGcm,
|
||||
}
|
||||
})
|
||||
RegisterCipher("aes-192-gcm", func() Cipher {
|
||||
return &AEADCipher{
|
||||
KeyLength: 24,
|
||||
SaltLength: 24,
|
||||
Constructor: aesGcm,
|
||||
}
|
||||
})
|
||||
RegisterCipher("aes-256-gcm", func() Cipher {
|
||||
return &AEADCipher{
|
||||
KeyLength: 32,
|
||||
SaltLength: 32,
|
||||
Constructor: aesGcm,
|
||||
}
|
||||
})
|
||||
RegisterCipher("chacha20-ietf-poly1305", func() Cipher {
|
||||
return &AEADCipher{
|
||||
KeyLength: 32,
|
||||
SaltLength: 32,
|
||||
Constructor: chacha20Poly1305,
|
||||
}
|
||||
})
|
||||
RegisterCipher("xchacha20-ietf-poly1305", func() Cipher {
|
||||
return &AEADCipher{
|
||||
KeyLength: 32,
|
||||
SaltLength: 32,
|
||||
Constructor: xchacha20Poly1305,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func aesGcm(key []byte) cipher.AEAD {
|
||||
block, err := aes.NewCipher(key)
|
||||
common.Must(err)
|
||||
aead, err := cipher.NewGCM(block)
|
||||
common.Must(err)
|
||||
return aead
|
||||
}
|
||||
|
||||
func chacha20Poly1305(key []byte) cipher.AEAD {
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
common.Must(err)
|
||||
return aead
|
||||
}
|
||||
|
||||
func xchacha20Poly1305(key []byte) cipher.AEAD {
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
common.Must(err)
|
||||
return aead
|
||||
}
|
||||
|
||||
type AEADCipher struct {
|
||||
KeyLength int
|
||||
SaltLength int
|
||||
Constructor func(key []byte) cipher.AEAD
|
||||
}
|
||||
|
||||
func (c *AEADCipher) KeySize() int {
|
||||
return c.KeyLength
|
||||
}
|
||||
|
||||
func (c *AEADCipher) SaltSize() int {
|
||||
return c.SaltLength
|
||||
}
|
||||
|
||||
func (c *AEADCipher) CreateReader(key []byte, salt []byte, reader io.Reader) io.Reader {
|
||||
return NewAEADReader(reader, c.Constructor(Kdf(key, salt, c.KeyLength)))
|
||||
}
|
||||
|
||||
func (c *AEADCipher) CreateWriter(key []byte, salt []byte, writer io.Writer) io.Writer {
|
||||
protocolWriter := NewAEADWriter(writer, c.Constructor(Kdf(key, salt, c.KeyLength)))
|
||||
return protocolWriter
|
||||
}
|
||||
|
||||
func (c *AEADCipher) EncodePacket(key []byte, buffer *buf.Buffer) error {
|
||||
aead := c.Constructor(Kdf(key, buffer.To(c.SaltLength), c.KeyLength))
|
||||
aead.Seal(buffer.From(c.SaltLength)[:0], rw.ZeroBytes[:aead.NonceSize()], buffer.From(c.SaltLength), nil)
|
||||
buffer.Extend(aead.Overhead())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AEADCipher) DecodePacket(key []byte, buffer *buf.Buffer) error {
|
||||
if buffer.Len() < c.SaltLength {
|
||||
return exceptions.New("bad packet")
|
||||
}
|
||||
aead := c.Constructor(Kdf(key, buffer.To(c.SaltLength), c.KeyLength))
|
||||
packet, err := aead.Open(buffer.Index(c.SaltLength), rw.ZeroBytes[:aead.NonceSize()], buffer.From(c.SaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Advance(c.SaltLength)
|
||||
buffer.Truncate(len(packet))
|
||||
return nil
|
||||
}
|
||||
|
||||
type AEADConn struct {
|
||||
net.Conn
|
||||
Reader *AEADReader
|
||||
Writer *AEADWriter
|
||||
}
|
||||
|
||||
func (c *AEADConn) Read(p []byte) (n int, err error) {
|
||||
return c.Reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *AEADConn) Write(p []byte) (n int, err error) {
|
||||
return c.Writer.Write(p)
|
||||
}
|
||||
|
||||
type AEADReader struct {
|
||||
type Reader struct {
|
||||
upstream io.Reader
|
||||
cipher cipher.AEAD
|
||||
data []byte
|
||||
|
@ -141,8 +23,8 @@ type AEADReader struct {
|
|||
cached int
|
||||
}
|
||||
|
||||
func NewAEADReader(upstream io.Reader, cipher cipher.AEAD) *AEADReader {
|
||||
return &AEADReader{
|
||||
func NewReader(upstream io.Reader, cipher cipher.AEAD) *Reader {
|
||||
return &Reader{
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
|
||||
|
@ -150,19 +32,19 @@ func NewAEADReader(upstream io.Reader, cipher cipher.AEAD) *AEADReader {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *AEADReader) Upstream() io.Reader {
|
||||
func (r *Reader) Upstream() io.Reader {
|
||||
return r.upstream
|
||||
}
|
||||
|
||||
func (r *AEADReader) Replaceable() bool {
|
||||
func (r *Reader) Replaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *AEADReader) SetUpstream(reader io.Reader) {
|
||||
func (r *Reader) SetUpstream(reader io.Reader) {
|
||||
r.upstream = reader
|
||||
}
|
||||
|
||||
func (r *AEADReader) WriteTo(writer io.Writer) (n int64, err error) {
|
||||
func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
|
||||
if r.cached > 0 {
|
||||
writeN, writeErr := writer.Write(r.data[r.index : r.index+r.cached])
|
||||
if writeErr != nil {
|
||||
|
@ -200,7 +82,7 @@ func (r *AEADReader) WriteTo(writer io.Writer) (n int64, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *AEADReader) Read(b []byte) (n int, err error) {
|
||||
func (r *Reader) Read(b []byte) (n int, err error) {
|
||||
if r.cached > 0 {
|
||||
n = copy(b, r.data[r.index:r.index+r.cached])
|
||||
r.cached -= n
|
332
protocol/shadowsocks/shadowaead/method.go
Normal file
332
protocol/shadowsocks/shadowaead/method.go
Normal file
|
@ -0,0 +1,332 @@
|
|||
package shadowaead
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha1"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
var List = []string{
|
||||
"aes-128-gcm",
|
||||
"aes-192-gcm",
|
||||
"aes-256-gcm",
|
||||
"chacha20-ietf-poly1305",
|
||||
"xchacha20-ietf-poly1305",
|
||||
}
|
||||
|
||||
func New(method string, secureRNG io.Reader) shadowsocks.Method {
|
||||
m := &Method{
|
||||
name: method,
|
||||
secureRNG: secureRNG,
|
||||
}
|
||||
switch method {
|
||||
case "aes-128-gcm":
|
||||
m.keySaltLength = 16
|
||||
m.constructor = newAESGCM
|
||||
case "aes-192-gcm":
|
||||
m.keySaltLength = 24
|
||||
m.constructor = newAESGCM
|
||||
case "aes-256-gcm":
|
||||
m.keySaltLength = 32
|
||||
m.constructor = newAESGCM
|
||||
case "chacha20-ietf-poly1305":
|
||||
m.keySaltLength = 32
|
||||
m.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
case "xchacha20-ietf-poly1305":
|
||||
m.keySaltLength = 32
|
||||
m.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.NewX(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func NewSession(key []byte, replayFilter bool) shadowsocks.Session {
|
||||
var filter replay.Filter
|
||||
if replayFilter {
|
||||
filter = replay.NewBloomRing()
|
||||
}
|
||||
return &session{key, filter}
|
||||
}
|
||||
|
||||
func Kdf(key, iv []byte, keyLength int) []byte {
|
||||
subKey := make([]byte, keyLength)
|
||||
kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
|
||||
common.Must1(io.ReadFull(kdf, subKey))
|
||||
return subKey
|
||||
}
|
||||
|
||||
func newAESGCM(key []byte) cipher.AEAD {
|
||||
block, err := aes.NewCipher(key)
|
||||
common.Must(err)
|
||||
aead, err := cipher.NewGCM(block)
|
||||
common.Must(err)
|
||||
return aead
|
||||
}
|
||||
|
||||
type Method struct {
|
||||
name string
|
||||
keySaltLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
secureRNG io.Reader
|
||||
}
|
||||
|
||||
func (m *Method) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Method) KeyLength() int {
|
||||
return m.keySaltLength
|
||||
}
|
||||
|
||||
func (m *Method) DialConn(account shadowsocks.Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||
shadowsocksConn := &aeadConn{
|
||||
Conn: conn,
|
||||
method: m,
|
||||
key: account.Key(),
|
||||
replayFilter: account.ReplayFilter(),
|
||||
destination: destination,
|
||||
}
|
||||
return shadowsocksConn, shadowsocksConn.clientHandshake()
|
||||
}
|
||||
|
||||
func (m *Method) DialEarlyConn(account shadowsocks.Session, conn net.Conn, destination *M.AddrPort) net.Conn {
|
||||
return &aeadConn{
|
||||
Conn: conn,
|
||||
method: m,
|
||||
key: account.Key(),
|
||||
replayFilter: account.ReplayFilter(),
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Method) DialPacketConn(account shadowsocks.Session, conn net.Conn) socks.PacketConn {
|
||||
return &aeadPacketConn{conn, account.Key(), m}
|
||||
}
|
||||
|
||||
func (m *Method) EncodePacket(key []byte, buffer *buf.Buffer) error {
|
||||
cipher := m.constructor(Kdf(key, buffer.To(m.keySaltLength), m.keySaltLength))
|
||||
cipher.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:cipher.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
buffer.Extend(cipher.Overhead())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Method) DecodePacket(key []byte, buffer *buf.Buffer) error {
|
||||
if buffer.Len() < m.keySaltLength {
|
||||
return E.New("bad packet")
|
||||
}
|
||||
aead := m.constructor(Kdf(key, buffer.To(m.keySaltLength), m.keySaltLength))
|
||||
packet, err := aead.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:aead.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Advance(m.keySaltLength)
|
||||
buffer.Truncate(len(packet))
|
||||
return nil
|
||||
}
|
||||
|
||||
type session struct {
|
||||
key []byte
|
||||
replayFilter replay.Filter
|
||||
}
|
||||
|
||||
func (a *session) Key() []byte {
|
||||
return a.key
|
||||
}
|
||||
|
||||
func (a *session) ReplayFilter() replay.Filter {
|
||||
return a.replayFilter
|
||||
}
|
||||
|
||||
type aeadConn struct {
|
||||
net.Conn
|
||||
|
||||
method *Method
|
||||
key []byte
|
||||
destination *M.AddrPort
|
||||
|
||||
access sync.Mutex
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
replayFilter replay.Filter
|
||||
}
|
||||
|
||||
func (c *aeadConn) clientHandshake() error {
|
||||
header := buf.New()
|
||||
defer header.Release()
|
||||
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
if c.replayFilter != nil {
|
||||
c.replayFilter.Check(header.Bytes())
|
||||
}
|
||||
|
||||
c.writer = NewAEADWriter(
|
||||
&buf.BufferedWriter{
|
||||
Writer: c.Conn,
|
||||
Buffer: header,
|
||||
},
|
||||
c.method.constructor(Kdf(c.key, header.Bytes(), c.method.keySaltLength)),
|
||||
)
|
||||
|
||||
err := socks.AddressSerializer.WriteAddrPort(c.writer, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return common.FlushVar(&c.writer)
|
||||
}
|
||||
|
||||
func (c *aeadConn) serverHandshake() error {
|
||||
if c.reader == nil {
|
||||
salt := make([]byte, c.method.keySaltLength)
|
||||
_, err := io.ReadFull(c.Conn, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.replayFilter != nil {
|
||||
if !c.replayFilter.Check(salt) {
|
||||
return E.New("salt is not unique")
|
||||
}
|
||||
}
|
||||
c.reader = NewReader(c.Conn, c.method.constructor(Kdf(c.key, salt, c.method.keySaltLength)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *aeadConn) Read(p []byte) (n int, err error) {
|
||||
if err = c.serverHandshake(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *aeadConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if err = c.serverHandshake(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.(io.WriterTo).WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *aeadConn) Write(p []byte) (n int, err error) {
|
||||
if c.writer != nil {
|
||||
goto direct
|
||||
}
|
||||
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
if c.writer != nil {
|
||||
goto direct
|
||||
}
|
||||
|
||||
// client handshake
|
||||
|
||||
{
|
||||
header := buf.New()
|
||||
defer header.Release()
|
||||
|
||||
request := buf.New()
|
||||
defer request.Release()
|
||||
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
if c.replayFilter != nil {
|
||||
c.replayFilter.Check(header.Bytes())
|
||||
}
|
||||
|
||||
var writer io.Writer = c.Conn
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: header,
|
||||
}
|
||||
writer = NewAEADWriter(writer, c.method.constructor(Kdf(c.key, header.Bytes(), c.method.keySaltLength)))
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: request,
|
||||
}
|
||||
|
||||
err = socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(p) > 0 {
|
||||
_, err = writer.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = common.FlushVar(&writer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.writer = writer
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
direct:
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
|
||||
func (c *aeadConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer == nil {
|
||||
panic("missing client handshake")
|
||||
}
|
||||
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
||||
}
|
||||
|
||||
type aeadPacketConn struct {
|
||||
net.Conn
|
||||
key []byte
|
||||
method *Method
|
||||
}
|
||||
|
||||
func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
err = c.method.EncodePacket(c.key, buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(c.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (c *aeadPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
n, err := c.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = c.method.DecodePacket(c.key, buffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return socks.AddressSerializer.ReadAddrPort(buffer)
|
||||
}
|
|
@ -6,12 +6,12 @@ import (
|
|||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type PacketConn interface {
|
||||
ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error)
|
||||
WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error
|
||||
ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error)
|
||||
WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error
|
||||
|
||||
Close() error
|
||||
LocalAddr() net.Addr
|
||||
|
@ -21,23 +21,45 @@ type PacketConn interface {
|
|||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
func CopyPacketConn(dest PacketConn, conn PacketConn, onAction func(size int)) error {
|
||||
type UDPConnectionHandler interface {
|
||||
NewPacketConnection(conn PacketConn, metadata M.Metadata) error
|
||||
}
|
||||
|
||||
type PacketConnStub struct{}
|
||||
|
||||
func (s *PacketConnStub) RemoteAddr() net.Addr {
|
||||
return &common.DummyAddr{}
|
||||
}
|
||||
|
||||
func (s *PacketConnStub) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PacketConnStub) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func CopyPacketConn(dest PacketConn, conn PacketConn, onAction func(destination *M.AddrPort, n int)) error {
|
||||
for {
|
||||
buffer := buf.New()
|
||||
addr, port, err := conn.ReadPacket(buffer)
|
||||
destination, err := conn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return err
|
||||
}
|
||||
size := buffer.Len()
|
||||
err = dest.WritePacket(buffer, addr, port)
|
||||
err = dest.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return err
|
||||
}
|
||||
if onAction != nil {
|
||||
onAction(size)
|
||||
onAction(destination, size)
|
||||
}
|
||||
buffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,22 +80,22 @@ func (c *associatePacketConn) RemoteAddr() net.Addr {
|
|||
return c.addr
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (socksaddr.Addr, uint16, error) {
|
||||
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
c.addr = addr
|
||||
buffer.Truncate(n)
|
||||
buffer.Advance(3)
|
||||
return AddressSerializer.ReadAddressAndPort(buffer)
|
||||
return AddressSerializer.ReadAddrPort(buffer)
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addr socksaddr.Addr, port uint16) error {
|
||||
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
common.Must(header.WriteZeroN(3))
|
||||
common.Must(AddressSerializer.WriteAddressAndPort(header, addr, port))
|
||||
common.Must(AddressSerializer.WriteAddrPort(header, addrPort))
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr))
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ package socks
|
|||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -69,8 +69,8 @@ func (code ReplyCode) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
var AddressSerializer = socksaddr.NewSerializer(
|
||||
socksaddr.AddressFamilyByte(0x01, socksaddr.AddressFamilyIPv4),
|
||||
socksaddr.AddressFamilyByte(0x04, socksaddr.AddressFamilyIPv6),
|
||||
socksaddr.AddressFamilyByte(0x03, socksaddr.AddressFamilyFqdn),
|
||||
var AddressSerializer = M.NewSerializer(
|
||||
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
|
||||
M.AddressFamilyByte(0x04, M.AddressFamilyIPv6),
|
||||
M.AddressFamilyByte(0x03, M.AddressFamilyFqdn),
|
||||
)
|
||||
|
|
|
@ -4,11 +4,11 @@ import (
|
|||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksaddr.Addr, port uint16, username string, password string) (*Response, error) {
|
||||
func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination *M.AddrPort, username string, password string) (*Response, error) {
|
||||
var method byte
|
||||
if common.IsBlank(username) {
|
||||
method = AuthTypeNotRequired
|
||||
|
@ -27,7 +27,7 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksa
|
|||
return nil, err
|
||||
}
|
||||
if authResponse.Method != method {
|
||||
return nil, exceptions.New("not requested method, request ", method, ", return ", method)
|
||||
return nil, E.New("not requested method, request ", method, ", return ", method)
|
||||
}
|
||||
if method == AuthTypeUsernamePassword {
|
||||
err = WriteUsernamePasswordAuthRequest(conn, &UsernamePasswordAuthRequest{
|
||||
|
@ -46,10 +46,9 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksa
|
|||
}
|
||||
}
|
||||
err = WriteRequest(conn, &Request{
|
||||
Version: version,
|
||||
Command: command,
|
||||
Addr: addr,
|
||||
Port: port,
|
||||
Version: version,
|
||||
Command: command,
|
||||
Destination: destination,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -57,7 +56,7 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, addr socksa
|
|||
return ReadResponse(conn)
|
||||
}
|
||||
|
||||
func ClientFastHandshake(writer io.Writer, version byte, command byte, addr socksaddr.Addr, port uint16, username string, password string) error {
|
||||
func ClientFastHandshake(writer io.Writer, version byte, command byte, destination *M.AddrPort, username string, password string) error {
|
||||
var method byte
|
||||
if common.IsBlank(username) {
|
||||
method = AuthTypeNotRequired
|
||||
|
@ -81,10 +80,9 @@ func ClientFastHandshake(writer io.Writer, version byte, command byte, addr sock
|
|||
}
|
||||
}
|
||||
return WriteRequest(writer, &Request{
|
||||
Version: version,
|
||||
Command: command,
|
||||
Addr: addr,
|
||||
Port: port,
|
||||
Version: version,
|
||||
Command: command,
|
||||
Destination: destination,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
148
protocol/socks/listener.go
Normal file
148
protocol/socks/listener.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
package socks
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/auth"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/transport/tcp"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
tcp.Handler
|
||||
UDPConnectionHandler
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
tcpListener *tcp.Listener
|
||||
authenticator auth.Authenticator
|
||||
handler Handler
|
||||
}
|
||||
|
||||
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler Handler) *Listener {
|
||||
listener := &Listener{
|
||||
handler: handler,
|
||||
authenticator: authenticator,
|
||||
}
|
||||
listener.tcpListener = tcp.NewTCPListener(bind, listener)
|
||||
return listener
|
||||
}
|
||||
|
||||
func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
return HandleConnection(conn, l.authenticator, l.handler)
|
||||
}
|
||||
|
||||
func (l *Listener) Start() error {
|
||||
return l.tcpListener.Start()
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
return l.tcpListener.Close()
|
||||
}
|
||||
|
||||
func (l *Listener) HandleError(err error) {
|
||||
l.handler.HandleError(err)
|
||||
}
|
||||
|
||||
func HandleConnection(conn net.Conn, authenticator auth.Authenticator, handler Handler) error {
|
||||
authRequest, err := ReadAuthRequest(conn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read socks auth request")
|
||||
}
|
||||
var authMethod byte
|
||||
if authenticator == nil {
|
||||
authMethod = AuthTypeNotRequired
|
||||
} else {
|
||||
authMethod = AuthTypeUsernamePassword
|
||||
}
|
||||
if !common.Contains(authRequest.Methods, authMethod) {
|
||||
err = WriteAuthResponse(conn, &AuthResponse{
|
||||
Version: authRequest.Version,
|
||||
Method: AuthTypeNoAcceptedMethods,
|
||||
})
|
||||
if err != nil {
|
||||
return E.Cause(err, "write socks auth response")
|
||||
}
|
||||
}
|
||||
err = WriteAuthResponse(conn, &AuthResponse{
|
||||
Version: authRequest.Version,
|
||||
Method: AuthTypeNotRequired,
|
||||
})
|
||||
if err != nil {
|
||||
return E.Cause(err, "write socks auth response")
|
||||
}
|
||||
|
||||
if authMethod == AuthTypeUsernamePassword {
|
||||
usernamePasswordAuthRequest, err := ReadUsernamePasswordAuthRequest(conn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read user auth request")
|
||||
}
|
||||
response := new(UsernamePasswordAuthResponse)
|
||||
if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) {
|
||||
response.Status = UsernamePasswordStatusSuccess
|
||||
} else {
|
||||
response.Status = UsernamePasswordStatusFailure
|
||||
}
|
||||
err = WriteUsernamePasswordAuthResponse(conn, response)
|
||||
if err != nil {
|
||||
return E.Cause(err, "write user auth response")
|
||||
}
|
||||
}
|
||||
|
||||
request, err := ReadRequest(conn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read socks request")
|
||||
}
|
||||
switch request.Command {
|
||||
case CommandConnect:
|
||||
err = WriteResponse(conn, &Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: ReplyCodeSuccess,
|
||||
Bind: M.AddrPortFromNetAddr(conn.LocalAddr()),
|
||||
})
|
||||
if err != nil {
|
||||
return E.Cause(err, "write socks response")
|
||||
}
|
||||
return handler.NewConnection(conn, M.Metadata{
|
||||
Destination: request.Destination,
|
||||
})
|
||||
case CommandUDPAssociate:
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer udpConn.Close()
|
||||
err = WriteResponse(conn, &Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: ReplyCodeSuccess,
|
||||
Bind: M.AddrPortFromNetAddr(udpConn.LocalAddr()),
|
||||
})
|
||||
if err != nil {
|
||||
return E.Cause(err, "write socks response")
|
||||
}
|
||||
go func() {
|
||||
err := handler.NewPacketConnection(NewPacketConn(conn, udpConn), M.Metadata{
|
||||
Source: M.AddrPortFromNetAddr(conn.RemoteAddr()),
|
||||
Destination: request.Destination,
|
||||
})
|
||||
if err != nil {
|
||||
handler.HandleError(err)
|
||||
}
|
||||
conn.Close()
|
||||
}()
|
||||
return common.Error(io.Copy(io.Discard, conn))
|
||||
default:
|
||||
err = WriteResponse(conn, &Response{
|
||||
Version: request.Version,
|
||||
ReplyCode: ReplyCodeUnsupported,
|
||||
})
|
||||
if err != nil {
|
||||
return E.Cause(err, "write response")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -7,9 +7,9 @@ import (
|
|||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
)
|
||||
|
||||
//+----+----------+----------+
|
||||
|
@ -45,11 +45,11 @@ func ReadAuthRequest(reader io.Reader) (*AuthRequest, error) {
|
|||
}
|
||||
methodLen, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
return nil, exceptions.Cause(err, "read socks auth methods length")
|
||||
return nil, E.Cause(err, "read socks auth methods length")
|
||||
}
|
||||
methods, err := rw.ReadBytes(reader, int(methodLen))
|
||||
if err != nil {
|
||||
return nil, exceptions.CauseF(err, "read socks auth methods, length ", methodLen)
|
||||
return nil, E.CauseF(err, "read socks auth methods, length ", methodLen)
|
||||
}
|
||||
request := &AuthRequest{
|
||||
version,
|
||||
|
@ -112,11 +112,11 @@ func WriteUsernamePasswordAuthRequest(writer io.Writer, request *UsernamePasswor
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = socksaddr.WriteString(writer, "username", request.Username)
|
||||
err = M.WriteString(writer, "username", request.Username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return socksaddr.WriteString(writer, "password", request.Password)
|
||||
return M.WriteString(writer, "password", request.Password)
|
||||
}
|
||||
|
||||
func ReadUsernamePasswordAuthRequest(reader io.Reader) (*UsernamePasswordAuthRequest, error) {
|
||||
|
@ -127,11 +127,11 @@ func ReadUsernamePasswordAuthRequest(reader io.Reader) (*UsernamePasswordAuthReq
|
|||
if version != UsernamePasswordVersion1 {
|
||||
return nil, &UnsupportedVersionException{version}
|
||||
}
|
||||
username, err := socksaddr.ReadString(reader)
|
||||
username, err := M.ReadString(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
password, err := socksaddr.ReadString(reader)
|
||||
password, err := M.ReadString(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -185,10 +185,9 @@ func ReadUsernamePasswordAuthResponse(reader io.Reader) (*UsernamePasswordAuthRe
|
|||
//+----+-----+-------+------+----------+----------+
|
||||
|
||||
type Request struct {
|
||||
Version byte
|
||||
Command byte
|
||||
Addr socksaddr.Addr
|
||||
Port uint16
|
||||
Version byte
|
||||
Command byte
|
||||
Destination *M.AddrPort
|
||||
}
|
||||
|
||||
func WriteRequest(writer io.Writer, request *Request) error {
|
||||
|
@ -204,7 +203,7 @@ func WriteRequest(writer io.Writer, request *Request) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return AddressSerializer.WriteAddressAndPort(writer, request.Addr, request.Port)
|
||||
return AddressSerializer.WriteAddrPort(writer, request.Destination)
|
||||
}
|
||||
|
||||
func ReadRequest(reader io.Reader) (*Request, error) {
|
||||
|
@ -226,15 +225,14 @@ func ReadRequest(reader io.Reader) (*Request, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addr, port, err := AddressSerializer.ReadAddressAndPort(reader)
|
||||
addrPort, err := AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request := &Request{
|
||||
Version: version,
|
||||
Command: command,
|
||||
Addr: addr,
|
||||
Port: port,
|
||||
Version: version,
|
||||
Command: command,
|
||||
Destination: addrPort,
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
@ -248,8 +246,7 @@ func ReadRequest(reader io.Reader) (*Request, error) {
|
|||
type Response struct {
|
||||
Version byte
|
||||
ReplyCode ReplyCode
|
||||
BindAddr socksaddr.Addr
|
||||
BindPort uint16
|
||||
Bind *M.AddrPort
|
||||
}
|
||||
|
||||
func WriteResponse(writer io.Writer, response *Response) error {
|
||||
|
@ -265,10 +262,10 @@ func WriteResponse(writer io.Writer, response *Response) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if response.BindAddr == nil {
|
||||
return AddressSerializer.WriteAddressAndPort(writer, socksaddr.AddrFromIP(net.IPv4zero), response.BindPort)
|
||||
if response.Bind == nil {
|
||||
return AddressSerializer.WriteAddrPort(writer, M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0))
|
||||
}
|
||||
return AddressSerializer.WriteAddressAndPort(writer, response.BindAddr, response.BindPort)
|
||||
return AddressSerializer.WriteAddrPort(writer, response.Bind)
|
||||
}
|
||||
|
||||
func ReadResponse(reader io.Reader) (*Response, error) {
|
||||
|
@ -287,15 +284,14 @@ func ReadResponse(reader io.Reader) (*Response, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addr, port, err := AddressSerializer.ReadAddressAndPort(reader)
|
||||
addrPort, err := AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response := &Response{
|
||||
Version: version,
|
||||
ReplyCode: ReplyCode(replyCode),
|
||||
BindAddr: addr,
|
||||
BindPort: port,
|
||||
Bind: addrPort,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
@ -307,15 +303,14 @@ func ReadResponse(reader io.Reader) (*Response, error) {
|
|||
//+----+------+------+----------+----------+----------+
|
||||
|
||||
type AssociatePacket struct {
|
||||
Fragment byte
|
||||
Addr socksaddr.Addr
|
||||
Port uint16
|
||||
Data []byte
|
||||
Fragment byte
|
||||
Destination *M.AddrPort
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func DecodeAssociatePacket(buffer *buf.Buffer) (*AssociatePacket, error) {
|
||||
if buffer.Len() < 5 {
|
||||
return nil, exceptions.New("insufficient length")
|
||||
return nil, E.New("insufficient length")
|
||||
}
|
||||
fragment := buffer.Byte(2)
|
||||
reader := bytes.NewReader(buffer.Bytes())
|
||||
|
@ -323,16 +318,15 @@ func DecodeAssociatePacket(buffer *buf.Buffer) (*AssociatePacket, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addr, port, err := AddressSerializer.ReadAddressAndPort(reader)
|
||||
addrPort, err := AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.Advance(reader.Len())
|
||||
packet := &AssociatePacket{
|
||||
Fragment: fragment,
|
||||
Addr: addr,
|
||||
Port: port,
|
||||
Data: buffer.Bytes(),
|
||||
Fragment: fragment,
|
||||
Destination: addrPort,
|
||||
Data: buffer.Bytes(),
|
||||
}
|
||||
return packet, nil
|
||||
}
|
||||
|
@ -346,7 +340,7 @@ func EncodeAssociatePacket(packet *AssociatePacket, buffer *buf.Buffer) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = AddressSerializer.WriteAddressAndPort(buffer, packet.Addr, packet.Port)
|
||||
err = AddressSerializer.WriteAddrPort(buffer, packet.Destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/socksaddr"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
|
@ -20,7 +20,7 @@ func TestHandshake(t *testing.T) {
|
|||
method := socks.AuthTypeUsernamePassword
|
||||
|
||||
go func() {
|
||||
response, err := socks.ClientHandshake(client, socks.Version5, socks.CommandConnect, socksaddr.AddrFromFqdn("test"), 80, "user", "pswd")
|
||||
response, err := socks.ClientHandshake(client, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn("test"), 80), "user", "pswd")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -60,14 +60,13 @@ func TestHandshake(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if request.Version != socks.Version5 || request.Command != socks.CommandConnect || request.Addr.Fqdn() != "test" || request.Port != 80 {
|
||||
if request.Version != socks.Version5 || request.Command != socks.CommandConnect || request.Destination.Addr.Fqdn() != "test" || request.Destination.Port != 80 {
|
||||
t.Fatal(request)
|
||||
}
|
||||
err = socks.WriteResponse(server, &socks.Response{
|
||||
Version: socks.Version5,
|
||||
ReplyCode: socks.ReplyCodeSuccess,
|
||||
BindAddr: socksaddr.AddrFromIP(net.IPv4zero),
|
||||
BindPort: 0,
|
||||
Bind: M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue