Refactor socks

This commit is contained in:
世界 2022-05-10 22:24:09 +08:00
parent 788e4e2658
commit ed28a5714f
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
30 changed files with 980 additions and 978 deletions

View file

@ -17,7 +17,8 @@ import (
"github.com/sagernet/sing/common/buf"
_ "github.com/sagernet/sing/common/log"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks5"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/protocol/socks"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/net/dns/dnsmessage"
@ -25,7 +26,7 @@ import (
func main() {
command := &cobra.Command{
Use: "socks-chk address:port",
Use: "socks-chk [socks4/4a/5://]address:port",
Args: cobra.ExactArgs(1),
Run: run,
}
@ -35,33 +36,29 @@ func main() {
}
func run(cmd *cobra.Command, args []string) {
server := M.ParseSocksaddr(args[0])
err := testSocksTCP(server)
client, err := socks.NewClientFromURL(N.SystemDialer, args[0])
if err != nil {
logrus.Fatal(err)
}
err = testSocksUDP(server)
err = testSocksTCP(client)
if err != nil {
logrus.Fatal(err)
}
err = testSocksQuic(server)
err = testSocksUDP(client)
if err != nil {
logrus.Fatal(err)
}
err = testSocksQuic(client)
if err != nil {
logrus.Fatal(err)
}
}
func testSocksTCP(server M.Socksaddr) error {
tcpConn, err := net.Dial("tcp", server.String())
func testSocksTCP(client *socks.Client) error {
tcpConn, err := client.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("1.0.0.1", "53"))
if err != nil {
return err
}
response, err := socks5.ClientHandshake(tcpConn, socks5.Version5, socks5.CommandConnect, M.ParseSocksaddrHostPort("1.0.0.1", "53"), "", "")
if err != nil {
return err
}
if response.ReplyCode != socks5.ReplyCodeSuccess {
logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode)
}
message := &dnsmessage.Message{}
message.Header.ID = 1
@ -105,25 +102,12 @@ func testSocksTCP(server M.Socksaddr) error {
return nil
}
func testSocksUDP(server M.Socksaddr) error {
tcpConn, err := net.Dial("tcp", server.String())
func testSocksUDP(client *socks.Client) error {
udpConn, err := client.DialContext(context.Background(), "udp", M.ParseSocksaddrHostPort("1.0.0.1", "53"))
if err != nil {
return err
}
dest := M.ParseSocksaddrHostPort("1.0.0.1", "53")
response, err := socks5.ClientHandshake(tcpConn, socks5.Version5, socks5.CommandUDPAssociate, dest, "", "")
if err != nil {
return err
}
if response.ReplyCode != socks5.ReplyCodeSuccess {
logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode)
}
var dialer net.Dialer
udpConn, err := dialer.DialContext(context.Background(), "udp", response.Bind.String())
if err != nil {
return err
}
assConn := socks5.NewAssociateConn(tcpConn, udpConn, dest)
message := &dnsmessage.Message{}
message.Header.ID = 1
message.Header.RecursionDesired = true
@ -134,14 +118,11 @@ func testSocksUDP(server M.Socksaddr) error {
})
packet, err := message.Pack()
common.Must(err)
common.Must1(assConn.WriteTo(packet, &net.UDPAddr{
IP: net.IPv4(1, 0, 0, 1),
Port: 53,
}))
common.Must1(udpConn.Write(packet))
_buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
common.Must2(buffer.ReadPacketFrom(assConn))
common.Must1(buffer.ReadFrom(udpConn))
common.Must(message.Unpack(buffer.Bytes()))
for _, answer := range message.Answers {
@ -149,44 +130,26 @@ func testSocksUDP(server M.Socksaddr) error {
}
udpConn.Close()
tcpConn.Close()
return nil
}
func testSocksQuic(server M.Socksaddr) error {
client := &http.Client{
func testSocksQuic(client *socks.Client) error {
httpClient := &http.Client{
Transport: &http3.RoundTripper{
Dial: func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
udpAddr, err := net.ResolveUDPAddr(network, addr)
if err != nil {
return nil, err
}
tcpConn, err := net.Dial("tcp", server.String())
conn, err := client.DialContext(context.Background(), network, M.SocksaddrFromNet(udpAddr))
if err != nil {
return nil, err
}
destination := M.SocksaddrFromNetIP(udpAddr.AddrPort())
response, err := socks5.ClientHandshake(tcpConn, socks5.Version5, socks5.CommandUDPAssociate, destination, "", "")
if err != nil {
return nil, err
}
if response.ReplyCode != socks5.ReplyCodeSuccess {
logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode)
}
var dialer net.Dialer
udpConn, err := dialer.DialContext(context.Background(), "udp", response.Bind.String())
if err != nil {
return nil, err
}
assConn := socks5.NewAssociateConn(tcpConn, udpConn, destination)
host := M.ParseSocksaddr(addr).AddrString()
logrus.Trace(host)
return quic.DialEarlyContext(ctx, assConn, udpAddr, host, tlsCfg, cfg)
return quic.DialEarlyContext(ctx, conn.(net.PacketConn), udpAddr, M.ParseSocksaddr(addr).AddrString(), tlsCfg, cfg)
},
},
}
// qResponse, err := client.Get("https://cloudflare.com/cdn-cgi/trace")
qResponse, err := client.Get("https://cloudflare.com/cdn-cgi/trace")
qResponse, err := httpClient.Get("https://cloudflare.com/cdn-cgi/trace")
if err != nil {
return err
}

View file

@ -17,7 +17,7 @@ import (
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/protocol/socks5"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/transport/mixed"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@ -73,7 +73,12 @@ func run(cmd *cobra.Command, args []string) {
logrus.Fatal("unknown transproxy mode ", transproxy)
}
client := &localClient{upstream: args[1]}
socks, err := socks.NewClientFromURL(N.SystemDialer, args[1])
if err != nil {
logrus.Fatal(err)
}
client := &localClient{upstream: socks}
client.Listener = mixed.NewListener(bind, nil, transproxyMode, 300, client)
err = client.Start()
@ -92,38 +97,29 @@ func run(cmd *cobra.Command, args []string) {
type localClient struct {
*mixed.Listener
upstream string
upstream *socks.Client
}
func (c *localClient) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
logrus.Info("inbound ", metadata.Protocol, " TCP ", metadata.Source, " ==> ", metadata.Destination)
upstream, err := net.Dial("tcp", c.upstream)
upstream, err := c.upstream.DialContext(ctx, "tcp", metadata.Destination)
if err != nil {
return E.Cause(err, "connect to upstream")
return err
}
_, err = socks5.ClientHandshake(upstream, socks5.Version5, socks5.CommandConnect, metadata.Destination, "", "")
if err != nil {
return E.Cause(err, "upstream handshake failed")
}
return rw.CopyConn(context.Background(), upstream, conn)
return rw.CopyConn(context.Background(), conn, upstream)
}
func (c *localClient) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
upstream, err := net.Dial("tcp", c.upstream)
logrus.Info("inbound ", metadata.Protocol, " UDP ", metadata.Source, " ==> ", metadata.Destination)
upstream, err := c.upstream.DialContext(ctx, "tcp", metadata.Destination)
if err != nil {
return E.Cause(err, "connect to upstream")
return err
}
_, err = socks5.ClientHandshake(upstream, socks5.Version5, socks5.CommandConnect, M.ParseSocksaddrHostPort(uot.UOTMagicAddress, "443"), "", "")
if err != nil {
return E.Cause(err, "upstream handshake failed")
}
client := uot.NewClientConn(upstream)
return N.CopyPacketConn(ctx, client, conn)
return N.CopyPacketConn(ctx, conn, uot.NewClientConn(upstream))
}
func (c *localClient) OnError(err error) {

View file

@ -43,7 +43,7 @@ func New(message ...any) error {
}
func Cause(cause error, message ...any) Exception {
return exception{fmt.Sprint(message), cause}
return exception{fmt.Sprint(message...), cause}
}
func IsClosed(err error) bool {

View file

@ -30,19 +30,13 @@ func (ap Socksaddr) IsValid() bool {
func (ap Socksaddr) Family() Family {
if ap.Addr.IsValid() {
if ap.Addr.Is4() {
if ap.Addr.Is4() || ap.Addr.Is4In6() {
return AddressFamilyIPv4
} else {
return AddressFamilyIPv6
}
}
if ap.Fqdn != "" {
return AddressFamilyFqdn
} else if ap.Addr.Is4() || ap.Addr.Is4In6() {
return AddressFamilyIPv4
} else {
return AddressFamilyIPv6
}
return AddressFamilyFqdn
}
func (ap Socksaddr) AddrString() string {

View file

@ -1,12 +0,0 @@
package metadata
import "fmt"
type StringTooLongException struct {
Op string
Len int
}
func (e StringTooLongException) Error() string {
return fmt.Sprint(e.Op, " too long: length ", e.Len, ", max 255")
}

View file

@ -50,7 +50,7 @@ func (s *Serializer) WriteAddress(writer io.Writer, addr Socksaddr) error {
if addr.Addr.IsValid() {
err = rw.WriteBytes(writer, addr.Addr.AsSlice())
} else {
err = WriteString(writer, "fqdn", addr.Fqdn)
err = WriteSocksString(writer, addr.Fqdn)
}
return err
}
@ -100,7 +100,7 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) {
family := s.familyMap[af]
switch family {
case AddressFamilyFqdn:
fqdn, err := ReadString(reader)
fqdn, err := ReadSockString(reader)
if err != nil {
return Socksaddr{}, E.Cause(err, "read fqdn")
}
@ -160,7 +160,7 @@ func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err
return addr, nil
}
func ReadString(reader io.Reader) (string, error) {
func ReadSockString(reader io.Reader) (string, error) {
strLen, err := rw.ReadByte(reader)
if err != nil {
return "", err
@ -168,10 +168,10 @@ func ReadString(reader io.Reader) (string, error) {
return rw.ReadString(reader, int(strLen))
}
func WriteString(writer io.Writer, op string, str string) error {
func WriteSocksString(writer io.Writer, str string) error {
strLen := len(str)
if strLen > 255 {
return &StringTooLongException{op, strLen}
return E.New("fqdn too long")
}
err := rw.WriteByte(writer, byte(strLen))
if err != nil {

View file

@ -0,0 +1,7 @@
package metadata
var SocksaddrSerializer = NewSerializer(
AddressFamilyByte(0x01, AddressFamilyIPv4),
AddressFamilyByte(0x04, AddressFamilyIPv6),
AddressFamilyByte(0x03, AddressFamilyFqdn),
)

View file

@ -14,7 +14,6 @@ import (
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/socks5"
)
const MethodNone = "none"
@ -62,7 +61,7 @@ type noneConn struct {
}
func (c *noneConn) clientHandshake() error {
err := socks5.AddressSerializer.WriteAddrPort(c.Conn, c.destination)
err := M.SocksaddrSerializer.WriteAddrPort(c.Conn, c.destination)
if err != nil {
return err
}
@ -90,7 +89,7 @@ func (c *noneConn) Write(b []byte) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
err = socks5.AddressSerializer.WriteAddrPort(buffer, c.destination)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination)
if err != nil {
return
}
@ -140,12 +139,12 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
if err != nil {
return M.Socksaddr{}, err
}
return socks5.AddressSerializer.ReadAddrPort(buffer)
return M.SocksaddrSerializer.ReadAddrPort(buffer)
}
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buf.With(buffer.ExtendHeader(socks5.AddressSerializer.AddrPortLen(destination)))
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination)))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
@ -158,7 +157,7 @@ func (c *nonePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return
}
buffer := buf.With(p[:n])
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
return
}
@ -169,10 +168,10 @@ func (c *nonePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
func (c *nonePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
_buffer := buf.Make(socks5.AddressSerializer.AddrPortLen(destination) + len(p))
_buffer := buf.Make(M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
defer runtime.KeepAlive(_buffer)
buffer := buf.With(common.Dup(_buffer))
err = socks5.AddressSerializer.WriteAddrPort(buffer, destination)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
if err != nil {
return
}
@ -197,7 +196,7 @@ func NewNoneService(udpTimeout int64, handler Handler) Service {
}
func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
destination, err := socks5.AddressSerializer.ReadAddrPort(conn)
destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
if err != nil {
return err
}
@ -207,7 +206,7 @@ func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata
}
func (s *NoneService) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
return err
}
@ -225,8 +224,8 @@ type nonePacketWriter struct {
}
func (s *nonePacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buf.With(buffer.ExtendHeader(socks5.AddressSerializer.AddrPortLen(destination)))
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination)))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}

View file

@ -15,7 +15,6 @@ import (
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks5"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/hkdf"
)
@ -208,7 +207,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
bufferedWriter := writer.BufferedWriter(header.Len())
if len(payload) > 0 {
err := socks5.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return err
}
@ -218,7 +217,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
return err
}
} else {
err := socks5.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return err
}
@ -315,9 +314,9 @@ type clientPacketConn struct {
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buffer.ExtendHeader(c.keySaltLength + socks5.AddressSerializer.AddrPortLen(destination))
header := buffer.ExtendHeader(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(c.secureRNG, header[:c.keySaltLength]))
err := socks5.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
if err != nil {
return err
}
@ -338,7 +337,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
if err != nil {
return M.Socksaddr{}, err
}
return socks5.AddressSerializer.ReadAddrPort(buffer)
return M.SocksaddrSerializer.ReadAddrPort(buffer)
}
func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
@ -351,7 +350,7 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
if err != nil {
return
}
destination, err := socks5.AddressSerializer.ReadAddrPort(b)
destination, err := M.SocksaddrSerializer.ReadAddrPort(b)
if err != nil {
return
}
@ -364,7 +363,7 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
err = socks5.AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil {
return
}

View file

@ -17,7 +17,6 @@ import (
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks5"
"golang.org/x/crypto/chacha20poly1305"
)
@ -95,7 +94,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
key := Kdf(s.key, salt, s.keySaltLength)
reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize)
destination, err := socks5.AddressSerializer.ReadAddrPort(reader)
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return err
}
@ -221,7 +220,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me
buffer.Advance(s.keySaltLength)
buffer.Truncate(len(packet))
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
return err
}
@ -241,9 +240,9 @@ type serverPacketWriter struct {
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buffer.ExtendHeader(w.keySaltLength + socks5.AddressSerializer.AddrPortLen(destination))
header := buffer.ExtendHeader(w.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(w.secureRNG, header[:w.keySaltLength]))
err := socks5.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
if err != nil {
return err
}

View file

@ -22,7 +22,6 @@ import (
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks5"
"golang.org/x/crypto/chacha20poly1305"
wgReplay "golang.zx2c4.com/wireguard/replay"
"lukechampine.com/blake3"
@ -238,7 +237,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeClient))
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
err := socks5.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return E.Cause(err, "write destination")
}
@ -409,7 +408,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
hdrLen += 1 // header type
hdrLen += 8 // timestamp
hdrLen += 2 // padding length
hdrLen += socks5.AddressSerializer.AddrPortLen(destination)
hdrLen += M.SocksaddrSerializer.AddrPortLen(destination)
header := buf.With(buffer.ExtendHeader(hdrLen))
var dataIndex int
@ -449,7 +448,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
@ -580,7 +579,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
}
buffer.Advance(int(paddingLength))
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
return M.Socksaddr{}, err
}
@ -614,7 +613,7 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
overHead += 1 // header type
overHead += 8 // timestamp
overHead += 2 // padding length
overHead += socks5.AddressSerializer.AddrPortLen(destination)
overHead += M.SocksaddrSerializer.AddrPortLen(destination)
_buffer := buf.Make(overHead + len(p))
defer runtime.KeepAlive(_buffer)
@ -657,7 +656,7 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
binary.Write(buffer, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(buffer, binary.BigEndian, uint16(0)), // padding length
)
err = socks5.AddressSerializer.WriteAddrPort(buffer, destination)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
if err != nil {
return
}

