diff --git a/.github/workflows/debug.yml b/.github/workflows/debug.yml index 8b23dfe..ebe20f4 100644 --- a/.github/workflows/debug.yml +++ b/.github/workflows/debug.yml @@ -48,4 +48,4 @@ jobs: with: go-version: 1.18.1 - name: Build - run: go build -v ./cli/ss-local \ No newline at end of file + run: go build -v ./... \ No newline at end of file diff --git a/cli/cloudflare-ddns/main.go b/cli/cloudflare-ddns/main.go index 9f3f829..91de5e8 100644 --- a/cli/cloudflare-ddns/main.go +++ b/cli/cloudflare-ddns/main.go @@ -1,3 +1,5 @@ +//go:build linux + package main import ( diff --git a/cli/portal/portal-v2board/main.go b/cli/portal/portal-v2board/main.go index c13ca27..a9c80aa 100644 --- a/cli/portal/portal-v2board/main.go +++ b/cli/portal/portal-v2board/main.go @@ -18,9 +18,8 @@ import ( E "github.com/sagernet/sing/common/exceptions" _ "github.com/sagernet/sing/common/log" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/network" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/protocol/trojan" transTLS "github.com/sagernet/sing/transport/tls" "github.com/sirupsen/logrus" @@ -177,14 +176,14 @@ func (i *TrojanInstance) NewConnection(ctx context.Context, conn net.Conn, metad userCtx := ctx.(*trojan.Context[int]) conn = i.user.TrackConnection(userCtx.User, conn) logrus.Info(i.id, ": user ", userCtx.User, " TCP ", metadata.Source, " ==> ", metadata.Destination) - destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination) + destConn, err := N.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination) if err != nil { return err } return rw.CopyConn(ctx, conn, destConn) } -func (i *TrojanInstance) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error { +func (i *TrojanInstance) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { userCtx := ctx.(*trojan.Context[int]) conn = i.user.TrackPacketConnection(userCtx.User, conn) logrus.Info(i.id, ": user ", userCtx.User, " UDP ", metadata.Source, " ==> ", metadata.Destination) @@ -192,7 +191,7 @@ func (i *TrojanInstance) NewPacketConnection(ctx context.Context, conn socks.Pac if err != nil { return err } - return socks.CopyNetPacketConn(ctx, conn, udpConn) + return N.CopyNetPacketConn(ctx, conn, udpConn) } func (i *TrojanInstance) loopRequests() { @@ -205,7 +204,7 @@ func (i *TrojanInstance) loopRequests() { go func() { hErr := i.service.NewConnection(context.Background(), conn, M.Metadata{ Protocol: "tls", - Source: M.AddrPortFromNetAddr(conn.RemoteAddr()), + Source: M.SocksaddrFromNet(conn.RemoteAddr()), }) if hErr != nil { i.HandleError(hErr) diff --git a/cli/portal/portal-v2board/traffic.go b/cli/portal/portal-v2board/traffic.go index 2149c23..cfe708f 100644 --- a/cli/portal/portal-v2board/traffic.go +++ b/cli/portal/portal-v2board/traffic.go @@ -8,7 +8,7 @@ import ( "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/protocol/socks" + N "github.com/sagernet/sing/common/network" ) type UserManager struct { @@ -40,7 +40,7 @@ func (m *UserManager) TrackConnection(userId int, conn net.Conn) net.Conn { return &TrackConn{conn, user} } -func (m *UserManager) TrackPacketConnection(userId int, conn socks.PacketConn) socks.PacketConn { +func (m *UserManager) TrackPacketConnection(userId int, conn N.PacketConn) N.PacketConn { m.access.Lock() defer m.access.Unlock() var user *User @@ -112,11 +112,11 @@ func (c *TrackConn) ReadFrom(r io.Reader) (n int64, err error) { } type TrackPacketConn struct { - socks.PacketConn + N.PacketConn *User } -func (c *TrackPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (c *TrackPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { destination, err := c.PacketConn.ReadPacket(buffer) if err == nil { atomic.AddUint64(&c.Upload, uint64(buffer.Len())) @@ -124,7 +124,7 @@ func (c *TrackPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { return destination, err } -func (c *TrackPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { +func (c *TrackPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { n := buffer.Len() err := c.PacketConn.WritePacket(buffer, destination) if err == nil { diff --git a/cli/socks-chk/main.go b/cli/socks-chk/main.go index ff54fa5..7ee6c53 100644 --- a/cli/socks-chk/main.go +++ b/cli/socks-chk/main.go @@ -11,7 +11,7 @@ import ( "github.com/sagernet/sing/common/buf" _ "github.com/sagernet/sing/common/log" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/socks5" "github.com/sirupsen/logrus" "github.com/spf13/cobra" "golang.org/x/net/dns/dnsmessage" @@ -29,11 +29,8 @@ func main() { } func run(cmd *cobra.Command, args []string) { - server, err := M.ParseAddress(args[0]) - if err != nil { - logrus.Fatal("invalid server address ", args[0]) - } - err = testSocksTCP(server) + server := M.ParseSocksaddr(args[0]) + err := testSocksTCP(server) if err != nil { logrus.Fatal(err) } @@ -43,16 +40,16 @@ func run(cmd *cobra.Command, args []string) { } } -func testSocksTCP(server *M.AddrPort) error { +func testSocksTCP(server M.Socksaddr) error { tcpConn, err := net.Dial("tcp", server.String()) if err != nil { return err } - response, err := socks.ClientHandshake(tcpConn, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.ParseAddr("1.0.0.1"), 53), "", "") + response, err := socks5.ClientHandshake(tcpConn, socks5.Version5, socks5.CommandConnect, M.ParseSocksaddrHostPort("1.0.0.1", "53"), "", "") if err != nil { return err } - if response.ReplyCode != socks.ReplyCodeSuccess { + if response.ReplyCode != socks5.ReplyCodeSuccess { logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode) } @@ -98,17 +95,17 @@ func testSocksTCP(server *M.AddrPort) error { return nil } -func testSocksUDP(server *M.AddrPort) error { +func testSocksUDP(server M.Socksaddr) error { tcpConn, err := net.Dial("tcp", server.String()) if err != nil { return err } - dest := M.AddrPortFrom(M.ParseAddr("1.0.0.1"), 53) - response, err := socks.ClientHandshake(tcpConn, socks.Version5, socks.CommandUDPAssociate, dest, "", "") + dest := M.ParseSocksaddrHostPort("1.0.0.1", "53") + response, err := socks5.ClientHandshake(tcpConn, socks5.Version5, socks5.CommandUDPAssociate, dest, "", "") if err != nil { return err } - if response.ReplyCode != socks.ReplyCodeSuccess { + if response.ReplyCode != socks5.ReplyCodeSuccess { logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode) } var dialer net.Dialer @@ -116,7 +113,7 @@ func testSocksUDP(server *M.AddrPort) error { if err != nil { return err } - assConn := socks.NewAssociateConn(tcpConn, udpConn, dest) + assConn := socks5.NewAssociateConn(tcpConn, udpConn, dest) message := &dnsmessage.Message{} message.Header.ID = 1 message.Header.RecursionDesired = true diff --git a/cli/ss-local/main.go b/cli/ss-local/main.go index 0a72171..ed3d389 100644 --- a/cli/ss-local/main.go +++ b/cli/ss-local/main.go @@ -23,6 +23,7 @@ import ( "github.com/sagernet/sing/common/geosite" _ "github.com/sagernet/sing/common/log" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/random" "github.com/sagernet/sing/common/redir" "github.com/sagernet/sing/common/rw" @@ -30,7 +31,6 @@ import ( "github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022" - "github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/transport/mixed" "github.com/sagernet/sing/transport/system" "github.com/sirupsen/logrus" @@ -101,7 +101,7 @@ Only available with Linux kernel > 3.7.0.`) type client struct { *mixed.Listener *geosite.Matcher - server *M.AddrPort + server M.Socksaddr method shadowsocks.Method dialer net.Dialer bypass string @@ -163,7 +163,7 @@ func newClient(f *flags) (*client, error) { } c := &client{ - server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort), + server: M.SocksaddrFromAddrPort(M.ParseAddr(f.Server), f.ServerPort), bypass: f.Bypass, } @@ -294,7 +294,7 @@ func newClient(f *flags) (*client, error) { return c, nil } -func bypass(conn net.Conn, destination *M.AddrPort) error { +func bypass(conn net.Conn, destination M.Socksaddr) error { logrus.Info("BYPASS ", conn.RemoteAddr(), " ==> ", destination) serverConn, err := net.Dial("tcp", destination.String()) if err != nil { @@ -313,12 +313,12 @@ func bypass(conn net.Conn, destination *M.AddrPort) error { func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { if c.bypass != "" { - if metadata.Destination.Addr.Family().IsFqdn() { - if c.Match(metadata.Destination.Addr.Fqdn()) { + if metadata.Destination.Family().IsFqdn() { + if c.Match(metadata.Destination.Fqdn) { return bypass(conn, metadata.Destination) } } else { - if geoip.Match(c.bypass, metadata.Destination.Addr.Addr().AsSlice()) { + if geoip.Match(c.bypass, metadata.Destination.Addr.AsSlice()) { return bypass(conn, metadata.Destination) } } @@ -354,14 +354,14 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me return rw.CopyConn(ctx, serverConn, conn) } -func (c *client) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error { +func (c *client) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { logrus.Info("outbound ", metadata.Protocol, " UDP ", metadata.Source, " ==> ", metadata.Destination) udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String()) if err != nil { return err } serverConn := c.method.DialPacketConn(udpConn) - return socks.CopyPacketConn(ctx, serverConn, conn) + return N.CopyPacketConn(ctx, serverConn, conn) } func run(cmd *cobra.Command, flags *flags) { diff --git a/cli/ss-server/main.go b/cli/ss-server/main.go index c092704..12475d7 100644 --- a/cli/ss-server/main.go +++ b/cli/ss-server/main.go @@ -17,13 +17,12 @@ import ( E "github.com/sagernet/sing/common/exceptions" _ "github.com/sagernet/sing/common/log" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/network" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/random" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022" - "github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/transport/tcp" "github.com/sagernet/sing/transport/udp" "github.com/sirupsen/logrus" @@ -191,23 +190,23 @@ func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me return s.service.NewConnection(ctx, conn, metadata) } logrus.Info("inbound TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination) - destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination) + destConn, err := N.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination) if err != nil { return err } return rw.CopyConn(ctx, conn, destConn) } -func (s *server) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error { +func (s *server) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { logrus.Info("inbound UDP ", metadata.Source, " ==> ", metadata.Destination) udpConn, err := net.ListenUDP("udp", nil) if err != nil { return err } - return socks.CopyNetPacketConn(ctx, conn, udpConn) + return N.CopyNetPacketConn(ctx, conn, udpConn) } -func (s *server) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { +func (s *server) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { logrus.Trace("inbound raw UDP from ", metadata.Source) return s.service.NewPacket(conn, buffer, metadata) } diff --git a/cli/trojan-local/main.go b/cli/trojan-local/main.go index 50de22a..928f550 100644 --- a/cli/trojan-local/main.go +++ b/cli/trojan-local/main.go @@ -19,9 +19,9 @@ import ( E "github.com/sagernet/sing/common/exceptions" _ "github.com/sagernet/sing/common/log" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/redir" "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/protocol/trojan" "github.com/sagernet/sing/transport/mixed" "github.com/sirupsen/logrus" @@ -148,7 +148,7 @@ func newClient(f *flags) (*client, error) { } c := &client{ - server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort).String(), + server: netip.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort).String(), key: trojan.Key(f.Password), sni: f.ServerName, insecure: f.Insecure, @@ -319,7 +319,7 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me return rw.CopyConn(ctx, clientConn, conn) } -func (c *client) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error { +func (c *client) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { logrus.Info("outbound ", metadata.Protocol, " UDP ", metadata.Source, " ==> ", metadata.Destination) tlsConn, err := c.connect(ctx) @@ -332,7 +332,7 @@ func (c *client) NewPacketConnection(ctx context.Context, conn socks.PacketConn, } return socks.CopyPacketConn(ctx, &trojan.PacketConn{Conn: tlsConn}, conn)*/ clientConn := trojan.NewClientPacketConn(tlsConn, c.key) - return socks.CopyPacketConn(ctx, clientConn, conn) + return N.CopyPacketConn(ctx, clientConn, conn) } func (c *client) HandleError(err error) { diff --git a/cli/trojan-server/main.go b/cli/trojan-server/main.go index 6cd3d15..07392f6 100644 --- a/cli/trojan-server/main.go +++ b/cli/trojan-server/main.go @@ -17,9 +17,8 @@ import ( E "github.com/sagernet/sing/common/exceptions" _ "github.com/sagernet/sing/common/log" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/network" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/protocol/trojan" "github.com/sagernet/sing/transport/tcp" transTLS "github.com/sagernet/sing/transport/tls" @@ -193,7 +192,7 @@ func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me } return s.service.NewConnection(ctx, tls.Server(conn, &s.tlsConfig), metadata) } - destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination) + destConn, err := N.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination) if err != nil { return err } @@ -201,13 +200,13 @@ func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me return rw.CopyConn(ctx, conn, destConn) } -func (s *server) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error { +func (s *server) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { logrus.Info("inbound UDP ", metadata.Source, " ==> ", metadata.Destination) udpConn, err := net.ListenUDP("udp", nil) if err != nil { return err } - return socks.CopyNetPacketConn(ctx, conn, udpConn) + return N.CopyNetPacketConn(ctx, conn, udpConn) } func (s *server) HandleError(err error) { diff --git a/cli/uot-local/main.go b/cli/uot-local/main.go index 87dcdc1..e4ba805 100644 --- a/cli/uot-local/main.go +++ b/cli/uot-local/main.go @@ -13,10 +13,11 @@ import ( E "github.com/sagernet/sing/common/exceptions" _ "github.com/sagernet/sing/common/log" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/redir" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/uot" - "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/socks5" "github.com/sagernet/sing/transport/mixed" "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -102,7 +103,7 @@ func (c *localClient) NewConnection(ctx context.Context, conn net.Conn, metadata return E.Cause(err, "connect to upstream") } - _, err = socks.ClientHandshake(upstream, socks.Version5, socks.CommandConnect, metadata.Destination, "", "") + _, err = socks5.ClientHandshake(upstream, socks5.Version5, socks5.CommandConnect, metadata.Destination, "", "") if err != nil { return E.Cause(err, "upstream handshake failed") } @@ -110,19 +111,19 @@ func (c *localClient) NewConnection(ctx context.Context, conn net.Conn, metadata return rw.CopyConn(context.Background(), upstream, conn) } -func (c *localClient) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error { +func (c *localClient) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { upstream, err := net.Dial("tcp", c.upstream) if err != nil { return E.Cause(err, "connect to upstream") } - _, err = socks.ClientHandshake(upstream, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn(uot.UOTMagicAddress), 443), "", "") + _, err = socks5.ClientHandshake(upstream, socks5.Version5, socks5.CommandConnect, M.ParseSocksaddrHostPort(uot.UOTMagicAddress, "443"), "", "") if err != nil { return E.Cause(err, "upstream handshake failed") } client := uot.NewClientConn(upstream) - return socks.CopyPacketConn(ctx, client, conn) + return N.CopyPacketConn(ctx, client, conn) } func (c *localClient) OnError(err error) { diff --git a/common/metadata/addr.go b/common/metadata/addr.go index 9b930f6..c83001c 100644 --- a/common/metadata/addr.go +++ b/common/metadata/addr.go @@ -6,93 +6,144 @@ import ( "strconv" ) -type Addr interface { - Family() Family - Addr() netip.Addr - Fqdn() string - String() string -} - -type AddrPort struct { - Addr Addr +type Socksaddr struct { + Addr netip.Addr + Fqdn string Port uint16 } -func (ap AddrPort) IPAddr() *net.IPAddr { +func (ap Socksaddr) Network() string { + return "socks" +} + +func (ap Socksaddr) IsIP() bool { + return ap.Addr.IsValid() +} + +func (ap Socksaddr) IsFqdn() bool { + return !ap.IsIP() +} + +func (ap Socksaddr) IsValid() bool { + return ap.Addr.IsValid() || ap.Fqdn != "" +} + +func (ap Socksaddr) Family() Family { + if ap.Addr.IsValid() { + if ap.Addr.Is4() { + return AddressFamilyIPv4 + } else { + return AddressFamilyIPv6 + } + } + if ap.Fqdn != "" { + return AddressFamilyFqdn + } else if ap.Addr.Is4() || ap.Addr.Is4In6() { + return AddressFamilyIPv4 + } else { + return AddressFamilyIPv6 + } +} + +func (ap Socksaddr) AddrString() string { + if ap.Addr.IsValid() { + return ap.Addr.String() + } else { + return ap.Fqdn + } +} + +func (ap Socksaddr) IPAddr() *net.IPAddr { return &net.IPAddr{ - IP: ap.Addr.Addr().AsSlice(), + IP: ap.Addr.AsSlice(), } } -func (ap AddrPort) TCPAddr() *net.TCPAddr { +func (ap Socksaddr) TCPAddr() *net.TCPAddr { return &net.TCPAddr{ - IP: ap.Addr.Addr().AsSlice(), + IP: ap.Addr.AsSlice(), Port: int(ap.Port), } } -func (ap AddrPort) UDPAddr() *net.UDPAddr { +func (ap Socksaddr) UDPAddr() *net.UDPAddr { return &net.UDPAddr{ - IP: ap.Addr.Addr().AsSlice(), + IP: ap.Addr.AsSlice(), Port: int(ap.Port), } } -func (ap AddrPort) AddrPort() netip.AddrPort { - return netip.AddrPortFrom(ap.Addr.Addr(), ap.Port) +func (ap Socksaddr) AddrPort() netip.AddrPort { + return netip.AddrPortFrom(ap.Addr, ap.Port) } -func (ap AddrPort) String() string { - return net.JoinHostPort(ap.Addr.String(), strconv.Itoa(int(ap.Port))) +func (ap Socksaddr) String() string { + return net.JoinHostPort(ap.AddrString(), strconv.Itoa(int(ap.Port))) } -func ParseAddr(address string) Addr { - addr, err := netip.ParseAddr(address) - if err == nil { - return AddrFromAddr(addr) +func TCPAddr(ap netip.AddrPort) *net.TCPAddr { + return &net.TCPAddr{ + IP: ap.Addr().AsSlice(), + Port: int(ap.Port()), } - return AddrFromFqdn(address) } -func AddrPortFrom(addr Addr, port uint16) *AddrPort { - return &AddrPort{addr, port} -} - -func ParseAddress(address string) (*AddrPort, error) { - host, port, err := net.SplitHostPort(address) - if err != nil { - return nil, err +func UDPAddr(ap netip.AddrPort) *net.UDPAddr { + return &net.UDPAddr{ + IP: ap.Addr().AsSlice(), + Port: int(ap.Port()), } - portInt, err := strconv.Atoi(port) - if err != nil { - return nil, err - } - return AddrPortFrom(ParseAddr(host), uint16(portInt)), nil } -func ParseAddrPort(address string, port string) (*AddrPort, error) { - portInt, err := strconv.Atoi(port) - if err != nil { - return nil, err - } - return AddrPortFrom(ParseAddr(address), uint16(portInt)), nil +func AddrPortFrom(ip net.IP, port uint16) netip.AddrPort { + addr, _ := netip.AddrFromSlice(ip) + return netip.AddrPortFrom(addr, port) } -func AddrFromNetAddr(netAddr net.Addr) Addr { +func SocksaddrFrom(ip net.IP, port uint16) Socksaddr { + return SocksaddrFromNetIP(AddrPortFrom(ip, port)) +} + +func SocksaddrFromAddrPort(addr netip.Addr, port uint16) Socksaddr { + return SocksaddrFromNetIP(netip.AddrPortFrom(addr, port)) +} + +func SocksaddrFromNetIP(ap netip.AddrPort) Socksaddr { + return Socksaddr{ + Addr: ap.Addr(), + Port: ap.Port(), + } +} + +func SocksaddrFromNet(ap net.Addr) Socksaddr { + if socksAddr, ok := ap.(Socksaddr); ok { + return socksAddr + } + return SocksaddrFromNetIP(AddrPortFromNet(ap)) +} + +func AddrFromNetAddr(netAddr net.Addr) netip.Addr { + if addr := AddrPortFromNet(netAddr); addr.Addr().IsValid() { + return addr.Addr() + } switch addr := netAddr.(type) { + case Socksaddr: + return addr.Addr case *net.IPAddr: return AddrFromIP(addr.IP) case *net.IPNet: return AddrFromIP(addr.IP) default: - return nil + return netip.Addr{} } } -func AddrPortFromNetAddr(netAddr net.Addr) *AddrPort { +func AddrPortFromNet(netAddr net.Addr) netip.AddrPort { var ip net.IP var port uint16 switch addr := netAddr.(type) { + case Socksaddr: + return addr.AddrPort() case *net.TCPAddr: ip = addr.IP port = uint16(addr.Port) @@ -102,84 +153,39 @@ func AddrPortFromNetAddr(netAddr net.Addr) *AddrPort { case *net.IPAddr: ip = addr.IP } - return AddrPortFrom(AddrFromIP(ip), port) + return netip.AddrPortFrom(AddrFromIP(ip), port) } -func AddrFromIP(ip net.IP) Addr { +func AddrFromIP(ip net.IP) netip.Addr { addr, _ := netip.AddrFromSlice(ip) - if addr.Is4() || addr.Is4In6() { - return Addr4(addr.As4()) + return addr +} + +func ParseAddr(s string) netip.Addr { + addr, _ := netip.ParseAddr(s) + return addr +} + +func ParseSocksaddr(address string) Socksaddr { + host, port, err := net.SplitHostPort(address) + if err != nil { + return Socksaddr{} + } + return ParseSocksaddrHostPort(host, port) +} + +func ParseSocksaddrHostPort(host string, portStr string) Socksaddr { + port, _ := strconv.Atoi(portStr) + netAddr, err := netip.ParseAddr(host) + if err != nil { + return Socksaddr{ + Fqdn: host, + Port: uint16(port), + } } else { - return Addr16(addr.As16()) + return Socksaddr{ + Addr: netAddr, + Port: uint16(port), + } } } - -func AddrFromAddr(addr netip.Addr) Addr { - if addr.Is4() && addr.Is4In6() { - return Addr4(addr.As4()) - } else { - return Addr16(addr.As16()) - } -} - -func AddrPortFromAddrPort(addrPort netip.AddrPort) *AddrPort { - return AddrPortFrom(AddrFromAddr(addrPort.Addr()), addrPort.Port()) -} - -func AddrFromFqdn(fqdn string) Addr { - return AddrFqdn(fqdn) -} - -type Addr4 [4]byte - -func (a Addr4) Family() Family { - return AddressFamilyIPv4 -} - -func (a Addr4) Addr() netip.Addr { - return netip.AddrFrom4(a) -} - -func (a Addr4) Fqdn() string { - return "" -} - -func (a Addr4) String() string { - return netip.AddrFrom4(a).String() -} - -type Addr16 [16]byte - -func (a Addr16) Family() Family { - return AddressFamilyIPv6 -} - -func (a Addr16) Addr() netip.Addr { - return netip.AddrFrom16(a) -} - -func (a Addr16) Fqdn() string { - return "" -} - -func (a Addr16) String() string { - return netip.AddrFrom16(a).String() -} - -type AddrFqdn string - -func (f AddrFqdn) Family() Family { - return AddressFamilyFqdn -} - -func (f AddrFqdn) Addr() netip.Addr { - return netip.Addr{} -} - -func (f AddrFqdn) Fqdn() string { - return string(f) -} - -func (f AddrFqdn) String() string { - return string(f) -} diff --git a/common/metadata/metadata.go b/common/metadata/metadata.go index f91646a..10c8d45 100644 --- a/common/metadata/metadata.go +++ b/common/metadata/metadata.go @@ -7,8 +7,8 @@ import ( type Metadata struct { Protocol string - Source *AddrPort - Destination *AddrPort + Source Socksaddr + Destination Socksaddr } type TCPConnectionHandler interface { diff --git a/common/metadata/serializer.go b/common/metadata/serializer.go index d0c8e25..099b70b 100644 --- a/common/metadata/serializer.go +++ b/common/metadata/serializer.go @@ -3,6 +3,7 @@ package metadata import ( "encoding/binary" "io" + "net/netip" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" @@ -41,28 +42,27 @@ func NewSerializer(options ...SerializerOption) *Serializer { return s } -func (s *Serializer) WriteAddress(writer io.Writer, addr Addr) error { +func (s *Serializer) WriteAddress(writer io.Writer, addr Socksaddr) error { err := rw.WriteByte(writer, s.familyByteMap[addr.Family()]) if err != nil { return err } - if addr.Family().IsIP() { - err = rw.WriteBytes(writer, addr.Addr().AsSlice()) + if addr.Addr.IsValid() { + err = rw.WriteBytes(writer, addr.Addr.AsSlice()) } else { - domain := addr.Fqdn() - err = WriteString(writer, "fqdn", domain) + err = WriteString(writer, "fqdn", addr.Fqdn) } return err } -func (s *Serializer) AddressLen(addr Addr) int { +func (s *Serializer) AddressLen(addr Socksaddr) int { switch addr.Family() { case AddressFamilyIPv4: return 5 case AddressFamilyIPv6: return 17 default: - return 2 + len(addr.Fqdn()) + return 2 + len(addr.Fqdn) } } @@ -70,10 +70,10 @@ func (s *Serializer) WritePort(writer io.Writer, port uint16) error { return binary.Write(writer, binary.BigEndian, port) } -func (s *Serializer) WriteAddrPort(writer io.Writer, destination *AddrPort) error { +func (s *Serializer) WriteAddrPort(writer io.Writer, destination Socksaddr) error { var err error if !s.portFirst { - err = s.WriteAddress(writer, destination.Addr) + err = s.WriteAddress(writer, destination) } else { err = s.WritePort(writer, destination.Port) } @@ -81,48 +81,50 @@ func (s *Serializer) WriteAddrPort(writer io.Writer, destination *AddrPort) erro return err } if s.portFirst { - err = s.WriteAddress(writer, destination.Addr) + err = s.WriteAddress(writer, destination) } else { err = s.WritePort(writer, destination.Port) } return err } -func (s *Serializer) AddrPortLen(destination *AddrPort) int { - return s.AddressLen(destination.Addr) + 2 +func (s *Serializer) AddrPortLen(destination Socksaddr) int { + return s.AddressLen(destination) + 2 } -func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) { +func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) { af, err := rw.ReadByte(reader) if err != nil { - return nil, err + return Socksaddr{}, err } family := s.familyMap[af] switch family { case AddressFamilyFqdn: fqdn, err := ReadString(reader) if err != nil { - return nil, E.Cause(err, "read fqdn") + return Socksaddr{}, E.Cause(err, "read fqdn") } - return AddrFqdn(fqdn), nil + return Socksaddr{ + Fqdn: fqdn, + }, nil default: switch family { case AddressFamilyIPv4: var addr [4]byte err = common.Error(reader.Read(addr[:])) if err != nil { - return nil, E.Cause(err, "read ipv4 address") + return Socksaddr{}, E.Cause(err, "read ipv4 address") } - return Addr4(addr), nil + return Socksaddr{Addr: netip.AddrFrom4(addr)}, nil case AddressFamilyIPv6: var addr [16]byte err = common.Error(reader.Read(addr[:])) if err != nil { - return nil, E.Cause(err, "read ipv6 address") + return Socksaddr{}, E.Cause(err, "read ipv6 address") } - return Addr16(addr), nil + return Socksaddr{Addr: netip.AddrFrom16(addr)}, nil default: - return nil, E.New("unknown address family: ", af) + return Socksaddr{}, E.New("unknown address family: ", af) } } } @@ -135,8 +137,8 @@ func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) { return binary.BigEndian.Uint16(port), nil } -func (s *Serializer) ReadAddrPort(reader io.Reader) (destination *AddrPort, err error) { - var addr Addr +func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err error) { + var addr Socksaddr var port uint16 if !s.portFirst { addr, err = s.ReadAddress(reader) @@ -154,7 +156,8 @@ func (s *Serializer) ReadAddrPort(reader io.Reader) (destination *AddrPort, err if err != nil { return } - return AddrPortFrom(addr, port), nil + addr.Port = port + return addr, nil } func ReadString(reader io.Reader) (string, error) { diff --git a/common/network/addr.go b/common/network/addr.go index 10a20c7..73f523b 100644 --- a/common/network/addr.go +++ b/common/network/addr.go @@ -13,12 +13,8 @@ func LocalAddrs() ([]netip.Addr, error) { if err != nil { return nil, err } - return common.Map(common.Filter(common.Map(interfaceAddrs, func(addr net.Addr) M.Addr { + return common.Map(interfaceAddrs, func(addr net.Addr) netip.Addr { return M.AddrFromNetAddr(addr) - }), func(addr M.Addr) bool { - return addr != nil - }), func(it M.Addr) netip.Addr { - return it.Addr() }), nil } diff --git a/protocol/socks/packet_conn.go b/common/network/conn.go similarity index 67% rename from protocol/socks/packet_conn.go rename to common/network/conn.go index 125c6e0..01c6946 100644 --- a/protocol/socks/packet_conn.go +++ b/common/network/conn.go @@ -1,4 +1,4 @@ -package socks +package network import ( "context" @@ -14,11 +14,11 @@ import ( ) type PacketReader interface { - ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) + ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) } type PacketWriter interface { - WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error + WritePacket(buffer *buf.Buffer, addr M.Socksaddr) error } type PacketConn interface { @@ -27,7 +27,6 @@ type PacketConn interface { Close() error LocalAddr() net.Addr - RemoteAddr() net.Addr SetDeadline(t time.Time) error SetReadDeadline(t time.Time) error SetWriteDeadline(t time.Time) error @@ -100,25 +99,49 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error } func CopyNetPacketConn(ctx context.Context, conn PacketConn, dest net.PacketConn) error { - return CopyPacketConn(ctx, conn, &PacketConnWrapper{dest}) + if udpConn, ok := dest.(*net.UDPConn); ok { + return CopyPacketConn(ctx, conn, &UDPConnWrapper{udpConn}) + } else { + return CopyPacketConn(ctx, conn, &PacketConnWrapper{dest}) + } +} + +type UDPConnWrapper struct { + *net.UDPConn +} + +func (w *UDPConnWrapper) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + n, addr, err := w.ReadFromUDPAddrPort(buffer.FreeBytes()) + if err != nil { + return M.Socksaddr{}, err + } + buffer.Truncate(n) + return M.SocksaddrFromNetIP(addr), nil +} + +func (w *UDPConnWrapper) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if destination.Family().IsFqdn() { + udpAddr, err := net.ResolveUDPAddr("udp", destination.String()) + if err != nil { + return err + } + return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr)) + } + return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort())) } type PacketConnWrapper struct { net.PacketConn } -func (p *PacketConnWrapper) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (p *PacketConnWrapper) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { _, addr, err := buffer.ReadPacketFrom(p) if err != nil { - return nil, err + return M.Socksaddr{}, err } - return M.AddrPortFromNetAddr(addr), err + return M.SocksaddrFromNet(addr), err } -func (p *PacketConnWrapper) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { +func (p *PacketConnWrapper) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return common.Error(p.WriteTo(buffer.Bytes(), destination.UDPAddr())) } - -func (p *PacketConnWrapper) RemoteAddr() net.Addr { - return &common.DummyAddr{} -} diff --git a/common/network/dialer.go b/common/network/dialer.go index fe1b359..59e4388 100644 --- a/common/network/dialer.go +++ b/common/network/dialer.go @@ -8,7 +8,7 @@ import ( ) type ContextDialer interface { - DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error) + DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) } var SystemDialer ContextDialer = &DefaultDialer{} @@ -21,7 +21,7 @@ func (d *DefaultDialer) ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPC return net.ListenUDP(network, laddr) } -func (d *DefaultDialer) DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error) { +func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) { return d.Dialer.DialContext(ctx, network, address.String()) } diff --git a/common/redir/redir_linux.go b/common/redir/redir_linux.go index 61ea559..abb1b1a 100644 --- a/common/redir/redir_linux.go +++ b/common/redir/redir_linux.go @@ -2,12 +2,13 @@ package redir import ( "net" + "net/netip" "syscall" M "github.com/sagernet/sing/common/metadata" ) -func GetOriginalDestination(conn net.Conn) (destination *M.AddrPort, err error) { +func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err error) { rawConn, err := conn.(syscall.Conn).SyscallConn() if err != nil { return @@ -23,14 +24,14 @@ func GetOriginalDestination(conn net.Conn) (destination *M.AddrPort, err error) if conn.RemoteAddr().(*net.TCPAddr).IP.To4() != nil { raw, err := syscall.GetsockoptIPv6Mreq(int(rawFd), syscall.IPPROTO_IP, SO_ORIGINAL_DST) if err != nil { - return nil, err + return netip.AddrPort{}, err } - return M.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil + return netip.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil } else { raw, err := syscall.GetsockoptIPv6MTUInfo(int(rawFd), syscall.IPPROTO_IPV6, SO_ORIGINAL_DST) if err != nil { - return nil, err + return netip.AddrPort{}, err } - return M.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), raw.Addr.Port), nil + return netip.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), raw.Addr.Port), nil } } diff --git a/common/redir/redir_other.go b/common/redir/redir_other.go index 379bc98..a24a270 100644 --- a/common/redir/redir_other.go +++ b/common/redir/redir_other.go @@ -5,10 +5,9 @@ package redir import ( "errors" "net" - - M "github.com/sagernet/sing/common/metadata" + "net/netip" ) -func GetOriginalDestination(conn net.Conn) (destination *M.AddrPort, err error) { - return nil, errors.New("unsupported platform") +func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err error) { + return netip.AddrPort{}, errors.New("unsupported platform") } diff --git a/common/redir/tproxy_linux.go b/common/redir/tproxy_linux.go index e41159c..3c310c3 100644 --- a/common/redir/tproxy_linux.go +++ b/common/redir/tproxy_linux.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "os" "strconv" "syscall" @@ -36,19 +37,19 @@ func FWMark(fd uintptr, mark int) error { return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark) } -func GetOriginalDestinationFromOOB(oob []byte) (*M.AddrPort, error) { +func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) { controlMessages, err := unix.ParseSocketControlMessage(oob) if err != nil { - return nil, err + return netip.AddrPort{}, err } for _, message := range controlMessages { if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR { - return M.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil + return netip.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil } else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR { - return M.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil + return netip.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil } } - return nil, E.New("not found") + return netip.AddrPort{}, E.New("not found") } func DialUDP(network string, lAddr *net.UDPAddr, rAddr *net.UDPAddr) (*net.UDPConn, error) { diff --git a/common/redir/tproxy_other.go b/common/redir/tproxy_other.go index 9d9090f..26778e6 100644 --- a/common/redir/tproxy_other.go +++ b/common/redir/tproxy_other.go @@ -4,9 +4,9 @@ package redir import ( "net" + "net/netip" E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" ) func TProxy(fd uintptr, isIPv6 bool) error { @@ -21,8 +21,8 @@ func FWMark(fd uintptr, mark int) error { return E.New("only available on linux") } -func GetOriginalDestinationFromOOB(oob []byte) (*M.AddrPort, error) { - return nil, E.New("only available on linux") +func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) { + return netip.AddrPort{}, E.New("only available on linux") } func DialUDP(network string, lAddr *net.UDPAddr, rAddr *net.UDPAddr) (*net.UDPConn, error) { diff --git a/common/session/context.go b/common/session/context.go deleted file mode 100644 index 2e85fab..0000000 --- a/common/session/context.go +++ /dev/null @@ -1,66 +0,0 @@ -package session - -import ( - "net" - "strconv" - - "github.com/sagernet/sing/common/buf" - M "github.com/sagernet/sing/common/metadata" -) - -type Network int - -const ( - NetworkTCP Network = iota - NetworkUDP -) - -type InstanceContext struct{} - -type Context struct { - InstanceContext - Network Network - Source M.Addr - Destination M.Addr - SourcePort uint16 - DestinationPort uint16 -} - -func (c Context) DestinationNetAddr() string { - return net.JoinHostPort(c.Destination.String(), strconv.Itoa(int(c.DestinationPort))) -} - -func AddressFromNetAddr(netAddr net.Addr) (addr M.Addr, port uint16) { - var ip net.IP - switch addr := netAddr.(type) { - case *net.TCPAddr: - ip = addr.IP - port = uint16(addr.Port) - case *net.UDPAddr: - ip = addr.IP - port = uint16(addr.Port) - } - return M.AddrFromIP(ip), port -} - -type Conn struct { - Conn net.Conn - Context *Context -} - -type PacketConn struct { - Conn net.PacketConn - Context *Context -} - -type Packet struct { - Context *Context - Data *buf.Buffer - WriteBack func(buffer *buf.Buffer, addr *net.UDPAddr) error - Release func() -} - -type Handler interface { - HandleConnection(conn *Conn) - HandlePacket(packet *Packet) -} diff --git a/common/session/pool.go b/common/session/pool.go deleted file mode 100644 index b522d1e..0000000 --- a/common/session/pool.go +++ /dev/null @@ -1,63 +0,0 @@ -package session - -import ( - "container/list" - "sync" - - "github.com/sagernet/sing/common" -) - -var ( - connectionPool list.List - connectionPoolEnabled bool - connectionAccess sync.Mutex -) - -func EnableConnectionPool() { - connectionPoolEnabled = true -} - -func DisableConnectionPool() { - connectionAccess.Lock() - defer connectionAccess.Unlock() - connectionPoolEnabled = false - clearConnections() -} - -func AddConnection(connection any) any { - if !connectionPoolEnabled { - return connection - } - connectionAccess.Lock() - defer connectionAccess.Unlock() - return connectionPool.PushBack(connection) -} - -func RemoveConnection(anyElement any) { - element, ok := anyElement.(*list.Element) - if !ok { - common.Close(anyElement) - return - } - if element.Value == nil { - return - } - common.Close(element.Value) - element.Value = nil - connectionAccess.Lock() - defer connectionAccess.Unlock() - connectionPool.Remove(element) -} - -func ResetConnections() { - connectionAccess.Lock() - defer connectionAccess.Unlock() - clearConnections() -} - -func clearConnections() { - for element := connectionPool.Front(); element != nil; element = element.Next() { - common.Close(element) - } - connectionPool.Init() -} diff --git a/common/tun/system/tun.go b/common/tun/system/tun.go new file mode 100644 index 0000000..04d5c13 --- /dev/null +++ b/common/tun/system/tun.go @@ -0,0 +1,420 @@ +package system + +import ( + "context" + "net" + "net/netip" + "os" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/cache" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/log" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/tun" + "github.com/sagernet/sing/common/udpnat" + "gvisor.dev/gvisor/pkg/tcpip" + tcpipBuffer "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var logger = log.NewLogger("tun ") + +type Stack struct { + tunFd uintptr + tunMtu int + inetAddress netip.Prefix + inet6Address netip.Prefix + + handler tun.Handler + + tunFile *os.File + tcpForwarder *net.TCPListener + tcpPort uint16 + tcpSessions *cache.LruCache[netip.AddrPort, netip.AddrPort] + udpNat *udpnat.Service[netip.AddrPort] +} + +func New(tunFd uintptr, tunMtu int, inetAddress netip.Prefix, inet6Address netip.Prefix, packetTimeout int64, handler tun.Handler) tun.Stack { + return &Stack{ + tunFd: tunFd, + tunMtu: tunMtu, + inetAddress: inetAddress, + inet6Address: inet6Address, + handler: handler, + tunFile: os.NewFile(tunFd, "tun"), + tcpSessions: cache.New( + cache.WithAge[netip.AddrPort, netip.AddrPort](packetTimeout), + cache.WithUpdateAgeOnGet[netip.AddrPort, netip.AddrPort](), + ), + udpNat: udpnat.New[netip.AddrPort](packetTimeout, handler), + } +} + +func (t *Stack) Start() error { + var network string + var address net.TCPAddr + if !t.inet6Address.IsValid() { + network = "tcp4" + address.IP = t.inetAddress.Addr().AsSlice() + } else { + network = "tcp" + address.IP = net.IPv6zero + } + + tcpListener, err := net.ListenTCP(network, &address) + if err != nil { + return err + } + + t.tcpForwarder = tcpListener + + go t.tcpLoop() + go t.tunLoop() + + return nil +} + +func (t *Stack) Close() error { + t.tcpForwarder.Close() + t.tunFile.Close() + return nil +} + +func (t *Stack) tunLoop() { + _buffer := buf.Make(t.tunMtu) + buffer := common.Dup(_buffer) + for { + n, err := t.tunFile.Read(buffer) + if err != nil { + t.handler.HandleError(err) + break + } + packet := buffer[:n] + t.deliverPacket(packet) + } +} + +func (t *Stack) deliverPacket(packet []byte) { + var err error + switch header.IPVersion(packet) { + case header.IPv4Version: + ipHdr := header.IPv4(packet) + switch ipHdr.TransportProtocol() { + case header.TCPProtocolNumber: + err = t.processIPv4TCP(ipHdr, ipHdr.Payload()) + case header.UDPProtocolNumber: + err = t.processIPv4UDP(ipHdr, ipHdr.Payload()) + default: + _, err = t.tunFile.Write(packet) + } + case header.IPv6Version: + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: tcpipBuffer.View(packet).ToVectorisedView(), + }) + proto, _, _, _, ok := parse.IPv6(pkt) + pkt.DecRef() + if !ok { + return + } + ipHdr := header.IPv6(packet) + switch proto { + case header.TCPProtocolNumber: + err = t.processIPv6TCP(ipHdr, ipHdr.Payload()) + case header.UDPProtocolNumber: + err = t.processIPv6UDP(ipHdr, ipHdr.Payload()) + default: + _, err = t.tunFile.Write(packet) + } + } + if err != nil { + t.handler.HandleError(err) + } +} + +func (t *Stack) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) error { + sourceAddress := ipHdr.SourceAddress() + destinationAddress := ipHdr.DestinationAddress() + sourcePort := tcpHdr.SourcePort() + destinationPort := tcpHdr.DestinationPort() + + logger.Trace(sourceAddress, ":", sourcePort, " => ", destinationAddress, ":", destinationPort) + + if sourcePort != t.tcpPort { + key := M.AddrPortFrom(net.IP(destinationAddress), sourcePort) + t.tcpSessions.LoadOrStore(key, func() netip.AddrPort { + return M.AddrPortFrom(net.IP(sourceAddress), destinationPort) + }) + ipHdr.SetSourceAddress(destinationAddress) + ipHdr.SetDestinationAddress(tcpip.Address(t.inetAddress.Addr().AsSlice())) + tcpHdr.SetDestinationPort(t.tcpPort) + } else { + key := M.AddrPortFrom(net.IP(destinationAddress), destinationPort) + session, loaded := t.tcpSessions.Load(key) + if !loaded { + return E.New("unknown tcp session with source port ", destinationPort, " to destination address ", destinationAddress) + } + ipHdr.SetSourceAddress(destinationAddress) + tcpHdr.SetSourcePort(session.Port()) + ipHdr.SetDestinationAddress(tcpip.Address(session.Addr().AsSlice())) + } + + ipHdr.SetChecksum(0) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + tcpHdr.SetChecksum(0) + tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.ChecksumCombine( + header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), uint16(len(tcpHdr))), + header.Checksum(tcpHdr.Payload(), 0), + ))) + + _, err := t.tunFile.Write(ipHdr) + return err +} + +func (t *Stack) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) error { + sourceAddress := ipHdr.SourceAddress() + destinationAddress := ipHdr.DestinationAddress() + sourcePort := tcpHdr.SourcePort() + destinationPort := tcpHdr.DestinationPort() + + if sourcePort != t.tcpPort { + key := M.AddrPortFrom(net.IP(destinationAddress), sourcePort) + t.tcpSessions.LoadOrStore(key, func() netip.AddrPort { + return M.AddrPortFrom(net.IP(sourceAddress), destinationPort) + }) + ipHdr.SetSourceAddress(destinationAddress) + ipHdr.SetDestinationAddress(tcpip.Address(t.inet6Address.Addr().AsSlice())) + tcpHdr.SetDestinationPort(t.tcpPort) + } else { + key := M.AddrPortFrom(net.IP(destinationAddress), destinationPort) + session, loaded := t.tcpSessions.Load(key) + if !loaded { + return E.New("unknown tcp session with source port ", destinationPort, " to destination address ", destinationAddress) + } + ipHdr.SetSourceAddress(destinationAddress) + tcpHdr.SetSourcePort(session.Port()) + ipHdr.SetDestinationAddress(tcpip.Address(session.Addr().AsSlice())) + } + + tcpHdr.SetChecksum(0) + tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.ChecksumCombine( + header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), uint16(len(tcpHdr))), + header.Checksum(tcpHdr.Payload(), 0), + ))) + + _, err := t.tunFile.Write(ipHdr) + return err +} + +func (t *Stack) tcpLoop() { + for { + logger.Trace("tcp start") + tcpConn, err := t.tcpForwarder.AcceptTCP() + logger.Trace("tcp accept") + if err != nil { + t.handler.HandleError(err) + return + } + key := M.AddrPortFromNet(tcpConn.RemoteAddr()) + session, ok := t.tcpSessions.Load(key) + if !ok { + tcpConn.Close() + logger.Warn("dropped unknown tcp session from ", key) + continue + } + + var metadata M.Metadata + metadata.Protocol = "tun" + metadata.Source.Addr = session.Addr() + metadata.Source.Port = key.Port() + metadata.Destination.Addr = key.Addr() + metadata.Destination.Port = session.Port() + + go t.processConn(tcpConn, metadata, key) + } +} + +func (t *Stack) processConn(conn *net.TCPConn, metadata M.Metadata, key netip.AddrPort) { + err := t.handler.NewConnection(context.Background(), conn, metadata) + if err != nil { + t.handler.HandleError(err) + } + t.tcpSessions.Delete(key) +} + +func (t *Stack) processIPv4UDP(ipHdr header.IPv4, hdr header.UDP) error { + var metadata M.Metadata + metadata.Protocol = "tun" + metadata.Source = M.SocksaddrFrom(net.IP(ipHdr.SourceAddress()), hdr.SourcePort()) + metadata.Source = M.SocksaddrFrom(net.IP(ipHdr.DestinationAddress()), hdr.DestinationPort()) + + headerCache := buf.New() + _, err := headerCache.Write(ipHdr[:ipHdr.HeaderLength()+header.UDPMinimumSize]) + if err != nil { + return err + } + + logger.Trace("[UDP] ", metadata.Source, "=>", metadata.Destination) + + t.udpNat.NewPacket(metadata.Source.AddrPort(), func() N.PacketWriter { + return &inetPacketWriter{ + tun: t, + headerCache: headerCache, + sourceAddress: ipHdr.SourceAddress(), + destination: ipHdr.DestinationAddress(), + destinationPort: hdr.DestinationPort(), + } + }, buf.With(hdr), metadata) + return nil +} + +type inetPacketWriter struct { + tun *Stack + headerCache *buf.Buffer + sourceAddress tcpip.Address + destination tcpip.Address + destinationPort uint16 +} + +func (w *inetPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + index := w.headerCache.Len() + newHeader := w.headerCache.Extend(w.headerCache.Len()) + copy(newHeader, w.headerCache.Bytes()) + w.headerCache.Advance(index) + + defer func() { + w.headerCache.FullReset() + w.headerCache.Resize(0, index) + }() + + var newSourceAddress tcpip.Address + var newSourcePort uint16 + + if destination.IsValid() { + newSourceAddress = tcpip.Address(destination.Addr.AsSlice()) + newSourcePort = destination.Port + } else { + newSourceAddress = w.destination + newSourcePort = w.destinationPort + } + + newIpHdr := header.IPv4(newHeader) + newIpHdr.SetSourceAddress(newSourceAddress) + newIpHdr.SetTotalLength(uint16(int(w.headerCache.Len()) + buffer.Len())) + newIpHdr.SetChecksum(0) + newIpHdr.SetChecksum(^newIpHdr.CalculateChecksum()) + + udpHdr := header.UDP(w.headerCache.From(w.headerCache.Len() - header.UDPMinimumSize)) + udpHdr.SetSourcePort(newSourcePort) + udpHdr.SetLength(uint16(header.UDPMinimumSize + buffer.Len())) + udpHdr.SetChecksum(0) + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(header.Checksum(buffer.Bytes(), header.PseudoHeaderChecksum(header.UDPProtocolNumber, newSourceAddress, w.sourceAddress, uint16(header.UDPMinimumSize+buffer.Len()))))) + + replyVV := tcpipBuffer.VectorisedView{} + replyVV.AppendView(newHeader) + replyVV.AppendView(buffer.Bytes()) + + return w.tun.WriteVV(replyVV) +} + +func (w *inetPacketWriter) Close() error { + w.headerCache.Release() + return nil +} + +func (t *Stack) processIPv6UDP(ipHdr header.IPv6, hdr header.UDP) error { + var metadata M.Metadata + metadata.Protocol = "tun" + metadata.Source = M.SocksaddrFrom(net.IP(ipHdr.SourceAddress()), hdr.SourcePort()) + metadata.Destination = M.SocksaddrFrom(net.IP(ipHdr.DestinationAddress()), hdr.DestinationPort()) + + headerCache := buf.New() + _, err := headerCache.Write(ipHdr[:uint16(len(ipHdr))-ipHdr.PayloadLength()+header.UDPMinimumSize]) + if err != nil { + return err + } + + t.udpNat.NewPacket(metadata.Source.AddrPort(), func() N.PacketWriter { + return &inet6PacketWriter{ + tun: t, + headerCache: headerCache, + sourceAddress: ipHdr.SourceAddress(), + destination: ipHdr.DestinationAddress(), + destinationPort: hdr.DestinationPort(), + } + }, buf.With(hdr), metadata) + return nil +} + +type inet6PacketWriter struct { + tun *Stack + headerCache *buf.Buffer + sourceAddress tcpip.Address + destination tcpip.Address + destinationPort uint16 +} + +func (w *inet6PacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + index := w.headerCache.Len() + newHeader := w.headerCache.Extend(w.headerCache.Len()) + copy(newHeader, w.headerCache.Bytes()) + w.headerCache.Advance(index) + + defer func() { + w.headerCache.FullReset() + w.headerCache.Resize(0, index) + }() + + var newSourceAddress tcpip.Address + var newSourcePort uint16 + + if destination.IsValid() { + newSourceAddress = tcpip.Address(destination.Addr.AsSlice()) + newSourcePort = destination.Port + } else { + newSourceAddress = w.destination + newSourcePort = w.destinationPort + } + + newIpHdr := header.IPv6(newHeader) + newIpHdr.SetSourceAddress(newSourceAddress) + newIpHdr.SetPayloadLength(uint16(header.UDPMinimumSize + buffer.Len())) + + udpHdr := header.UDP(w.headerCache.From(w.headerCache.Len() - header.UDPMinimumSize)) + udpHdr.SetSourcePort(newSourcePort) + udpHdr.SetLength(uint16(header.UDPMinimumSize + buffer.Len())) + udpHdr.SetChecksum(0) + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(header.Checksum(buffer.Bytes(), header.PseudoHeaderChecksum(header.UDPProtocolNumber, newSourceAddress, w.sourceAddress, uint16(header.UDPMinimumSize+buffer.Len()))))) + + replyVV := tcpipBuffer.VectorisedView{} + replyVV.AppendView(newHeader) + replyVV.AppendView(buffer.Bytes()) + + return w.tun.WriteVV(replyVV) +} + +func (t *Stack) WriteVV(vv tcpipBuffer.VectorisedView) error { + data := make([][]byte, 0, len(vv.Views())) + for _, view := range vv.Views() { + data = append(data, view) + } + return common.Error(rw.WriteV(t.tunFd, data...)) +} + +func (w *inet6PacketWriter) Close() error { + w.headerCache.Release() + return nil +} + +type tcpipError struct { + Err tcpip.Error +} + +func (e *tcpipError) Error() string { + return e.Err.String() +} diff --git a/common/tun/tun.go b/common/tun/tun.go new file mode 100644 index 0000000..d002ba8 --- /dev/null +++ b/common/tun/tun.go @@ -0,0 +1,18 @@ +package tun + +import ( + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type Handler interface { + M.TCPConnectionHandler + N.UDPConnectionHandler + E.Handler +} + +type Stack interface { + Start() error + Close() error +} diff --git a/common/tun/tun_linux.go b/common/tun/tun_linux.go new file mode 100644 index 0000000..514774d --- /dev/null +++ b/common/tun/tun_linux.go @@ -0,0 +1,171 @@ +package tun + +/* +import ( + "bytes" + "net" + "syscall" + "unsafe" + + E "github.com/sagernet/sing/common/exceptions" + "golang.org/x/sys/unix" +) + +const ifReqSize = unix.IFNAMSIZ + 64 + +func (t *Interface) Name() (string, error) { + if t.tunName != "" { + return t.tunName, nil + } + var ifr [ifReqSize]byte + _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(t.tunFd), uintptr(unix.TUNGETIFF), uintptr(unsafe.Pointer(&ifr[0]))) + if errno != 0 { + return "", errno + } + name := ifr[:] + if i := bytes.IndexByte(name, 0); i != -1 { + name = name[:i] + } + t.tunName = string(name) + return t.tunName, nil +} + +func (t *Interface) MTU() (int, error) { + name, err := t.Name() + if err != nil { + return 0, err + } + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + if err != nil { + return 0, err + } + defer unix.Close(fd) + var ifr [ifReqSize]byte + copy(ifr[:], name) + _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr[0]))) + if errno != 0 { + return 0, errno + } + return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil +} + +func (t *Interface) SetMTU(mtu int) error { + name, err := t.Name() + if err != nil { + return err + } + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + if err != nil { + return err + } + defer unix.Close(fd) + var ifr [ifReqSize]byte + copy(ifr[:], name) + *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(mtu) + _, _, errno := unix.Syscall( + unix.SYS_IOCTL, + uintptr(fd), + uintptr(unix.SIOCSIFMTU), + uintptr(unsafe.Pointer(&ifr[0])), + ) + if errno != 0 { + return errno + } + return nil +} + +func (t *Interface) SetAddress() error { + name, err := t.Name() + if err != nil { + return err + } + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + if err != nil { + return err + } + defer unix.Close(fd) + ifreq, err := unix.NewIfreq(name) + if err != nil { + return E.Cause(err, "failed to create ifreq for name ", name) + } + + ifreq.SetInet4Addr(t.inetAddress.Addr().AsSlice()) + err = unix.IoctlIfreq(fd, syscall.SIOCSIFADDR, ifreq) + if err == nil { + ifreq, _ = unix.NewIfreq(name) + ifreq.SetInet4Addr(net.CIDRMask(t.inetAddress.Bits(), 32)) + err = unix.IoctlIfreq(fd, syscall.SIOCSIFNETMASK, ifreq) + } + if err != nil { + return E.Cause(err, "failed to set ipv4 address on ", name) + } + if t.inet6Address.IsValid() { + ifreq, _ = unix.NewIfreq(name) + err = unix.IoctlIfreq(fd, syscall.SIOCGIFINDEX, ifreq) + if err != nil { + return E.Cause(err, "failed to get interface index for ", name) + } + + ifreq6 := in6_ifreq{ + ifr6_addr: in6_addr{ + addr: t.inet6Address.Addr().As16(), + }, + ifr6_prefixlen: uint32(t.inet6Address.Bits()), + ifr6_ifindex: ifreq.Uint32(), + } + + fd6, err := unix.Socket( + unix.AF_INET6, + unix.SOCK_DGRAM, + 0, + ) + if err != nil { + return err + } + defer unix.Close(fd6) + + if _, _, errno := syscall.Syscall( + syscall.SYS_IOCTL, + uintptr(fd6), + uintptr(syscall.SIOCSIFADDR), + uintptr(unsafe.Pointer(&ifreq6)), + ); errno != 0 { + return E.Cause(errno, "failed to set ipv6 address on ", name) + } + } + + ifreq, _ = unix.NewIfreq(name) + err = unix.IoctlIfreq(fd, syscall.SIOCGIFFLAGS, ifreq) + if err == nil { + ifreq.SetUint16(ifreq.Uint16() | syscall.IFF_UP | syscall.IFF_RUNNING) + err = unix.IoctlIfreq(fd, syscall.SIOCSIFFLAGS, ifreq) + } + if err != nil { + return E.Cause(err, "failed to bring tun device up") + } + + return nil +} + +type in6_addr struct { + addr [16]byte +} + +type in6_ifreq struct { + ifr6_addr in6_addr + ifr6_prefixlen uint32 + ifr6_ifindex uint32 +} +*/ diff --git a/common/udpnat/service.go b/common/udpnat/service.go index 34f30ea..4451741 100644 --- a/common/udpnat/service.go +++ b/common/udpnat/service.go @@ -13,11 +13,11 @@ import ( "github.com/sagernet/sing/common/cache" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/protocol/socks" + N "github.com/sagernet/sing/common/network" ) type Handler interface { - socks.UDPConnectionHandler + N.UDPConnectionHandler E.Handler } @@ -36,15 +36,16 @@ func New[K comparable](maxAge int64, handler Handler) *Service[K] { } } -func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) { +func (s *Service[T]) NewPacket(key T, writer func() N.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) { s.NewContextPacket(context.Background(), key, writer, buffer, metadata) } -func (s *Service[T]) NewContextPacket(ctx context.Context, key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) { +func (s *Service[T]) NewContextPacket(ctx context.Context, key T, writer func() N.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) { c, loaded := s.nat.LoadOrStore(key, func() *conn { c := &conn{ data: make(chan packet), - remoteAddr: metadata.Source.UDPAddr(), + localAddr: metadata.Source, + remoteAddr: metadata.Destination, source: writer(), } c.ctx, c.cancel = context.WithCancel(ctx) @@ -80,7 +81,7 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, writer func() type packet struct { data *buf.Buffer - destination *M.AddrPort + destination M.Socksaddr done context.CancelFunc } @@ -89,15 +90,16 @@ type conn struct { ctx context.Context cancel context.CancelFunc data chan packet - remoteAddr *net.UDPAddr - source socks.PacketWriter + localAddr M.Socksaddr + remoteAddr M.Socksaddr + source N.PacketWriter } -func (c *conn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (c *conn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { select { case p, ok := <-c.data: if !ok { - return nil, io.ErrClosedPipe + return M.Socksaddr{}, io.ErrClosedPipe } defer p.data.Release() _, err := buffer.ReadFrom(p.data) @@ -106,7 +108,7 @@ func (c *conn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { } } -func (c *conn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { +func (c *conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return c.source.WritePacket(buffer, destination) } @@ -126,7 +128,7 @@ func (c *conn) Close() error { } func (c *conn) LocalAddr() net.Addr { - return &common.DummyAddr{} + return c.localAddr } func (c *conn) RemoteAddr() net.Addr { diff --git a/common/uot/client.go b/common/uot/client.go index d04d748..91cd471 100644 --- a/common/uot/client.go +++ b/common/uot/client.go @@ -18,23 +18,23 @@ func NewClientConn(conn net.Conn) *ClientConn { return &ClientConn{conn} } -func (c *ClientConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (c *ClientConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { destination, err := AddrParser.ReadAddrPort(c) if err != nil { - return nil, err + return M.Socksaddr{}, err } var length uint16 err = binary.Read(c, binary.BigEndian, &length) if err != nil { - return nil, err + return M.Socksaddr{}, err } if buffer.FreeLen() < int(length) { - return nil, io.ErrShortBuffer + return M.Socksaddr{}, io.ErrShortBuffer } return destination, common.Error(buffer.ReadFullFrom(c, int(length))) } -func (c *ClientConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { +func (c *ClientConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { err := AddrParser.WriteAddrPort(c, destination) if err != nil { return err @@ -68,7 +68,7 @@ func (c *ClientConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } func (c *ClientConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - err = AddrParser.WriteAddrPort(c, M.AddrPortFromNetAddr(addr)) + err = AddrParser.WriteAddrPort(c, M.SocksaddrFromNet(addr)) if err != nil { return } diff --git a/common/uot/server.go b/common/uot/server.go index 3eb815f..9dbd6d7 100644 --- a/common/uot/server.go +++ b/common/uot/server.go @@ -47,8 +47,8 @@ func (c *ServerConn) loopInput() { if err != nil { break } - if destination.Addr.Family().IsFqdn() { - ip, err := LookupAddress(destination.Addr.Fqdn()) + if destination.Family().IsFqdn() { + ip, err := LookupAddress(destination.Fqdn) if err != nil { break } @@ -81,8 +81,7 @@ func (c *ServerConn) loopOutput() { if err != nil { break } - destination := M.AddrPortFromNetAddr(addr) - err = AddrParser.WriteAddrPort(c.outputWriter, destination) + err = AddrParser.WriteAddrPort(c.outputWriter, M.SocksaddrFromNet(addr)) if err != nil { break } diff --git a/go.mod b/go.mod index 1a79c1e..ada71b7 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d google.golang.org/protobuf v1.28.0 + gvisor.dev/gvisor v0.0.0-20220428010907-8082b77961ba lukechampine.com/blake3 v1.1.7 ) @@ -31,6 +32,7 @@ require ( github.com/cenkalti/backoff/v4 v4.1.1 // indirect github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect github.com/golang/protobuf v1.5.2 // indirect + github.com/google/btree v1.0.1 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/klauspost/cpuid/v2 v2.0.12 // indirect diff --git a/go.sum b/go.sum index 945b12b..1db8d99 100644 --- a/go.sum +++ b/go.sum @@ -154,6 +154,8 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/golangci/lint-1 v0.0.0-20181222135242-d2cdd8c08219/go.mod h1:/X8TswGSh1pIozq4ZwCfxS0WA5JGXguxk94ar/4c87Y= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -754,6 +756,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gvisor.dev/gvisor v0.0.0-20220428010907-8082b77961ba h1:qJ6jWSTl9q+/y4l8QCNpkNnasX/sHzhVnPRysee8PzY= +gvisor.dev/gvisor v0.0.0-20220428010907-8082b77961ba/go.mod h1:tWwEcFvJavs154OdjFCw78axNrsDlz4Zh8jvPqwcpGI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/protocol/http/listener.go b/protocol/http/listener.go index cb6b46d..98570e3 100644 --- a/protocol/http/listener.go +++ b/protocol/http/listener.go @@ -45,13 +45,8 @@ func HandleRequest(ctx context.Context, request *http.Request, conn net.Conn, au if portStr == "" { portStr = "80" } - destination, err := M.ParseAddrPort(request.URL.Hostname(), portStr) - if err != nil { - if err != nil { - return err - } - } - _, err = fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established") + destination := M.ParseSocksaddrHostPort(request.URL.Hostname(), portStr) + _, err := fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established") if err != nil { return E.Cause(err, "write http response") } @@ -87,17 +82,11 @@ func HandleRequest(ctx context.Context, request *http.Request, conn net.Conn, au if network != "tcp" && network != "tcp4" && network != "tcp6" { return nil, E.New("unsupported network ", network) } - - destination, err := M.ParseAddress(address) - if err != nil { - return nil, err - } - + metadata.Destination = M.ParseSocksaddr(address) + metadata.Protocol = "http" left, right := net.Pipe() go func() { - metadata.Destination = destination - metadata.Protocol = "http" - err = handler.NewConnection(ctx, right, metadata) + err := handler.NewConnection(ctx, right, metadata) if err != nil { handler.HandleError(&tcp.Error{Conn: right, Cause: err}) } diff --git a/protocol/shadowsocks/none.go b/protocol/shadowsocks/none.go index fa33a1f..4910b4a 100644 --- a/protocol/shadowsocks/none.go +++ b/protocol/shadowsocks/none.go @@ -10,8 +10,9 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/udpnat" - "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/socks5" ) const MethodNone = "none" @@ -30,7 +31,7 @@ func (m *NoneMethod) KeyLength() int { return 0 } -func (m *NoneMethod) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) { +func (m *NoneMethod) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { shadowsocksConn := &noneConn{ Conn: conn, handshake: true, @@ -39,14 +40,14 @@ func (m *NoneMethod) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, return shadowsocksConn, shadowsocksConn.clientHandshake() } -func (m *NoneMethod) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn { +func (m *NoneMethod) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { return &noneConn{ Conn: conn, destination: destination, } } -func (m *NoneMethod) DialPacketConn(conn net.Conn) socks.PacketConn { +func (m *NoneMethod) DialPacketConn(conn net.Conn) N.PacketConn { return &nonePacketConn{conn} } @@ -55,11 +56,11 @@ type noneConn struct { access sync.Mutex handshake bool - destination *M.AddrPort + destination M.Socksaddr } func (c *noneConn) clientHandshake() error { - err := socks.AddressSerializer.WriteAddrPort(c.Conn, c.destination) + err := socks5.AddressSerializer.WriteAddrPort(c.Conn, c.destination) if err != nil { return err } @@ -87,7 +88,7 @@ func (c *noneConn) Write(b []byte) (n int, err error) { _buffer := buf.StackNew() buffer := common.Dup(_buffer) - err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination) + err = socks5.AddressSerializer.WriteAddrPort(buffer, c.destination) if err != nil { return } @@ -132,19 +133,19 @@ type nonePacketConn struct { net.Conn } -func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { _, err := buffer.ReadFrom(c) if err != nil { - return nil, err + return M.Socksaddr{}, err } - return socks.AddressSerializer.ReadAddrPort(buffer) + return socks5.AddressSerializer.ReadAddrPort(buffer) } -func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error { +func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort M.Socksaddr) error { defer buffer.Release() _header := buf.StackNewMax() header := common.Dup(_header) - err := socks.AddressSerializer.WriteAddrPort(header, addrPort) + err := socks5.AddressSerializer.WriteAddrPort(header, addrPort) if err != nil { header.Release() return err @@ -167,7 +168,7 @@ func NewNoneService(udpTimeout int64, handler Handler) Service { } func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - destination, err := socks.AddressSerializer.ReadAddrPort(conn) + destination, err := socks5.AddressSerializer.ReadAddrPort(conn) if err != nil { return err } @@ -176,34 +177,34 @@ func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata return s.handler.NewConnection(ctx, conn, metadata) } -func (s *NoneService) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { - destination, err := socks.AddressSerializer.ReadAddrPort(buffer) +func (s *NoneService) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + destination, err := socks5.AddressSerializer.ReadAddrPort(buffer) if err != nil { return err } metadata.Protocol = "shadowsocks" metadata.Destination = destination - s.udp.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter { + s.udp.NewPacket(metadata.Source.AddrPort(), func() N.PacketWriter { return &nonePacketWriter{conn, metadata.Source} }, buffer, metadata) return nil } type nonePacketWriter struct { - socks.PacketConn - sourceAddr *M.AddrPort + N.PacketConn + sourceAddr M.Socksaddr } -func (s *nonePacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { - header := buf.With(buffer.ExtendHeader(socks.AddressSerializer.AddrPortLen(destination))) - err := socks.AddressSerializer.WriteAddrPort(header, destination) +func (s *nonePacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buf.With(buffer.ExtendHeader(socks5.AddressSerializer.AddrPortLen(destination))) + err := socks5.AddressSerializer.WriteAddrPort(header, destination) if err != nil { return err } return s.PacketConn.WritePacket(buffer, s.sourceAddr) } -func (s *NoneService) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error { +func (s *NoneService) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { return s.handler.NewPacketConnection(ctx, conn, metadata) } diff --git a/protocol/shadowsocks/protocol.go b/protocol/shadowsocks/protocol.go index 5003640..4079ac9 100644 --- a/protocol/shadowsocks/protocol.go +++ b/protocol/shadowsocks/protocol.go @@ -8,15 +8,15 @@ import ( "net" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/protocol/socks" + N "github.com/sagernet/sing/common/network" ) type Method interface { Name() string KeyLength() int - DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) - DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn - DialPacketConn(conn net.Conn) socks.PacketConn + DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) + DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn + DialPacketConn(conn net.Conn) N.PacketConn } func Key(password []byte, keySize int) []byte { diff --git a/protocol/shadowsocks/service.go b/protocol/shadowsocks/service.go index 7c0fc5e..f09e030 100644 --- a/protocol/shadowsocks/service.go +++ b/protocol/shadowsocks/service.go @@ -7,17 +7,17 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/protocol/socks" + N "github.com/sagernet/sing/common/network" ) type Service interface { M.TCPConnectionHandler - socks.UDPHandler + N.UDPHandler } type Handler interface { M.TCPConnectionHandler - socks.UDPConnectionHandler + N.UDPConnectionHandler E.Handler } @@ -34,7 +34,7 @@ type UserContext[U comparable] struct { type ServerConnError struct { net.Conn - Source *M.AddrPort + Source M.Socksaddr Cause error } @@ -47,8 +47,8 @@ func (e *ServerConnError) Error() string { } type ServerPacketError struct { - socks.PacketConn - Source *M.AddrPort + N.PacketConn + Source M.Socksaddr Cause error } diff --git a/protocol/shadowsocks/shadowaead/protocol.go b/protocol/shadowsocks/shadowaead/protocol.go index 340409b..79b480e 100644 --- a/protocol/shadowsocks/shadowaead/protocol.go +++ b/protocol/shadowsocks/shadowaead/protocol.go @@ -12,10 +12,11 @@ import ( "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/replay" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/protocol/shadowsocks" - "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/socks5" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/hkdf" ) @@ -138,7 +139,7 @@ func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) { return NewWriter(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil } -func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) { +func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { shadowsocksConn := &clientConn{ Conn: conn, method: m, @@ -147,7 +148,7 @@ func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, err return shadowsocksConn, shadowsocksConn.writeRequest(nil) } -func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn { +func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { return &clientConn{ Conn: conn, method: m, @@ -155,7 +156,7 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn } } -func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn { +func (m *Method) DialPacketConn(conn net.Conn) N.PacketConn { return &clientPacketConn{m, conn} } @@ -186,7 +187,7 @@ type clientConn struct { net.Conn method *Method - destination *M.AddrPort + destination M.Socksaddr access sync.Mutex reader *Reader @@ -209,7 +210,7 @@ func (c *clientConn) writeRequest(payload []byte) error { bufferedWriter := writer.BufferedWriter(header.Len()) if len(payload) > 0 { - err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination) + err := socks5.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination) if err != nil { return err } @@ -219,7 +220,7 @@ func (c *clientConn) writeRequest(payload []byte) error { return err } } else { - err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination) + err := socks5.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination) if err != nil { return err } @@ -325,10 +326,10 @@ type clientPacketConn struct { net.Conn } -func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { - header := buffer.ExtendHeader(c.keySaltLength + socks.AddressSerializer.AddrPortLen(destination)) +func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buffer.ExtendHeader(c.keySaltLength + socks5.AddressSerializer.AddrPortLen(destination)) common.Must1(io.ReadFull(c.secureRNG, header[:c.keySaltLength])) - err := socks.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination) + err := socks5.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination) if err != nil { return err } @@ -339,17 +340,17 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo return common.Error(c.Write(buffer.Bytes())) } -func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { n, err := c.Read(buffer.FreeBytes()) if err != nil { - return nil, err + return M.Socksaddr{}, err } buffer.Truncate(n) err = c.DecodePacket(buffer) if err != nil { - return nil, err + return M.Socksaddr{}, err } - return socks.AddressSerializer.ReadAddrPort(buffer) + return socks5.AddressSerializer.ReadAddrPort(buffer) } func (c *clientPacketConn) UpstreamReader() io.Reader { diff --git a/protocol/shadowsocks/shadowaead/service.go b/protocol/shadowsocks/shadowaead/service.go index 608542b..a1ccd2e 100644 --- a/protocol/shadowsocks/shadowaead/service.go +++ b/protocol/shadowsocks/shadowaead/service.go @@ -12,11 +12,12 @@ import ( "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/replay" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/udpnat" "github.com/sagernet/sing/protocol/shadowsocks" - "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/socks5" "golang.org/x/crypto/chacha20poly1305" ) @@ -97,7 +98,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M key := Kdf(s.key, salt, s.keySaltLength) reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize) - destination, err := socks.AddressSerializer.ReadAddrPort(reader) + destination, err := socks5.AddressSerializer.ReadAddrPort(reader) if err != nil { return err } @@ -198,7 +199,7 @@ func (c *serverConn) WriterReplaceable() bool { return c.writer != nil } -func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { +func (s *Service) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { err := s.newPacket(conn, buffer, metadata) if err != nil { err = &shadowsocks.ServerPacketError{PacketConn: conn, Source: metadata.Source, Cause: err} @@ -206,7 +207,7 @@ func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata return err } -func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { +func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { if buffer.Len() < s.keySaltLength { return E.New("bad packet") } @@ -219,7 +220,7 @@ func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata buffer.Advance(s.keySaltLength) buffer.Truncate(len(packet)) metadata.Protocol = "shadowsocks" - s.udpNat.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter { + s.udpNat.NewPacket(metadata.Source.AddrPort(), func() N.PacketWriter { return &serverPacketWriter{s, conn, metadata.Source} }, buffer, metadata) return nil @@ -227,14 +228,14 @@ func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata type serverPacketWriter struct { *Service - socks.PacketConn - source *M.AddrPort + N.PacketConn + source M.Socksaddr } -func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { - header := buffer.ExtendHeader(w.keySaltLength + socks.AddressSerializer.AddrPortLen(destination)) +func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buffer.ExtendHeader(w.keySaltLength + socks5.AddressSerializer.AddrPortLen(destination)) common.Must1(io.ReadFull(w.secureRNG, header[:w.keySaltLength])) - err := socks.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination) + err := socks5.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination) if err != nil { return err } diff --git a/protocol/shadowsocks/shadowaead_2022/protocol.go b/protocol/shadowsocks/shadowaead_2022/protocol.go index b78f735..b5526d6 100644 --- a/protocol/shadowsocks/shadowaead_2022/protocol.go +++ b/protocol/shadowsocks/shadowaead_2022/protocol.go @@ -19,11 +19,12 @@ import ( E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/log" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/replay" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead" - "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/socks5" "golang.org/x/crypto/chacha20poly1305" wgReplay "golang.zx2c4.com/wireguard/replay" "lukechampine.com/blake3" @@ -163,7 +164,7 @@ func (m *Method) KeyLength() int { return m.keyLength } -func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) { +func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { shadowsocksConn := &clientConn{ Conn: conn, method: m, @@ -172,7 +173,7 @@ func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, err return shadowsocksConn, shadowsocksConn.writeRequest(nil) } -func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn { +func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { return &clientConn{ Conn: conn, method: m, @@ -180,7 +181,7 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn } } -func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn { +func (m *Method) DialPacketConn(conn net.Conn) N.PacketConn { return &clientPacketConn{conn, m, m.newUDPSession()} } @@ -188,7 +189,7 @@ type clientConn struct { net.Conn method *Method - destination *M.AddrPort + destination M.Socksaddr request sync.Mutex response sync.Mutex @@ -267,7 +268,7 @@ func (c *clientConn) writeRequest(payload []byte) error { common.Must(rw.WriteByte(bufferedWriter, HeaderTypeClient)) common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix()))) - err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination) + err := socks5.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination) if err != nil { return E.Cause(err, "write destination") } @@ -465,7 +466,7 @@ type clientPacketConn struct { session *udpSession } -func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { +func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { if debug.Enabled { logger.Trace("begin client packet") } @@ -534,7 +535,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())), binary.Write(header, binary.BigEndian, uint16(0)), // padding length ) - err := socks.AddressSerializer.WriteAddrPort(header, destination) + err := socks5.AddressSerializer.WriteAddrPort(header, destination) if err != nil { return err } @@ -551,14 +552,16 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo buffer.Extend(c.session.cipher.Overhead()) c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader) } - logger.Trace("ended client packet") + if debug.Enabled { + logger.Trace("ended client packet") + } return common.Error(c.Write(buffer.Bytes())) } -func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { n, err := c.Read(buffer.FreeBytes()) if err != nil { - return nil, err + return M.Socksaddr{}, err } buffer.Truncate(n) @@ -566,7 +569,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { if c.method.udpCipher != nil { _, err = c.method.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) if err != nil { - return nil, E.Cause(err, "decrypt packet") + return M.Socksaddr{}, E.Cause(err, "decrypt packet") } buffer.Advance(PacketNonceSize) } else { @@ -577,11 +580,11 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { var sessionId, packetId uint64 err = binary.Read(buffer, binary.BigEndian, &sessionId) if err != nil { - return nil, err + return M.Socksaddr{}, err } err = binary.Read(buffer, binary.BigEndian, &packetId) if err != nil { - return nil, err + return M.Socksaddr{}, err } var remoteCipher cipher.AEAD @@ -596,42 +599,42 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { } _, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) if err != nil { - return nil, E.Cause(err, "decrypt packet") + return M.Socksaddr{}, E.Cause(err, "decrypt packet") } } var headerType byte headerType, err = buffer.ReadByte() if err != nil { - return nil, err + return M.Socksaddr{}, err } if headerType != HeaderTypeServer { - return nil, ErrBadHeaderType + return M.Socksaddr{}, ErrBadHeaderType } var epoch uint64 err = binary.Read(buffer, binary.BigEndian, &epoch) if err != nil { - return nil, err + return M.Socksaddr{}, err } if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 { - return nil, ErrBadTimestamp + return M.Socksaddr{}, ErrBadTimestamp } if sessionId == c.session.remoteSessionId { if !c.session.filter.ValidateCounter(packetId, math.MaxUint64) { - return nil, ErrPacketIdNotUnique + return M.Socksaddr{}, ErrPacketIdNotUnique } } else if sessionId == c.session.lastRemoteSessionId { if !c.session.lastFilter.ValidateCounter(packetId, math.MaxUint64) { - return nil, ErrPacketIdNotUnique + return M.Socksaddr{}, ErrPacketIdNotUnique } remoteCipher = c.session.lastRemoteCipher c.session.lastRemoteSeen = time.Now().Unix() } else { if c.session.remoteSessionId != 0 { if time.Now().Unix()-c.session.lastRemoteSeen < 60 { - return nil, ErrTooManyServerSessions + return M.Socksaddr{}, ErrTooManyServerSessions } else { c.session.lastRemoteSessionId = c.session.remoteSessionId c.session.lastFilter = c.session.filter @@ -648,20 +651,20 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { var clientSessionId uint64 err = binary.Read(buffer, binary.BigEndian, &clientSessionId) if err != nil { - return nil, err + return M.Socksaddr{}, err } if clientSessionId != c.session.sessionId { - return nil, ErrBadClientSessionId + return M.Socksaddr{}, ErrBadClientSessionId } var paddingLength uint16 err = binary.Read(buffer, binary.BigEndian, &paddingLength) if err != nil { - return nil, E.Cause(err, "read padding length") + return M.Socksaddr{}, E.Cause(err, "read padding length") } buffer.Advance(int(paddingLength)) - return socks.AddressSerializer.ReadAddrPort(buffer) + return socks5.AddressSerializer.ReadAddrPort(buffer) } type udpSession struct { diff --git a/protocol/shadowsocks/shadowaead_2022/service.go b/protocol/shadowsocks/shadowaead_2022/service.go index c973f1c..0626720 100644 --- a/protocol/shadowsocks/shadowaead_2022/service.go +++ b/protocol/shadowsocks/shadowaead_2022/service.go @@ -18,12 +18,13 @@ import ( "github.com/sagernet/sing/common/debug" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/replay" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/udpnat" "github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead" - "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/socks5" wgReplay "golang.zx2c4.com/wireguard/replay" ) @@ -132,7 +133,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M return ErrBadTimestamp } - destination, err := socks.AddressSerializer.ReadAddrPort(reader) + destination, err := socks5.AddressSerializer.ReadAddrPort(reader) if err != nil { return E.Cause(err, "read destination") } @@ -268,7 +269,7 @@ func (c *serverConn) WriterReplaceable() bool { return c.writer != nil } -func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { +func (s *Service) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { err := s.newPacket(conn, buffer, metadata) if err != nil { err = &shadowsocks.ServerPacketError{PacketConn: conn, Source: metadata.Source, Cause: err} @@ -276,7 +277,7 @@ func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata return err } -func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { +func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { var packetHeader []byte if s.udpCipher != nil { _, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) @@ -358,14 +359,14 @@ process: } buffer.Advance(int(paddingLength)) - destination, err := socks.AddressSerializer.ReadAddrPort(buffer) + destination, err := socks5.AddressSerializer.ReadAddrPort(buffer) if err != nil { goto returnErr } metadata.Destination = destination session.remoteAddr = metadata.Source - s.udpNat.NewPacket(sessionId, func() socks.PacketWriter { + s.udpNat.NewPacket(sessionId, func() N.PacketWriter { return &serverPacketWriter{s, conn, session} }, buffer, metadata) return nil @@ -373,11 +374,11 @@ process: type serverPacketWriter struct { *Service - socks.PacketConn + N.PacketConn session *serverUDPSession } -func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { +func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { defer buffer.Release() _header := buf.StackNew() @@ -400,7 +401,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.Addr binary.Write(header, binary.BigEndian, uint16(0)), // padding length ) - err := socks.AddressSerializer.WriteAddrPort(header, destination) + err := socks5.AddressSerializer.WriteAddrPort(header, destination) if err != nil { return err } @@ -425,7 +426,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.Addr type serverUDPSession struct { sessionId uint64 remoteSessionId uint64 - remoteAddr *M.AddrPort + remoteAddr M.Socksaddr packetId uint64 cipher cipher.AEAD remoteCipher cipher.AEAD diff --git a/protocol/shadowsocks/shadowaead_2022/service_multi.go b/protocol/shadowsocks/shadowaead_2022/service_multi.go index 561448f..7898d09 100644 --- a/protocol/shadowsocks/shadowaead_2022/service_multi.go +++ b/protocol/shadowsocks/shadowaead_2022/service_multi.go @@ -13,10 +13,11 @@ import ( "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead" - "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/socks5" "lukechampine.com/blake3" ) @@ -140,7 +141,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta return ErrBadTimestamp } - destination, err := socks.AddressSerializer.ReadAddrPort(reader) + destination, err := socks5.AddressSerializer.ReadAddrPort(reader) if err != nil { return E.Cause(err, "read destination") } @@ -173,7 +174,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta }, metadata) } -func (s *MultiService[U]) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { +func (s *MultiService[U]) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { err := s.newPacket(conn, buffer, metadata) if err != nil { err = &shadowsocks.ServerPacketError{PacketConn: conn, Source: metadata.Source, Cause: err} @@ -181,7 +182,7 @@ func (s *MultiService[U]) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, m return err } -func (s *MultiService[U]) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { +func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { packetHeader := buffer.To(aes.BlockSize) s.udpBlockCipher.Decrypt(packetHeader, packetHeader) @@ -272,7 +273,7 @@ process: } buffer.Advance(int(paddingLength)) - destination, err := socks.AddressSerializer.ReadAddrPort(buffer) + destination, err := socks5.AddressSerializer.ReadAddrPort(buffer) if err != nil { goto returnErr } @@ -284,7 +285,7 @@ process: userCtx.Context = context.Background() userCtx.User = user - s.udpNat.NewContextPacket(&userCtx, sessionId, func() socks.PacketWriter { + s.udpNat.NewContextPacket(&userCtx, sessionId, func() N.PacketWriter { return &serverPacketWriter{s.Service, conn, session} }, buffer, metadata) return nil diff --git a/protocol/socks/protocol_test.go b/protocol/socks/protocol_test.go deleted file mode 100644 index 0588cc2..0000000 --- a/protocol/socks/protocol_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package socks_test - -import ( - "net" - "sync" - "testing" - - M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/protocol/socks" -) - -func TestHandshake(t *testing.T) { - server, client := net.Pipe() - defer server.Close() - defer client.Close() - - wg := new(sync.WaitGroup) - wg.Add(1) - - method := socks.AuthTypeUsernamePassword - - go func() { - response, err := socks.ClientHandshake(client, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn("test"), 80), "user", "pswd") - if err != nil { - t.Fatal(err) - } - if response.ReplyCode != socks.ReplyCodeSuccess { - t.Fatal(response) - } - wg.Done() - }() - authRequest, err := socks.ReadAuthRequest(server) - if err != nil { - t.Fatal(err) - } - if len(authRequest.Methods) != 1 || authRequest.Methods[0] != method { - t.Fatal("bad methods: ", authRequest.Methods) - } - err = socks.WriteAuthResponse(server, &socks.AuthResponse{ - Version: socks.Version5, - Method: method, - }) - if err != nil { - t.Fatal(err) - } - usernamePasswordAuthRequest, err := socks.ReadUsernamePasswordAuthRequest(server) - if err != nil { - t.Fatal(err) - } - if usernamePasswordAuthRequest.Username != "user" || usernamePasswordAuthRequest.Password != "pswd" { - t.Fatal(authRequest) - } - err = socks.WriteUsernamePasswordAuthResponse(server, &socks.UsernamePasswordAuthResponse{ - Status: socks.UsernamePasswordStatusSuccess, - }) - if err != nil { - t.Fatal(err) - } - request, err := socks.ReadRequest(server) - if err != nil { - t.Fatal(err) - } - if request.Version != socks.Version5 || request.Command != socks.CommandConnect || request.Destination.Addr.Fqdn() != "test" || request.Destination.Port != 80 { - t.Fatal(request) - } - err = socks.WriteResponse(server, &socks.Response{ - Version: socks.Version5, - ReplyCode: socks.ReplyCodeSuccess, - Bind: M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0), - }) - if err != nil { - t.Fatal(err) - } - wg.Wait() -} diff --git a/protocol/socks/conn.go b/protocol/socks5/conn.go similarity index 87% rename from protocol/socks/conn.go rename to protocol/socks5/conn.go index ba6f5c6..72b87ca 100644 --- a/protocol/socks/conn.go +++ b/protocol/socks5/conn.go @@ -1,4 +1,4 @@ -package socks +package socks5 import ( "net" @@ -12,10 +12,10 @@ type AssociateConn struct { net.Conn conn net.Conn addr net.Addr - dest *M.AddrPort + dest M.Socksaddr } -func NewAssociateConn(conn net.Conn, packetConn net.Conn, destination *M.AddrPort) *AssociateConn { +func NewAssociateConn(conn net.Conn, packetConn net.Conn, destination M.Socksaddr) *AssociateConn { return &AssociateConn{ Conn: packetConn, conn: conn, @@ -46,7 +46,7 @@ func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { _buffer := buf.StackNew() buffer := common.Dup(_buffer) common.Must(buffer.WriteZeroN(3)) - err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr)) + err = AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr)) if err != nil { return } @@ -80,17 +80,17 @@ func (c *AssociateConn) Write(b []byte) (n int, err error) { return } -func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { n, err := buffer.ReadFrom(c.conn) if err != nil { - return nil, err + return M.Socksaddr{}, err } buffer.Truncate(int(n)) buffer.Advance(3) return AddressSerializer.ReadAddrPort(buffer) } -func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { +func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { defer buffer.Release() header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination))) common.Must(header.WriteZeroN(3)) @@ -102,10 +102,10 @@ type AssociatePacketConn struct { net.PacketConn conn net.Conn addr net.Addr - dest *M.AddrPort + dest M.Socksaddr } -func NewAssociatePacketConn(conn net.Conn, packetConn net.PacketConn, destination *M.AddrPort) *AssociatePacketConn { +func NewAssociatePacketConn(conn net.Conn, packetConn net.PacketConn, destination M.Socksaddr) *AssociatePacketConn { return &AssociatePacketConn{ PacketConn: packetConn, conn: conn, @@ -137,7 +137,7 @@ func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error buffer := common.Dup(_buffer) common.Must(buffer.WriteZeroN(3)) - err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr)) + err = AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr)) if err != nil { return } @@ -171,10 +171,10 @@ func (c *AssociatePacketConn) Write(b []byte) (n int, err error) { return } -func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes()) if err != nil { - return nil, err + return M.Socksaddr{}, err } c.addr = addr buffer.Truncate(n) @@ -183,7 +183,7 @@ func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error return dest, err } -func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { +func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { defer buffer.Release() header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination))) common.Must(header.WriteZeroN(3)) diff --git a/protocol/socks/constant.go b/protocol/socks5/constant.go similarity index 99% rename from protocol/socks/constant.go rename to protocol/socks5/constant.go index b9b75c5..240fca2 100644 --- a/protocol/socks/constant.go +++ b/protocol/socks5/constant.go @@ -1,4 +1,4 @@ -package socks +package socks5 import ( "strconv" diff --git a/protocol/socks/exceptions.go b/protocol/socks5/exceptions.go similarity index 97% rename from protocol/socks/exceptions.go rename to protocol/socks5/exceptions.go index 9f8b8e0..549ed31 100644 --- a/protocol/socks/exceptions.go +++ b/protocol/socks5/exceptions.go @@ -1,4 +1,4 @@ -package socks +package socks5 import "fmt" diff --git a/protocol/socks/handshake.go b/protocol/socks5/handshake.go similarity index 94% rename from protocol/socks/handshake.go rename to protocol/socks5/handshake.go index a7a119c..d6a2f34 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks5/handshake.go @@ -1,4 +1,4 @@ -package socks +package socks5 import ( "io" @@ -8,7 +8,7 @@ import ( M "github.com/sagernet/sing/common/metadata" ) -func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination *M.AddrPort, username string, password string) (*Response, error) { +func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination M.Socksaddr, username string, password string) (*Response, error) { var method byte if common.IsBlank(username) { method = AuthTypeNotRequired @@ -56,7 +56,7 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination return ReadResponse(conn) } -func ClientFastHandshake(writer io.Writer, version byte, command byte, destination *M.AddrPort, username string, password string) error { +func ClientFastHandshake(writer io.Writer, version byte, command byte, destination M.Socksaddr, username string, password string) error { var method byte if common.IsBlank(username) { method = AuthTypeNotRequired diff --git a/protocol/socks/listener.go b/protocol/socks5/listener.go similarity index 93% rename from protocol/socks/listener.go rename to protocol/socks5/listener.go index 2fb4fb6..fdde004 100644 --- a/protocol/socks/listener.go +++ b/protocol/socks5/listener.go @@ -1,4 +1,4 @@ -package socks +package socks5 import ( "context" @@ -10,12 +10,13 @@ import ( "github.com/sagernet/sing/common/auth" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/transport/tcp" ) type Handler interface { tcp.Handler - UDPConnectionHandler + N.UDPConnectionHandler } type Listener struct { @@ -36,7 +37,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler } func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - return HandleConnection(ctx, conn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata) + return HandleConnection(ctx, conn, l.authenticator, M.AddrFromNetAddr(conn.LocalAddr()), l.handler, metadata) } func (l *Listener) Start() error { @@ -117,7 +118,7 @@ func handleConnection(authRequest *AuthRequest, ctx context.Context, conn net.Co err = WriteResponse(conn, &Response{ Version: request.Version, ReplyCode: ReplyCodeSuccess, - Bind: M.AddrPortFromNetAddr(conn.LocalAddr()), + Bind: M.SocksaddrFromNet(conn.LocalAddr()), }) if err != nil { return E.Cause(err, "write socks response") @@ -138,7 +139,7 @@ func handleConnection(authRequest *AuthRequest, ctx context.Context, conn net.Co err = WriteResponse(conn, &Response{ Version: request.Version, ReplyCode: ReplyCodeSuccess, - Bind: M.AddrPortFromNetAddr(udpConn.LocalAddr()), + Bind: M.SocksaddrFromNet(udpConn.LocalAddr()), }) if err != nil { return E.Cause(err, "write socks response") diff --git a/protocol/socks5/packet_conn.go b/protocol/socks5/packet_conn.go new file mode 100644 index 0000000..ab0aa7c --- /dev/null +++ b/protocol/socks5/packet_conn.go @@ -0,0 +1 @@ +package socks5 diff --git a/protocol/socks/protocol.go b/protocol/socks5/protocol.go similarity index 97% rename from protocol/socks/protocol.go rename to protocol/socks5/protocol.go index 08f2ded..4f514c9 100644 --- a/protocol/socks/protocol.go +++ b/protocol/socks5/protocol.go @@ -1,9 +1,9 @@ -package socks +package socks5 import ( "bytes" "io" - "net" + "net/netip" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -203,7 +203,7 @@ func ReadUsernamePasswordAuthResponse(reader io.Reader) (*UsernamePasswordAuthRe type Request struct { Version byte Command byte - Destination *M.AddrPort + Destination M.Socksaddr } func WriteRequest(writer io.Writer, request *Request) error { @@ -262,7 +262,7 @@ func ReadRequest(reader io.Reader) (*Request, error) { type Response struct { Version byte ReplyCode ReplyCode - Bind *M.AddrPort + Bind M.Socksaddr } func WriteResponse(writer io.Writer, response *Response) error { @@ -278,8 +278,10 @@ func WriteResponse(writer io.Writer, response *Response) error { if err != nil { return err } - if response.Bind == nil { - return AddressSerializer.WriteAddrPort(writer, M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0)) + if !response.Bind.IsValid() { + return AddressSerializer.WriteAddrPort(writer, M.Socksaddr{ + Addr: netip.IPv4Unspecified(), + }) } return AddressSerializer.WriteAddrPort(writer, response.Bind) } @@ -320,7 +322,7 @@ func ReadResponse(reader io.Reader) (*Response, error) { type AssociatePacket struct { Fragment byte - Destination *M.AddrPort + Destination M.Socksaddr Data []byte } diff --git a/protocol/socks5/protocol_test.go b/protocol/socks5/protocol_test.go new file mode 100644 index 0000000..6a4c4e4 --- /dev/null +++ b/protocol/socks5/protocol_test.go @@ -0,0 +1,75 @@ +package socks5_test + +import ( + "net" + "sync" + "testing" + + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/protocol/socks5" +) + +func TestHandshake(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + wg := new(sync.WaitGroup) + wg.Add(1) + + method := socks5.AuthTypeUsernamePassword + + go func() { + response, err := socks5.ClientHandshake(client, socks5.Version5, socks5.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn("test"), 80), "user", "pswd") + if err != nil { + t.Fatal(err) + } + if response.ReplyCode != socks5.ReplyCodeSuccess { + t.Fatal(response) + } + wg.Done() + }() + authRequest, err := socks5.ReadAuthRequest(server) + if err != nil { + t.Fatal(err) + } + if len(authRequest.Methods) != 1 || authRequest.Methods[0] != method { + t.Fatal("bad methods: ", authRequest.Methods) + } + err = socks5.WriteAuthResponse(server, &socks5.AuthResponse{ + Version: socks5.Version5, + Method: method, + }) + if err != nil { + t.Fatal(err) + } + usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(server) + if err != nil { + t.Fatal(err) + } + if usernamePasswordAuthRequest.Username != "user" || usernamePasswordAuthRequest.Password != "pswd" { + t.Fatal(authRequest) + } + err = socks5.WriteUsernamePasswordAuthResponse(server, &socks5.UsernamePasswordAuthResponse{ + Status: socks5.UsernamePasswordStatusSuccess, + }) + if err != nil { + t.Fatal(err) + } + request, err := socks5.ReadRequest(server) + if err != nil { + t.Fatal(err) + } + if request.Version != socks5.Version5 || request.Command != socks5.CommandConnect || request.Destination.Addr.Fqdn() != "test" || request.Destination.Port != 80 { + t.Fatal(request) + } + err = socks5.WriteResponse(server, &socks5.Response{ + Version: socks5.Version5, + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0), + }) + if err != nil { + t.Fatal(err) + } + wg.Wait() +} diff --git a/protocol/trojan/protocol.go b/protocol/trojan/protocol.go index 1d564ca..21adb2a 100644 --- a/protocol/trojan/protocol.go +++ b/protocol/trojan/protocol.go @@ -12,7 +12,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/socks5" ) const ( @@ -26,11 +26,11 @@ var CRLF = []byte{'\r', '\n'} type ClientConn struct { net.Conn key [KeyLength]byte - destination *M.AddrPort + destination M.Socksaddr headerWritten bool } -func NewClientConn(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort) *ClientConn { +func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn { return &ClientConn{ Conn: conn, key: key, @@ -75,11 +75,11 @@ func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn { } } -func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { return ReadPacket(c.Conn, buffer) } -func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { +func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { if !c.headerWritten { return ClientHandshakePacket(c.Conn, c.key, destination, buffer) } @@ -98,7 +98,7 @@ func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) } func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - err = c.WritePacket(buf.With(p), M.AddrPortFromNetAddr(addr)) + err = c.WritePacket(buf.With(p), M.SocksaddrFromNet(addr)) if err == nil { n = len(p) } @@ -113,7 +113,7 @@ func Key(password string) [KeyLength]byte { return key } -func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination *M.AddrPort, payload []byte) error { +func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error { _, err := conn.Write(key[:]) if err != nil { return err @@ -126,7 +126,7 @@ func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destin if err != nil { return err } - err = socks.AddressSerializer.WriteAddrPort(conn, destination) + err = socks5.AddressSerializer.WriteAddrPort(conn, destination) if err != nil { return err } @@ -143,8 +143,8 @@ func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destin return nil } -func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort, payload []byte) error { - headerLen := KeyLength + socks.AddressSerializer.AddrPortLen(destination) + 5 +func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error { + headerLen := KeyLength + socks5.AddressSerializer.AddrPortLen(destination) + 5 var header *buf.Buffer var writeHeader bool if len(payload) > 0 && headerLen+len(payload) < 65535 { @@ -158,7 +158,7 @@ func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort common.Must1(header.Write(key[:])) common.Must1(header.Write(CRLF)) common.Must(header.WriteByte(CommandTCP)) - common.Must(socks.AddressSerializer.WriteAddrPort(header, destination)) + common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination)) common.Must1(header.Write(CRLF)) common.Must1(header.Write(payload)) @@ -176,8 +176,8 @@ func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort return nil } -func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort, payload *buf.Buffer) error { - headerLen := KeyLength + 2*socks.AddressSerializer.AddrPortLen(destination) + 9 +func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error { + headerLen := KeyLength + 2*socks5.AddressSerializer.AddrPortLen(destination) + 9 payloadLen := payload.Len() var header *buf.Buffer var writeHeader bool @@ -191,9 +191,9 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination *M.Ad common.Must1(header.Write(key[:])) common.Must1(header.Write(CRLF)) common.Must(header.WriteByte(CommandUDP)) - common.Must(socks.AddressSerializer.WriteAddrPort(header, destination)) + common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination)) common.Must1(header.Write(CRLF)) - common.Must(socks.AddressSerializer.WriteAddrPort(header, destination)) + common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination)) common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen))) common.Must1(header.Write(CRLF)) @@ -211,33 +211,33 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination *M.Ad return nil } -func ReadPacket(conn net.Conn, buffer *buf.Buffer) (*M.AddrPort, error) { - destination, err := socks.AddressSerializer.ReadAddrPort(conn) +func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) { + destination, err := socks5.AddressSerializer.ReadAddrPort(conn) if err != nil { - return nil, E.Cause(err, "read destination") + return M.Socksaddr{}, E.Cause(err, "read destination") } var length uint16 err = binary.Read(conn, binary.BigEndian, &length) if err != nil { - return nil, E.Cause(err, "read chunk length") + return M.Socksaddr{}, E.Cause(err, "read chunk length") } if buffer.FreeLen() < int(length) { - return nil, io.ErrShortBuffer + return M.Socksaddr{}, io.ErrShortBuffer } err = rw.SkipN(conn, 2) if err != nil { - return nil, E.Cause(err, "skip crlf") + return M.Socksaddr{}, E.Cause(err, "skip crlf") } _, err = buffer.ReadFullFrom(conn, int(length)) return destination, err } -func WritePacket(conn net.Conn, buffer *buf.Buffer, destination *M.AddrPort) error { - headerOverload := socks.AddressSerializer.AddrPortLen(destination) + 4 +func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error { + headerOverload := socks5.AddressSerializer.AddrPortLen(destination) + 4 var header *buf.Buffer var writeHeader bool bufferLen := buffer.Len() @@ -248,7 +248,7 @@ func WritePacket(conn net.Conn, buffer *buf.Buffer, destination *M.AddrPort) err _buffer := buf.Make(headerOverload) header = buf.With(common.Dup(_buffer)) } - common.Must(socks.AddressSerializer.WriteAddrPort(header, destination)) + common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination)) common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen))) common.Must1(header.Write(CRLF)) if writeHeader { diff --git a/protocol/trojan/service.go b/protocol/trojan/service.go index 3a47730..99f6171 100644 --- a/protocol/trojan/service.go +++ b/protocol/trojan/service.go @@ -10,13 +10,14 @@ import ( "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/socks5" ) type Handler interface { M.TCPConnectionHandler - socks.UDPConnectionHandler + N.UDPConnectionHandler } type Context[K comparable] struct { @@ -115,7 +116,7 @@ process: goto returnErr } - destination, err := socks.AddressSerializer.ReadAddrPort(conn) + destination, err := socks5.AddressSerializer.ReadAddrPort(conn) if err != nil { err = E.Cause(err, "read destination") goto returnErr @@ -141,11 +142,11 @@ type PacketConn struct { net.Conn } -func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { return ReadPacket(c.Conn, buffer) } -func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { +func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return WritePacket(c.Conn, buffer, destination) } diff --git a/transport/mixed/listener.go b/transport/mixed/listener.go index ad11f9a..faf4149 100644 --- a/transport/mixed/listener.go +++ b/transport/mixed/listener.go @@ -15,17 +15,18 @@ import ( "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/redir" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/udpnat" "github.com/sagernet/sing/protocol/http" - "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/socks5" "github.com/sagernet/sing/transport/tcp" "github.com/sagernet/sing/transport/udp" ) type Handler interface { - socks.Handler + socks5.Handler } type Listener struct { @@ -53,15 +54,15 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transpro } func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - if metadata.Destination != nil { + if metadata.Destination.IsValid() { return l.handler.NewConnection(ctx, conn, metadata) } headerType, err := rw.ReadByte(conn) switch headerType { - case socks.Version4: + case socks5.Version4: return E.New("socks4 request dropped (TODO)") - case socks.Version5: - return socks.HandleConnection0(ctx, conn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata) + case socks5.Version5: + return socks5.HandleConnection0(ctx, conn, l.authenticator, M.AddrFromNetAddr(conn.LocalAddr()), l.handler, metadata) } reader := bufio.NewReader(&rw.BufferedReader{ @@ -75,7 +76,7 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M. } if request.Method == "GET" && request.URL.Path == "/proxy.pac" { - content := newPAC(M.AddrPortFromNetAddr(conn.LocalAddr())) + content := newPAC(M.AddrPortFromNet(conn.LocalAddr())) response := &netHttp.Response{ StatusCode: 200, Status: netHttp.StatusText(200), @@ -113,8 +114,8 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M. return http.HandleRequest(ctx, request, conn, l.authenticator, l.handler, metadata) } -func (l *Listener) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { - l.udpNat.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter { +func (l *Listener) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + l.udpNat.NewPacket(metadata.Source.AddrPort(), func() N.PacketWriter { return &tproxyPacketWriter{metadata.Source.UDPAddr()} }, buffer, metadata) return nil @@ -124,7 +125,7 @@ type tproxyPacketWriter struct { source *net.UDPAddr } -func (w *tproxyPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { +func (w *tproxyPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), w.source) if err != nil { return E.Cause(err, "tproxy udp write back") diff --git a/transport/mixed/pac.go b/transport/mixed/pac.go index 348cbc5..98e9a55 100644 --- a/transport/mixed/pac.go +++ b/transport/mixed/pac.go @@ -1,8 +1,10 @@ package mixed -import M "github.com/sagernet/sing/common/metadata" +import ( + "net/netip" +) -/*func newPAC(proxyAddr *M.AddrPort) string { +/*func newPAC(proxyAddr M.Socksaddr) string { return ` function FindProxyForURL(url, host) { return "SOCKS5 ` + proxyAddr.String() + `;SOCKS ` + proxyAddr.String() + `; PROXY ` + proxyAddr.String() + `"; @@ -10,7 +12,7 @@ function FindProxyForURL(url, host) { } */ -func newPAC(proxyAddr *M.AddrPort) string { +func newPAC(proxyAddr netip.AddrPort) string { // TODO: socks4 not supported return ` function FindProxyForURL(url, host) { diff --git a/transport/tcp/handler.go b/transport/tcp/handler.go index 68b8bbc..893474b 100644 --- a/transport/tcp/handler.go +++ b/transport/tcp/handler.go @@ -89,14 +89,14 @@ func (l *Listener) loop() { return } metadata := M.Metadata{ - Source: M.AddrPortFromNetAddr(tcpConn.RemoteAddr()), + Source: M.SocksaddrFromNet(tcpConn.RemoteAddr()), } switch l.trans { case redir.ModeRedirect: destination, err := redir.GetOriginalDestination(tcpConn) if err == nil { metadata.Protocol = "redirect" - metadata.Destination = destination + metadata.Destination = M.SocksaddrFromNetIP(destination) } case redir.ModeTProxy: lAddr := tcpConn.LocalAddr().(*net.TCPAddr) @@ -104,7 +104,7 @@ func (l *Listener) loop() { if lAddr.Port != l.lAddr.Port || !lAddr.IP.Equal(rAddr.IP) && !lAddr.IP.IsLoopback() && !lAddr.IP.IsPrivate() { metadata.Protocol = "tproxy" - metadata.Destination = M.AddrPortFromNetAddr(lAddr) + metadata.Destination = M.SocksaddrFromNet(lAddr) } } go func() { diff --git a/transport/udp/udp.go b/transport/udp/udp.go index 133a615..d26f94e 100644 --- a/transport/udp/udp.go +++ b/transport/udp/udp.go @@ -8,12 +8,12 @@ import ( "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/redir" - "github.com/sagernet/sing/protocol/socks" ) type Handler interface { - socks.UDPHandler + N.UDPHandler E.Handler } @@ -25,17 +25,24 @@ type Listener struct { tproxy bool } -func (l *Listener) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { - n, addr, err := l.ReadFromUDP(buffer.FreeBytes()) +func (l *Listener) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + n, addr, err := l.ReadFromUDPAddrPort(buffer.FreeBytes()) if err != nil { - return nil, err + return M.Socksaddr{}, err } buffer.Truncate(n) - return M.AddrPortFromNetAddr(addr), nil + return M.SocksaddrFromNetIP(addr), nil } -func (l *Listener) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { - return common.Error(l.UDPConn.WriteTo(buffer.Bytes(), destination.UDPAddr())) +func (l *Listener) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if destination.Family().IsFqdn() { + udpAddr, err := net.ResolveUDPAddr("udp", destination.String()) + if err != nil { + return err + } + return common.Error(l.UDPConn.WriteTo(buffer.Bytes(), udpAddr)) + } + return common.Error(l.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort())) } func NewUDPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener { @@ -88,7 +95,7 @@ func (l *Listener) loop() { data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader).Slice() if !l.tproxy { for { - n, addr, err := l.ReadFromUDP(data) + n, addr, err := l.ReadFromUDPAddrPort(data) if err != nil { l.handler.HandleError(err) return @@ -96,7 +103,7 @@ func (l *Listener) loop() { buffer.Resize(buf.ReversedHeader, n) err = l.handler.NewPacket(l, buffer, M.Metadata{ Protocol: "udp", - Source: M.AddrPortFromNetAddr(addr), + Source: M.SocksaddrFromNetIP(addr), }) if err != nil { l.handler.HandleError(err) @@ -119,8 +126,8 @@ func (l *Listener) loop() { buffer.Resize(buf.ReversedHeader, n) err = l.handler.NewPacket(l, buffer, M.Metadata{ Protocol: "tproxy", - Source: M.AddrPortFromAddrPort(addr), - Destination: destination, + Source: M.SocksaddrFromNetIP(addr), + Destination: M.SocksaddrFromNetIP(destination), }) if err != nil { l.handler.HandleError(err)