View file

@ -24,7 +24,6 @@ import (
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks5"
wgReplay "golang.zx2c4.com/wireguard/replay"
)
@ -124,7 +123,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
return ErrBadTimestamp
}
destination, err := socks5.AddressSerializer.ReadAddrPort(reader)
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return E.Cause(err, "read destination")
}
@ -341,7 +340,7 @@ process:
}
buffer.Advance(int(paddingLength))
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
goto returnErr
}
@ -370,7 +369,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
hdrLen += 8 // timestamp
hdrLen += 8 // remote session id
hdrLen += 2 // padding length
hdrLen += socks5.AddressSerializer.AddrPortLen(destination)
hdrLen += M.SocksaddrSerializer.AddrPortLen(destination)
header := buf.With(buffer.ExtendHeader(hdrLen))
var dataIndex int
@ -390,7 +389,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}

View file

@ -18,7 +18,6 @@ import (
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks5"
"lukechampine.com/blake3"
)
@ -147,7 +146,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
return ErrBadTimestamp
}
destination, err := socks5.AddressSerializer.ReadAddrPort(reader)
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return E.Cause(err, "read destination")
}
@ -282,7 +281,7 @@ process:
}
buffer.Advance(int(paddingLength))
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
goto returnErr
}

View file

@ -18,7 +18,6 @@ import (
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks5"
"golang.org/x/crypto/blowfish"
"golang.org/x/crypto/cast5"
"golang.org/x/crypto/chacha20"
@ -227,7 +226,7 @@ type clientConn struct {
}
func (c *clientConn) writeRequest(payload []byte) error {
_buffer := buf.Make(c.method.keyLength + socks5.AddressSerializer.AddrPortLen(c.destination) + len(payload))
_buffer := buf.Make(c.method.keyLength + M.SocksaddrSerializer.AddrPortLen(c.destination) + len(payload))
defer runtime.KeepAlive(_buffer)
buffer := buf.With(common.Dup(_buffer))
@ -241,7 +240,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
}
runtime.KeepAlive(key)
err = socks5.AddressSerializer.WriteAddrPort(buffer, c.destination)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination)
if err != nil {
return err
}
@ -326,9 +325,9 @@ type clientPacketConn struct {
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buf.With(buffer.ExtendHeader(c.keyLength + socks5.AddressSerializer.AddrPortLen(destination)))
header := buf.With(buffer.ExtendHeader(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination)))
common.Must1(header.ReadFullFrom(c.secureRNG, c.keyLength))
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
@ -352,7 +351,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
}
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
buffer.Advance(c.keyLength)
return socks5.AddressSerializer.ReadAddrPort(buffer)
return M.SocksaddrSerializer.ReadAddrPort(buffer)
}
func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
@ -366,7 +365,7 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
}
buffer := buf.With(p[c.keyLength:n])
stream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
return
}
@ -377,11 +376,11 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
_buffer := buf.Make(c.keyLength + socks5.AddressSerializer.AddrPortLen(destination) + len(p))
_buffer := buf.Make(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
defer runtime.KeepAlive(_buffer)
buffer := buf.With(common.Dup(_buffer))
common.Must1(buffer.ReadFullFrom(c.secureRNG, c.keyLength))
err = socks5.AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil {
return
}

146
protocol/socks/client.go Normal file
View file

@ -0,0 +1,146 @@
package socks
import (
"context"
"net"
"net/url"
"os"
"strings"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/protocol/socks/socks4"
"github.com/sagernet/sing/protocol/socks/socks5"
)
type Version uint8
const (
Version4 Version = iota
Version4A
Version5
)
type Client struct {
version Version
dialer N.ContextDialer
serverAddr M.Socksaddr
username string
password string
}
func NewClient(dialer N.ContextDialer, serverAddr M.Socksaddr, version Version, username string, password string) *Client {
return &Client{
version: version,
dialer: dialer,
serverAddr: serverAddr,
username: username,
password: password,
}
}
func NewClientFromURL(dialer N.ContextDialer, rawURL string) (*Client, error) {
var client Client
if !strings.Contains(rawURL, "://") {
rawURL = "socks://" + rawURL
}
proxyURL, err := url.Parse(rawURL)
if err != nil {
return nil, err
}
client.dialer = dialer
client.serverAddr = M.ParseSocksaddr(proxyURL.Host)
switch proxyURL.Scheme {
case "socks4":
client.version = Version4
case "socks4a":
client.version = Version4A
case "socks", "socks5", "":
client.version = Version5
default:
return nil, E.New("socks: unknown scheme: ", proxyURL.Scheme)
}
if proxyURL.User != nil {
if client.version == Version5 {
client.username = proxyURL.User.Username()
client.password, _ = proxyURL.User.Password()
} else {
client.username = proxyURL.User.String()
}
}
return &client, nil
}
func (c *Client) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) {
var command byte
if strings.HasPrefix(network, "tcp") {
command = socks4.CommandConnect
} else {
if c.version != Version5 {
return nil, E.New("socks4: udp unsupported")
}
command = socks5.CommandUDPAssociate
}
tcpConn, err := c.dialer.DialContext(ctx, "tcp", c.serverAddr)
if err != nil {
return nil, err
}
if c.version == Version4 && address.Family().IsFqdn() {
tcpAddr, err := net.ResolveTCPAddr(network, address.String())
if err != nil {
tcpConn.Close()
return nil, err
}
address = M.SocksaddrFromNet(tcpAddr)
}
switch c.version {
case Version4, Version4A:
_, err = ClientHandshake4(tcpConn, command, address, c.username)
if err != nil {
tcpConn.Close()
return nil, err
}
return tcpConn, nil
case Version5:
response, err := ClientHandshake5(tcpConn, command, address, c.username, c.password)
if err != nil {
tcpConn.Close()
return nil, err
}
if command == socks5.CommandConnect {
return tcpConn, nil
}
udpConn, err := c.dialer.DialContext(ctx, "udp", response.Bind)
if err != nil {
tcpConn.Close()
return nil, err
}
return NewAssociateConn(tcpConn, udpConn, address), nil
}
return nil, os.ErrInvalid
}
func (c *Client) BindContext(ctx context.Context, address M.Socksaddr) (net.Conn, error) {
tcpConn, err := c.dialer.DialContext(ctx, "tcp", c.serverAddr)
if err != nil {
return nil, err
}
switch c.version {
case Version4, Version4A:
_, err = ClientHandshake4(tcpConn, socks4.CommandBind, address, c.username)
if err != nil {
tcpConn.Close()
return nil, err
}
return tcpConn, nil
case Version5:
_, err = ClientHandshake5(tcpConn, socks5.CommandBind, address, c.username, c.password)
if err != nil {
tcpConn.Close()
return nil, err
}
return tcpConn, nil
}
return nil, os.ErrInvalid
}

232
protocol/socks/handshake.go Normal file
View file

@ -0,0 +1,232 @@
package socks
import (
"context"
"io"
"net"
"net/netip"
"os"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/socks/socks4"
"github.com/sagernet/sing/protocol/socks/socks5"
)
func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) {
err := socks4.WriteRequest(conn, socks4.Request{
Command: command,
Destination: destination,
Username: username,
})
if err != nil {
return socks4.Response{}, err
}
response, err := socks4.ReadResponse(conn)
if err != nil {
return socks4.Response{}, err
}
if response.ReplyCode != socks4.ReplyCodeGranted {
err = E.New("socks4: request rejected, code= ", response.ReplyCode)
}
return response, err
}
func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, username string, password string) (socks5.Response, error) {
var method byte
if common.IsBlank(username) {
method = socks5.AuthTypeNotRequired
} else {
method = socks5.AuthTypeUsernamePassword
}
err := socks5.WriteAuthRequest(conn, socks5.AuthRequest{
Methods: []byte{method},
})
if err != nil {
return socks5.Response{}, err
}
authResponse, err := socks5.ReadAuthResponse(conn)
if err != nil {
return socks5.Response{}, err
}
if authResponse.Method == socks5.AuthTypeUsernamePassword {
err = socks5.WriteUsernamePasswordAuthRequest(conn, socks5.UsernamePasswordAuthRequest{
Username: username,
Password: password,
})
if err != nil {
return socks5.Response{}, err
}
usernamePasswordResponse, err := socks5.ReadUsernamePasswordAuthResponse(conn)
if err != nil {
return socks5.Response{}, err
}
if usernamePasswordResponse.Status != socks5.UsernamePasswordStatusSuccess {
return socks5.Response{}, E.New("socks5: incorrect user name or password")
}
} else if authResponse.Method != socks5.AuthTypeNotRequired {
return socks5.Response{}, E.New("socks5: unsupported auth method: ", authResponse.Method)
}
err = socks5.WriteRequest(conn, socks5.Request{
Command: command,
Destination: destination,
})
if err != nil {
return socks5.Response{}, err
}
response, err := socks5.ReadResponse(conn)
if err != nil {
return socks5.Response{}, err
}
if response.ReplyCode != socks5.ReplyCodeSuccess {
err = E.New("socks5: request rejected, code=", response.ReplyCode)
}
return response, err
}
func HandleConnection(ctx context.Context, conn net.Conn, authenticator auth.Authenticator, handler Handler, metadata M.Metadata) error {
version, err := rw.ReadByte(conn)
if err != nil {
return err
}
return HandleConnection0(ctx, conn, version, authenticator, handler, metadata)
}
func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authenticator auth.Authenticator, handler Handler, metadata M.Metadata) error {
switch version {
case socks4.Version:
request, err := socks4.ReadRequest0(conn)
if err != nil {
return err
}
switch request.Command {
case socks4.CommandConnect:
responseAddr := request.Destination
if !responseAddr.Family().IsIPv4() {
responseAddr = M.SocksaddrFromAddrPort(netip.IPv4Unspecified(), responseAddr.Port)
}
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeGranted,
Destination: responseAddr,
})
if err != nil {
return err
}
metadata.Protocol = "socks4"
metadata.Destination = request.Destination
ctx = &socks4.UserContext{
Context: ctx,
Username: request.Username,
}
return handler.NewConnection(ctx, conn, metadata)
default:
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
Destination: request.Destination,
})
if err != nil {
return err
}
return E.New("socks4: unsupported command ", request.Command)
}
case socks5.Version:
authRequest, err := socks5.ReadAuthRequest0(conn)
if err != nil {
return err
}
var authMethod byte
if authenticator != nil && !common.Contains(authRequest.Methods, socks5.AuthTypeUsernamePassword) {
err = socks5.WriteAuthResponse(conn, socks5.AuthResponse{
Method: socks5.AuthTypeNoAcceptedMethods,
})
if err != nil {
return err
}
}
if authenticator != nil {
authMethod = socks5.AuthTypeUsernamePassword
} else {
authMethod = socks5.AuthTypeNotRequired
}
err = socks5.WriteAuthResponse(conn, socks5.AuthResponse{
Method: authMethod,
})
if err != nil {
return err
}
userCtx := &socks5.UserContext{
Context: ctx,
}
if authMethod == socks5.AuthTypeUsernamePassword {
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(conn)
if err != nil {
return err
}
userCtx.Username = usernamePasswordAuthRequest.Username
userCtx.Password = usernamePasswordAuthRequest.Password
response := socks5.UsernamePasswordAuthResponse{}
if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) {
response.Status = socks5.UsernamePasswordStatusSuccess
} else {
response.Status = socks5.UsernamePasswordStatusFailure
}
err = socks5.WriteUsernamePasswordAuthResponse(conn, response)
if err != nil {
return err
}
}
ctx = userCtx
request, err := socks5.ReadRequest(conn)
if err != nil {
return err
}
switch request.Command {
case socks5.CommandConnect:
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: request.Destination,
})
if err != nil {
return err
}
metadata.Protocol = "socks5"
metadata.Destination = request.Destination
return handler.NewConnection(ctx, conn, metadata)
case socks5.CommandUDPAssociate:
udpConn, err := net.ListenUDP(M.NetworkFromNetAddr("udp", M.AddrFromNetAddr(conn.LocalAddr())), net.UDPAddrFromAddrPort(netip.AddrPortFrom(M.AddrFromNetAddr(conn.LocalAddr()), 0)))
if err != nil {
return err
}
defer udpConn.Close()
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.SocksaddrFromNet(udpConn.LocalAddr()),
})
if err != nil {
return err
}
metadata.Protocol = "socks5"
metadata.Destination = request.Destination
go func() {
err = handler.NewPacketConnection(ctx, NewAssociatePacketConn(conn, udpConn, request.Destination), metadata)
if err != nil {
handler.HandleError(err)
conn.Close()
}
}()
return common.Error(io.Copy(io.Discard, conn))
default:
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeUnsupported,
})
if err != nil {
return err
}
return E.New("socks5: unsupported command ", request.Command)
}
}
return os.ErrInvalid
}

View file

@ -1,4 +1,4 @@
package socks5
package socks
import (
"context"
@ -34,7 +34,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler
}
func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
return HandleConnection(ctx, conn, l.authenticator, M.AddrFromNetAddr(conn.LocalAddr()), l.handler, metadata)
return HandleConnection(ctx, conn, l.authenticator, l.handler, metadata)
}
func (l *Listener) Start() error {

View file

@ -1,4 +1,4 @@
package socks5
package socks
import (
"net"
@ -9,6 +9,12 @@ import (
M "github.com/sagernet/sing/common/metadata"
)
//+----+------+------+----------+----------+----------+
//|RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
//+----+------+------+----------+----------+----------+
//| 2 | 1 | 1 | Variable | 2 | Variable |
//+----+------+------+----------+----------+----------+
type AssociateConn struct {
net.Conn
conn net.Conn
@ -40,7 +46,7 @@ func (c *AssociateConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return
}
reader := buf.As(p[3:n])
destination, err := AddressSerializer.ReadAddrPort(reader)
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return
}
@ -54,7 +60,7 @@ func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil {
return
}
@ -77,7 +83,7 @@ func (c *AssociateConn) Write(b []byte) (n int, err error) {
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, c.dest)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.dest)
if err != nil {
return
}
@ -96,14 +102,14 @@ func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
}
buffer.Truncate(int(n))
buffer.Advance(3)
return AddressSerializer.ReadAddrPort(buffer)
return M.SocksaddrSerializer.ReadAddrPort(buffer)
}
func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
header := buf.With(buffer.ExtendHeader(3 + M.SocksaddrSerializer.AddrPortLen(destination)))
common.Must(header.WriteZeroN(3))
common.Must(AddressSerializer.WriteAddrPort(header, destination))
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
return common.Error(c.Conn.Write(buffer.Bytes()))
}
@ -132,7 +138,7 @@ func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err erro
return
}
reader := buf.As(p[3:n])
destination, err := AddressSerializer.ReadAddrPort(reader)
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return
}
@ -147,7 +153,7 @@ func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil {
return
}
@ -170,7 +176,7 @@ func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, c.dest)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.dest)
if err != nil {
return
}
@ -190,14 +196,14 @@ func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error
c.addr = addr
buffer.Truncate(n)
buffer.Advance(3)
dest, err := AddressSerializer.ReadAddrPort(buffer)
dest, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
return dest, err
}
func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
header := buf.With(buffer.ExtendHeader(3 + M.SocksaddrSerializer.AddrPortLen(destination)))
common.Must(header.WriteZeroN(3))
common.Must(AddressSerializer.WriteAddrPort(header, destination))
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr))
}

View file

@ -0,0 +1,8 @@
package socks4
import "context"
type UserContext struct {
context.Context
Username string
}

View file

@ -0,0 +1,170 @@
package socks4
import (
"bytes"
"encoding/binary"
"io"
"net/netip"
"os"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
)
const (
Version byte = 4
CommandConnect byte = 1
CommandBind byte = 2
ReplyCodeGranted byte = 90
ReplyCodeRejectedOrFailed byte = 91
ReplyCodeCannotConnectToIdentd byte = 92
ReplyCodeIdentdReportDifferentUserID byte = 93
)
type Request struct {
Command byte
Destination M.Socksaddr
Username string
}
func ReadRequest(reader io.Reader) (request Request, err error) {
version, err := rw.ReadByte(reader)
if err != nil {
return
}
if version != 4 {
err = E.New("excepted socks version 4, got ", version)
return
}
return ReadRequest0(reader)
}
func ReadRequest0(reader io.Reader) (request Request, err error) {
request.Command, err = rw.ReadByte(reader)
if err != nil {
return
}
err = binary.Read(reader, binary.BigEndian, &request.Destination.Port)
if err != nil {
return
}
var dstIP [4]byte
_, err = io.ReadFull(reader, dstIP[:])
if err != nil {
return
}
var readHostName bool
if dstIP[0] == 0 && dstIP[1] == 0 && dstIP[2] == 0 {
readHostName = true
} else {
request.Destination.Addr = netip.AddrFrom4(dstIP)
}
request.Username, err = readString(reader)
if readHostName {
request.Destination.Fqdn, err = readString(reader)
}
return
}
func WriteRequest(writer io.Writer, request Request) error {
if request.Command != CommandConnect && request.Command != CommandBind {
return os.ErrInvalid
}
_, err := writer.Write([]byte{Version, request.Command})
if err != nil {
return err
}
err = binary.Write(writer, binary.BigEndian, request.Destination.Port)
if err != nil {
return err
}
if request.Destination.Family().IsIPv4() {
dstIP := request.Destination.Addr.As4()
_, err = writer.Write(dstIP[:])
if err != nil {
return err
}
} else {
err = rw.WriteZeroN(writer, 4)
if err != nil {
return err
}
_, err = writer.Write([]byte(request.Destination.AddrString()))
if err != nil {
return err
}
err = rw.WriteZero(writer)
if err != nil {
return err
}
}
if request.Username != "" {
_, err = writer.Write([]byte(request.Username))
if err != nil {
return err
}
}
return rw.WriteZero(writer)
}
type Response struct {
ReplyCode byte
Destination M.Socksaddr
}
func ReadResponse(reader io.Reader) (response Response, err error) {
version, err := rw.ReadByte(reader)
if err != nil {
return
}
if version != 4 {
err = E.New("excepted socks version 4, got ", version)
return
}
response.ReplyCode, err = rw.ReadByte(reader)
if err != nil {
return
}
err = binary.Read(reader, binary.BigEndian, &response.Destination.Port)
if err != nil {
return
}
var dstIP [4]byte
_, err = io.ReadFull(reader, dstIP[:])
if err != nil {
return
}
response.Destination.Addr = netip.AddrFrom4(dstIP)
return
}
func WriteResponse(writer io.Writer, response Response) error {
_, err := writer.Write([]byte{Version, response.ReplyCode})
if err != nil {
return err
}
err = binary.Write(writer, binary.BigEndian, response.Destination.Port)
if err != nil {
return err
}
dstIP := response.Destination.Addr.As4()
return rw.WriteBytes(writer, dstIP[:])
}
func readString(reader io.Reader) (string, error) {
buffer := bytes.Buffer{}
for {
b, err := rw.ReadByte(reader)
if err != nil {
return "", err
}
if b == 0 {
break
}
buffer.WriteByte(b)
}
return buffer.String(), nil
}

View file

@ -0,0 +1,9 @@
package socks5
import "context"
type UserContext struct {
context.Context
Username string
Password string
}

View file

@ -0,0 +1,281 @@
package socks5
import (
"io"
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
)
const (
Version byte = 5
AuthTypeNotRequired byte = 0x00
AuthTypeGSSAPI byte = 0x01
AuthTypeUsernamePassword byte = 0x02
AuthTypeNoAcceptedMethods byte = 0xFF
UsernamePasswordStatusSuccess byte = 0x00
UsernamePasswordStatusFailure byte = 0x01
CommandConnect byte = 0x01
CommandBind byte = 0x02
CommandUDPAssociate byte = 0x03
ReplyCodeSuccess byte = 0
ReplyCodeFailure byte = 1
ReplyCodeNotAllowed byte = 2
ReplyCodeNetworkUnreachable byte = 3
ReplyCodeHostUnreachable byte = 4
ReplyCodeConnectionRefused byte = 5
ReplyCodeTTLExpired byte = 6
ReplyCodeUnsupported byte = 7
ReplyCodeAddressTypeUnsupported byte = 8
)
//+----+----------+----------+
//|VER | NMETHODS | METHODS |
//+----+----------+----------+
//| 1 | 1 | 1 to 255 |
//+----+----------+----------+
type AuthRequest struct {
Methods []byte
}
func WriteAuthRequest(writer io.Writer, request AuthRequest) error {
err := rw.WriteByte(writer, Version)
if err != nil {
return err
}
err = rw.WriteByte(writer, byte(len(request.Methods)))
if err != nil {
return err
}
return rw.WriteBytes(writer, request.Methods)
}
func ReadAuthRequest(reader io.Reader) (request AuthRequest, err error) {
version, err := rw.ReadByte(reader)
if err != nil {
return
}
if version != Version {
err = E.New("expected socks version 5, got ", version)
return
}
return ReadAuthRequest0(reader)
}
func ReadAuthRequest0(reader io.Reader) (request AuthRequest, err error) {
methodLen, err := rw.ReadByte(reader)
if err != nil {
return
}
request.Methods, err = rw.ReadBytes(reader, int(methodLen))
return
}
//+----+--------+
//|VER | METHOD |
//+----+--------+
//| 1 | 1 |
//+----+--------+
type AuthResponse struct {
Method byte
}
func WriteAuthResponse(writer io.Writer, response AuthResponse) error {
return rw.WriteBytes(writer, []byte{Version, response.Method})
}
func ReadAuthResponse(reader io.Reader) (response AuthResponse, err error) {
version, err := rw.ReadByte(reader)
if err != nil {
return
}
if version != Version {
err = E.New("expected socks version 5, got ", version)
return
}
response.Method, err = rw.ReadByte(reader)
return
}
//+----+------+----------+------+----------+
//|VER | ULEN | UNAME | PLEN | PASSWD |
//+----+------+----------+------+----------+
//| 1 | 1 | 1 to 255 | 1 | 1 to 255 |
//+----+------+----------+------+----------+
type UsernamePasswordAuthRequest struct {
Username string
Password string
}
func WriteUsernamePasswordAuthRequest(writer io.Writer, request UsernamePasswordAuthRequest) error {
err := rw.WriteByte(writer, 1)
if err != nil {
return err
}
err = M.WriteSocksString(writer, request.Username)
if err != nil {
return err
}
return M.WriteSocksString(writer, request.Password)
}
func ReadUsernamePasswordAuthRequest(reader io.Reader) (request UsernamePasswordAuthRequest, err error) {
version, err := rw.ReadByte(reader)
if err != nil {
return
}
if version != 1 {
err = E.New("excepted password request version 1, got ", version)
return
}
request.Username, err = M.ReadSockString(reader)
if err != nil {
return
}
request.Password, err = M.ReadSockString(reader)
if err != nil {
return
}
return
}
//+----+--------+
//|VER | STATUS |
//+----+--------+
//| 1 | 1 |
//+----+--------+
type UsernamePasswordAuthResponse struct {
Status byte
}
func WriteUsernamePasswordAuthResponse(writer io.Writer, response UsernamePasswordAuthResponse) error {
err := rw.WriteByte(writer, 1)
if err != nil {
return err
}
return rw.WriteByte(writer, response.Status)
}
func ReadUsernamePasswordAuthResponse(reader io.Reader) (response UsernamePasswordAuthResponse, err error) {
version, err := rw.ReadByte(reader)
if err != nil {
return
}
if version != 1 {
err = E.New("excepted password request version 1, got ", version)
return
}
response.Status, err = rw.ReadByte(reader)
return
}
//+----+-----+-------+------+----------+----------+
//|VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT |
//+----+-----+-------+------+----------+----------+
//| 1 | 1 | X'00' | 1 | Variable | 2 |
//+----+-----+-------+------+----------+----------+
type Request struct {
Command byte
Destination M.Socksaddr
}
func WriteRequest(writer io.Writer, request Request) error {
err := rw.WriteByte(writer, Version)
if err != nil {
return err
}
err = rw.WriteByte(writer, request.Command)
if err != nil {
return err
}
err = rw.WriteZero(writer)
if err != nil {
return err
}
return M.SocksaddrSerializer.WriteAddrPort(writer, request.Destination)
}
func ReadRequest(reader io.Reader) (request Request, err error) {
version, err := rw.ReadByte(reader)
if err != nil {
return
}
if version != Version {
err = E.New("expected socks version 5, got ", version)
return
}
request.Command, err = rw.ReadByte(reader)
if err != nil {
return
}
err = rw.Skip(reader)
if err != nil {
return
}
request.Destination, err = M.SocksaddrSerializer.ReadAddrPort(reader)
return
}
//+----+-----+-------+------+----------+----------+
//|VER | REP | RSV | ATYP | BND.ADDR | BND.PORT |
//+----+-----+-------+------+----------+----------+
//| 1 | 1 | X'00' | 1 | Variable | 2 |
//+----+-----+-------+------+----------+----------+
type Response struct {
ReplyCode byte
Bind M.Socksaddr
}
func WriteResponse(writer io.Writer, response Response) error {
err := rw.WriteByte(writer, Version)
if err != nil {
return err
}
err = rw.WriteByte(writer, response.ReplyCode)
if err != nil {
return err
}
err = rw.WriteZero(writer)
if err != nil {
return err
}
if !response.Bind.IsValid() {
return M.SocksaddrSerializer.WriteAddrPort(writer, M.Socksaddr{
Addr: netip.IPv4Unspecified(),
})
}
return M.SocksaddrSerializer.WriteAddrPort(writer, response.Bind)
}
func ReadResponse(reader io.Reader) (response Response, err error) {
version, err := rw.ReadByte(reader)
if err != nil {
return
}
if version != Version {
err = E.New("expected socks version 5, got ", version)
return
}
response.ReplyCode, err = rw.ReadByte(reader)
if err != nil {
return
}
err = rw.Skip(reader)
if err != nil {
return
}
response.Bind, err = M.SocksaddrSerializer.ReadAddrPort(reader)
return
}

View file

@ -1,76 +0,0 @@
package socks5
import (
"strconv"
M "github.com/sagernet/sing/common/metadata"
)
const (
Version4 byte = 0x04
Version5 byte = 0x05
)
const (
AuthTypeNotRequired byte = 0x00
AuthTypeGSSAPI byte = 0x01
AuthTypeUsernamePassword byte = 0x02
AuthTypeNoAcceptedMethods byte = 0xFF
)
const (
UsernamePasswordVersion1 byte = 0x01
UsernamePasswordStatusSuccess byte = 0x00
UsernamePasswordStatusFailure byte = 0x01
)
const (
CommandConnect byte = 0x01
CommandBind byte = 0x02
CommandUDPAssociate byte = 0x03
)
type ReplyCode byte
const (
ReplyCodeSuccess ReplyCode = iota
ReplyCodeFailure
ReplyCodeNotAllowed
ReplyCodeNetworkUnreachable
ReplyCodeHostUnreachable
ReplyCodeConnectionRefused
ReplyCodeTTLExpired
ReplyCodeUnsupported
ReplyCodeAddressTypeUnsupported
)
func (code ReplyCode) String() string {
switch code {
case ReplyCodeSuccess:
return "succeeded"
case ReplyCodeFailure:
return "general SOCKS server failure"
case ReplyCodeNotAllowed:
return "connection not allowed by ruleset"
case ReplyCodeNetworkUnreachable:
return "network unreachable"
case ReplyCodeHostUnreachable:
return "host unreachable"
case ReplyCodeConnectionRefused:
return "connection refused"
case ReplyCodeTTLExpired:
return "TTL expired"
case ReplyCodeUnsupported:
return "command not supported"
case ReplyCodeAddressTypeUnsupported:
return "address type not supported"
default:
return "unassigned code: " + strconv.Itoa(int(code))
}
}
var AddressSerializer = M.NewSerializer(
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
M.AddressFamilyByte(0x04, M.AddressFamilyIPv6),
M.AddressFamilyByte(0x03, M.AddressFamilyFqdn),
)

View file

@ -1,25 +0,0 @@
package socks5
import "fmt"
type UnsupportedVersionException struct {
Version byte
}
func (e UnsupportedVersionException) Error() string {
return fmt.Sprint("unsupported version: ", e.Version)
}
type UnsupportedCommandException struct {
Command byte
}
func (e UnsupportedCommandException) Error() string {
return fmt.Sprint("unsupported command: ", e.Command)
}
type UsernamePasswordAuthFailureException struct{}
func (e UsernamePasswordAuthFailureException) Error() string {
return "username/password auth failed"
}

View file

@ -1,246 +0,0 @@
package socks5
import (
"context"
"io"
"net"
"net/netip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination M.Socksaddr, username string, password string) (*Response, error) {
var method byte
if common.IsBlank(username) {
method = AuthTypeNotRequired
} else {
method = AuthTypeUsernamePassword
}
err := WriteAuthRequest(conn, &AuthRequest{
Version: version,
Methods: []byte{method},
})
if err != nil {
return nil, err
}
authResponse, err := ReadAuthResponse(conn)
if err != nil {
return nil, err
}
if authResponse.Method != method {
return nil, E.New("not requested method, request ", method, ", return ", method)
}
if method == AuthTypeUsernamePassword {
err = WriteUsernamePasswordAuthRequest(conn, &UsernamePasswordAuthRequest{
Username: username,
Password: password,
})
if err != nil {
return nil, err
}
usernamePasswordResponse, err := ReadUsernamePasswordAuthResponse(conn)
if err != nil {
return nil, err
}
if usernamePasswordResponse.Status == UsernamePasswordStatusFailure {
return nil, &UsernamePasswordAuthFailureException{}
}
}
err = WriteRequest(conn, &Request{
Version: version,
Command: command,
Destination: destination,
})
if err != nil {
return nil, err
}
return ReadResponse(conn)
}
func ClientFastHandshake(writer io.Writer, version byte, command byte, destination M.Socksaddr, username string, password string) error {
var method byte
if common.IsBlank(username) {
method = AuthTypeNotRequired
} else {
method = AuthTypeUsernamePassword
}
err := WriteAuthRequest(writer, &AuthRequest{
Version: version,
Methods: []byte{method},
})
if err != nil {
return err
}
if method == AuthTypeUsernamePassword {
err = WriteUsernamePasswordAuthRequest(writer, &UsernamePasswordAuthRequest{
Username: username,
Password: password,
})
if err != nil {
return err
}
}
return WriteRequest(writer, &Request{
Version: version,
Command: command,
Destination: destination,
})
}
func ClientFastHandshakeFinish(reader io.Reader) (*Response, error) {
response, err := ReadAuthResponse(reader)
if err != nil {
return nil, err
}
if response.Method == AuthTypeUsernamePassword {
usernamePasswordResponse, err := ReadUsernamePasswordAuthResponse(reader)
if err != nil {
return nil, err
}
if usernamePasswordResponse.Status == UsernamePasswordStatusFailure {
return nil, &UsernamePasswordAuthFailureException{}
}
}
return ReadResponse(reader)
}
func HandleConnection(ctx context.Context, conn net.Conn, authenticator auth.Authenticator, bind netip.Addr, handler Handler, metadata M.Metadata) error {
authRequest, err := ReadAuthRequest(conn)
if err != nil {
return E.Cause(err, "read socks auth request")
}
return handleConnection(authRequest, ctx, conn, authenticator, bind, handler, metadata)
}
func HandleConnection0(ctx context.Context, conn net.Conn, authenticator auth.Authenticator, bind netip.Addr, handler Handler, metadata M.Metadata) error {
authRequest, err := ReadAuthRequest0(conn)
if err != nil {
return E.Cause(err, "read socks auth request")
}
return handleConnection(authRequest, ctx, conn, authenticator, bind, handler, metadata)
}
func handleConnection(authRequest *AuthRequest, ctx context.Context, conn net.Conn, authenticator auth.Authenticator, bind netip.Addr, handler Handler, metadata M.Metadata) error {
request, err := serverHandshake(authRequest, conn, authenticator)
if err != nil {
return E.Cause(err, "read socks request")
}
switch request.Command {
case CommandConnect:
err = WriteResponse(conn, &Response{
Version: request.Version,
ReplyCode: ReplyCodeSuccess,
Bind: M.SocksaddrFromNet(conn.LocalAddr()),
})
if err != nil {
return E.Cause(err, "write socks response")
}
metadata.Protocol = "socks5"
metadata.Destination = request.Destination
return handler.NewConnection(ctx, conn, metadata)
case CommandUDPAssociate:
network := "udp"
if bind.Is4() {
network = "udp4"
}
udpConn, err := net.ListenUDP(network, net.UDPAddrFromAddrPort(netip.AddrPortFrom(bind, 0)))
if err != nil {
return err
}
defer udpConn.Close()
err = WriteResponse(conn, &Response{
Version: request.Version,
ReplyCode: ReplyCodeSuccess,
Bind: M.SocksaddrFromNet(udpConn.LocalAddr()),
})
if err != nil {
return E.Cause(err, "write socks response")
}
metadata.Protocol = "socks5"
metadata.Destination = request.Destination
go func() {
err := handler.NewPacketConnection(ctx, NewAssociatePacketConn(conn, udpConn, request.Destination), metadata)
if err != nil {
handler.HandleError(err)
}
conn.Close()
}()
return common.Error(io.Copy(io.Discard, conn))
default:
err = WriteResponse(conn, &Response{
Version: request.Version,
ReplyCode: ReplyCodeUnsupported,
})
if err != nil {
return E.Cause(err, "write response")
}
}
return nil
}
func ServerHandshake(conn net.Conn, authenticator auth.Authenticator) (*Request, error) {
authRequest, err := ReadAuthRequest(conn)
if err != nil {
return nil, E.Cause(err, "read socks auth request")
}
return serverHandshake(authRequest, conn, authenticator)
}
func ServerHandshake0(conn net.Conn, authenticator auth.Authenticator) (*Request, error) {
authRequest, err := ReadAuthRequest0(conn)
if err != nil {
return nil, E.Cause(err, "read socks auth request")
}
return serverHandshake(authRequest, conn, authenticator)
}
func serverHandshake(authRequest *AuthRequest, conn net.Conn, authenticator auth.Authenticator) (*Request, error) {
var authMethod byte
if authenticator == nil {
authMethod = AuthTypeNotRequired
} else {
authMethod = AuthTypeUsernamePassword
}
if !common.Contains(authRequest.Methods, authMethod) {
err := WriteAuthResponse(conn, &AuthResponse{
Version: authRequest.Version,
Method: AuthTypeNoAcceptedMethods,
})
if err != nil {
return nil, E.Cause(err, "write socks auth response")
}
}
err := WriteAuthResponse(conn, &AuthResponse{
Version: authRequest.Version,
Method: authMethod,
})
if err != nil {
return nil, E.Cause(err, "write socks auth response")
}
if authMethod == AuthTypeUsernamePassword {
usernamePasswordAuthRequest, err := ReadUsernamePasswordAuthRequest(conn)
if err != nil {
return nil, E.Cause(err, "read user auth request")
}
response := &UsernamePasswordAuthResponse{}
if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) {
response.Status = UsernamePasswordStatusSuccess
} else {
response.Status = UsernamePasswordStatusFailure
}
err = WriteUsernamePasswordAuthResponse(conn, response)
if err != nil {
return nil, E.Cause(err, "write user auth response")
}
}
request, err := ReadRequest(conn)
if err != nil {
return nil, E.Cause(err, "read socks request")
}
return request, nil
}

View file

@ -1,367 +0,0 @@
package socks5
import (
"bytes"
"io"
"net/netip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
)
//+----+----------+----------+
//|VER | NMETHODS | METHODS |
//+----+----------+----------+
//| 1 | 1 | 1 to 255 |
//+----+----------+----------+
type AuthRequest struct {
Version byte
Methods []byte
}
func WriteAuthRequest(writer io.Writer, request *AuthRequest) error {
err := rw.WriteByte(writer, request.Version)
if err != nil {
return err
}
err = rw.WriteByte(writer, byte(len(request.Methods)))
if err != nil {
return err
}
return rw.WriteBytes(writer, request.Methods)
}
func ReadAuthRequest(reader io.Reader) (*AuthRequest, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if version != Version5 {
return nil, &UnsupportedVersionException{version}
}
methodLen, err := rw.ReadByte(reader)
if err != nil {
return nil, E.Cause(err, "read socks auth methods length")
}
methods, err := rw.ReadBytes(reader, int(methodLen))
if err != nil {
return nil, E.Cause(err, "read socks auth methods, length ", methodLen)
}
request := &AuthRequest{
version,
methods,
}
return request, nil
}
func ReadAuthRequest0(reader io.Reader) (*AuthRequest, error) {
methodLen, err := rw.ReadByte(reader)
if err != nil {
return nil, E.Cause(err, "read socks auth methods length")
}
methods, err := rw.ReadBytes(reader, int(methodLen))
if err != nil {
return nil, E.Cause(err, "read socks auth methods, length ", methodLen)
}
request := &AuthRequest{
Version5,
methods,
}
return request, nil
}
//+----+--------+
//|VER | METHOD |
//+----+--------+
//| 1 | 1 |
//+----+--------+
type AuthResponse struct {
Version byte
Method byte
}
func WriteAuthResponse(writer io.Writer, response *AuthResponse) error {
err := rw.WriteByte(writer, response.Version)
if err != nil {
return err
}
return rw.WriteByte(writer, response.Method)
}
func ReadAuthResponse(reader io.Reader) (*AuthResponse, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if version != Version5 {
return nil, &UnsupportedVersionException{version}
}
method, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
response := &AuthResponse{
Version: version,
Method: method,
}
return response, nil
}
//+----+------+----------+------+----------+
//|VER | ULEN | UNAME | PLEN | PASSWD |
//+----+------+----------+------+----------+
//| 1 | 1 | 1 to 255 | 1 | 1 to 255 |
//+----+------+----------+------+----------+
type UsernamePasswordAuthRequest struct {
Username string
Password string
}
func WriteUsernamePasswordAuthRequest(writer io.Writer, request *UsernamePasswordAuthRequest) error {
err := rw.WriteByte(writer, UsernamePasswordVersion1)
if err != nil {
return err
}
err = M.WriteString(writer, "username", request.Username)
if err != nil {
return err
}
return M.WriteString(writer, "password", request.Password)
}
func ReadUsernamePasswordAuthRequest(reader io.Reader) (*UsernamePasswordAuthRequest, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if version != UsernamePasswordVersion1 {
return nil, &UnsupportedVersionException{version}
}
username, err := M.ReadString(reader)
if err != nil {
return nil, err
}
password, err := M.ReadString(reader)
if err != nil {
return nil, err
}
request := &UsernamePasswordAuthRequest{
Username: username,
Password: password,
}
return request, nil
}
//+----+--------+
//|VER | STATUS |
//+----+--------+
//| 1 | 1 |
//+----+--------+
type UsernamePasswordAuthResponse struct {
Status byte
}
func WriteUsernamePasswordAuthResponse(writer io.Writer, response *UsernamePasswordAuthResponse) error {
err := rw.WriteByte(writer, UsernamePasswordVersion1)
if err != nil {
return err
}
return rw.WriteByte(writer, response.Status)
}
func ReadUsernamePasswordAuthResponse(reader io.Reader) (*UsernamePasswordAuthResponse, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if version != UsernamePasswordVersion1 {
return nil, &UnsupportedVersionException{version}
}
status, err := rw.ReadByte(reader)
if status != UsernamePasswordStatusSuccess {
status = UsernamePasswordStatusFailure
}
response := &UsernamePasswordAuthResponse{
Status: status,
}
return response, nil
}
//+----+-----+-------+------+----------+----------+
//|VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT |
//+----+-----+-------+------+----------+----------+
//| 1 | 1 | X'00' | 1 | Variable | 2 |
//+----+-----+-------+------+----------+----------+
type Request struct {
Version byte
Command byte
Destination M.Socksaddr
}
func WriteRequest(writer io.Writer, request *Request) error {
err := rw.WriteByte(writer, request.Version)
if err != nil {
return err
}
err = rw.WriteByte(writer, request.Command)
if err != nil {
return err
}
err = rw.WriteZero(writer)
if err != nil {
return err
}
return AddressSerializer.WriteAddrPort(writer, request.Destination)
}
func ReadRequest(reader io.Reader) (*Request, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if !(version == Version4 || version == Version5) {
return nil, &UnsupportedVersionException{version}
}
command, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if command != CommandConnect && command != CommandUDPAssociate {
return nil, &UnsupportedCommandException{command}
}
err = rw.Skip(reader)
if err != nil {
return nil, err
}
addrPort, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
request := &Request{
Version: version,
Command: command,
Destination: addrPort,
}
return request, nil
}
//+----+-----+-------+------+----------+----------+
//|VER | REP | RSV | ATYP | BND.ADDR | BND.PORT |
//+----+-----+-------+------+----------+----------+
//| 1 | 1 | X'00' | 1 | Variable | 2 |
//+----+-----+-------+------+----------+----------+
type Response struct {
Version byte
ReplyCode ReplyCode
Bind M.Socksaddr
}
func WriteResponse(writer io.Writer, response *Response) error {
err := rw.WriteByte(writer, response.Version)
if err != nil {
return err
}
err = rw.WriteByte(writer, byte(response.ReplyCode))
if err != nil {
return err
}
err = rw.WriteZero(writer)
if err != nil {
return err
}
if !response.Bind.IsValid() {
return AddressSerializer.WriteAddrPort(writer, M.Socksaddr{
Addr: netip.IPv4Unspecified(),
})
}
return AddressSerializer.WriteAddrPort(writer, response.Bind)
}
func ReadResponse(reader io.Reader) (*Response, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if !(version == Version4 || version == Version5) {
return nil, &UnsupportedVersionException{version}
}
replyCode, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
err = rw.Skip(reader)
if err != nil {
return nil, err
}
addrPort, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
response := &Response{
Version: version,
ReplyCode: ReplyCode(replyCode),
Bind: addrPort,
}
return response, nil
}
//+----+------+------+----------+----------+----------+
//|RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
//+----+------+------+----------+----------+----------+
//| 2 | 1 | 1 | Variable | 2 | Variable |
//+----+------+------+----------+----------+----------+
type AssociatePacket struct {
Fragment byte
Destination M.Socksaddr
Data []byte
}
func DecodeAssociatePacket(buffer *buf.Buffer) (*AssociatePacket, error) {
if buffer.Len() < 5 {
return nil, E.New("insufficient length")
}
fragment := buffer.Byte(2)
reader := bytes.NewReader(buffer.Bytes())
err := common.Error(reader.Seek(3, io.SeekStart))
if err != nil {
return nil, err
}
addrPort, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
buffer.Advance(reader.Len())
packet := &AssociatePacket{
Fragment: fragment,
Destination: addrPort,
Data: buffer.Bytes(),
}
return packet, nil
}
func EncodeAssociatePacket(packet *AssociatePacket, buffer *buf.Buffer) error {
err := rw.WriteZeroN(buffer, 2)
if err != nil {
return err
}
err = rw.WriteByte(buffer, packet.Fragment)
if err != nil {
return err
}
err = AddressSerializer.WriteAddrPort(buffer, packet.Destination)
if err != nil {
return err
}
_, err = buffer.Write(packet.Data)
return err
}

View file

@ -1,75 +0,0 @@
package socks5_test
import (
"net"
"sync"
"testing"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks5"
)
func TestHandshake(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
defer client.Close()
wg := new(sync.WaitGroup)
wg.Add(1)
method := socks5.AuthTypeUsernamePassword
go func() {
response, err := socks5.ClientHandshake(client, socks5.Version5, socks5.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn("test"), 80), "user", "pswd")
if err != nil {
t.Fatal(err)
}
if response.ReplyCode != socks5.ReplyCodeSuccess {
t.Fatal(response)
}
wg.Done()
}()
authRequest, err := socks5.ReadAuthRequest(server)
if err != nil {
t.Fatal(err)
}
if len(authRequest.Methods) != 1 || authRequest.Methods[0] != method {
t.Fatal("bad methods: ", authRequest.Methods)
}
err = socks5.WriteAuthResponse(server, &socks5.AuthResponse{
Version: socks5.Version5,
Method: method,
})
if err != nil {
t.Fatal(err)
}
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(server)
if err != nil {
t.Fatal(err)
}
if usernamePasswordAuthRequest.Username != "user" || usernamePasswordAuthRequest.Password != "pswd" {
t.Fatal(authRequest)
}
err = socks5.WriteUsernamePasswordAuthResponse(server, &socks5.UsernamePasswordAuthResponse{
Status: socks5.UsernamePasswordStatusSuccess,
})
if err != nil {
t.Fatal(err)
}
request, err := socks5.ReadRequest(server)
if err != nil {
t.Fatal(err)
}
if request.Version != socks5.Version5 || request.Command != socks5.CommandConnect || request.Destination.Addr.Fqdn() != "test" || request.Destination.Port != 80 {
t.Fatal(request)
}
err = socks5.WriteResponse(server, &socks5.Response{
Version: socks5.Version5,
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0),
})
if err != nil {
t.Fatal(err)
}
wg.Wait()
}

View file

@ -13,7 +13,6 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/socks5"
)
const (
@ -127,7 +126,7 @@ func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destin
if err != nil {
return err
}
err = socks5.AddressSerializer.WriteAddrPort(conn, destination)
err = M.SocksaddrSerializer.WriteAddrPort(conn, destination)
if err != nil {
return err
}
@ -145,7 +144,7 @@ func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destin
}
func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error {
headerLen := KeyLength + socks5.AddressSerializer.AddrPortLen(destination) + 5
headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5
var header *buf.Buffer
var writeHeader bool
if len(payload) > 0 && headerLen+len(payload) < 65535 {
@ -161,7 +160,7 @@ func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr
common.Must1(header.Write(key[:]))
common.Must1(header.Write(CRLF))
common.Must(header.WriteByte(CommandTCP))
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
common.Must1(header.Write(CRLF))
common.Must1(header.Write(payload))
@ -180,7 +179,7 @@ func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr
}
func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
headerLen := KeyLength + 2*socks5.AddressSerializer.AddrPortLen(destination) + 9
headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9
payloadLen := payload.Len()
var header *buf.Buffer
var writeHeader bool
@ -195,9 +194,9 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Soc
common.Must1(header.Write(key[:]))
common.Must1(header.Write(CRLF))
common.Must(header.WriteByte(CommandUDP))
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
common.Must1(header.Write(CRLF))
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
common.Must1(header.Write(CRLF))
@ -216,7 +215,7 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Soc
}
func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
destination, err := socks5.AddressSerializer.ReadAddrPort(conn)
destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
if err != nil {
return M.Socksaddr{}, E.Cause(err, "read destination")
}
@ -241,7 +240,7 @@ func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
}
func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error {
headerOverload := socks5.AddressSerializer.AddrPortLen(destination) + 4
headerOverload := M.SocksaddrSerializer.AddrPortLen(destination) + 4
var header *buf.Buffer
var writeHeader bool
bufferLen := buffer.Len()
@ -253,7 +252,7 @@ func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) err
defer runtime.KeepAlive(_buffer)
header = buf.With(common.Dup(_buffer))
}
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))
common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
common.Must1(header.Write(CRLF))
if writeHeader {

View file

@ -12,7 +12,6 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/socks5"
)
type Handler interface {
@ -116,7 +115,7 @@ process:
goto returnErr
}
destination, err := socks5.AddressSerializer.ReadAddrPort(conn)
destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
if err != nil {
err = E.Cause(err, "read destination")
goto returnErr

View file

@ -21,13 +21,15 @@ import (
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/http"
"github.com/sagernet/sing/protocol/socks5"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks/socks4"
"github.com/sagernet/sing/protocol/socks/socks5"
"github.com/sagernet/sing/transport/tcp"
"github.com/sagernet/sing/transport/udp"
)
type Handler interface {
socks5.Handler
socks.Handler
}
type Listener struct {
@ -60,10 +62,8 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.
}
headerType, err := rw.ReadByte(conn)
switch headerType {
case socks5.Version4:
return E.New("socks4 request dropped (TODO)")
case socks5.Version5:
return socks5.HandleConnection0(ctx, conn, l.authenticator, M.AddrFromNetAddr(conn.LocalAddr()), l.handler, metadata)
case socks4.Version, socks5.Version:
return socks.HandleConnection0(ctx, conn, headerType, l.authenticator, l.handler, metadata)
}
reader := bufio.NewReader(&rw.BufferedReader{