Refactor socksaddr

This commit is contained in:
世界 2022-05-04 19:12:27 +08:00
parent 9378ae739c
commit b35c53ca8f
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
54 changed files with 1191 additions and 666 deletions

View file

@ -48,4 +48,4 @@ jobs:
with:
go-version: 1.18.1
- name: Build
run: go build -v ./cli/ss-local
run: go build -v ./...

View file

@ -1,3 +1,5 @@
//go:build linux
package main
import (

View file

@ -18,9 +18,8 @@ import (
E "github.com/sagernet/sing/common/exceptions"
_ "github.com/sagernet/sing/common/log"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/trojan"
transTLS "github.com/sagernet/sing/transport/tls"
"github.com/sirupsen/logrus"
@ -177,14 +176,14 @@ func (i *TrojanInstance) NewConnection(ctx context.Context, conn net.Conn, metad
userCtx := ctx.(*trojan.Context[int])
conn = i.user.TrackConnection(userCtx.User, conn)
logrus.Info(i.id, ": user ", userCtx.User, " TCP ", metadata.Source, " ==> ", metadata.Destination)
destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination)
destConn, err := N.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination)
if err != nil {
return err
}
return rw.CopyConn(ctx, conn, destConn)
}
func (i *TrojanInstance) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error {
func (i *TrojanInstance) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
userCtx := ctx.(*trojan.Context[int])
conn = i.user.TrackPacketConnection(userCtx.User, conn)
logrus.Info(i.id, ": user ", userCtx.User, " UDP ", metadata.Source, " ==> ", metadata.Destination)
@ -192,7 +191,7 @@ func (i *TrojanInstance) NewPacketConnection(ctx context.Context, conn socks.Pac
if err != nil {
return err
}
return socks.CopyNetPacketConn(ctx, conn, udpConn)
return N.CopyNetPacketConn(ctx, conn, udpConn)
}
func (i *TrojanInstance) loopRequests() {
@ -205,7 +204,7 @@ func (i *TrojanInstance) loopRequests() {
go func() {
hErr := i.service.NewConnection(context.Background(), conn, M.Metadata{
Protocol: "tls",
Source: M.AddrPortFromNetAddr(conn.RemoteAddr()),
Source: M.SocksaddrFromNet(conn.RemoteAddr()),
})
if hErr != nil {
i.HandleError(hErr)

View file

@ -8,7 +8,7 @@ import (
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
N "github.com/sagernet/sing/common/network"
)
type UserManager struct {
@ -40,7 +40,7 @@ func (m *UserManager) TrackConnection(userId int, conn net.Conn) net.Conn {
return &TrackConn{conn, user}
}
func (m *UserManager) TrackPacketConnection(userId int, conn socks.PacketConn) socks.PacketConn {
func (m *UserManager) TrackPacketConnection(userId int, conn N.PacketConn) N.PacketConn {
m.access.Lock()
defer m.access.Unlock()
var user *User
@ -112,11 +112,11 @@ func (c *TrackConn) ReadFrom(r io.Reader) (n int64, err error) {
}
type TrackPacketConn struct {
socks.PacketConn
N.PacketConn
*User
}
func (c *TrackPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *TrackPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
destination, err := c.PacketConn.ReadPacket(buffer)
if err == nil {
atomic.AddUint64(&c.Upload, uint64(buffer.Len()))
@ -124,7 +124,7 @@ func (c *TrackPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
return destination, err
}
func (c *TrackPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *TrackPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
n := buffer.Len()
err := c.PacketConn.WritePacket(buffer, destination)
if err == nil {

View file

@ -11,7 +11,7 @@ import (
"github.com/sagernet/sing/common/buf"
_ "github.com/sagernet/sing/common/log"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/net/dns/dnsmessage"
@ -29,11 +29,8 @@ func main() {
}
func run(cmd *cobra.Command, args []string) {
server, err := M.ParseAddress(args[0])
if err != nil {
logrus.Fatal("invalid server address ", args[0])
}
err = testSocksTCP(server)
server := M.ParseSocksaddr(args[0])
err := testSocksTCP(server)
if err != nil {
logrus.Fatal(err)
}
@ -43,16 +40,16 @@ func run(cmd *cobra.Command, args []string) {
}
}
func testSocksTCP(server *M.AddrPort) error {
func testSocksTCP(server M.Socksaddr) error {
tcpConn, err := net.Dial("tcp", server.String())
if err != nil {
return err
}
response, err := socks.ClientHandshake(tcpConn, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.ParseAddr("1.0.0.1"), 53), "", "")
response, err := socks5.ClientHandshake(tcpConn, socks5.Version5, socks5.CommandConnect, M.ParseSocksaddrHostPort("1.0.0.1", "53"), "", "")
if err != nil {
return err
}
if response.ReplyCode != socks.ReplyCodeSuccess {
if response.ReplyCode != socks5.ReplyCodeSuccess {
logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode)
}
@ -98,17 +95,17 @@ func testSocksTCP(server *M.AddrPort) error {
return nil
}
func testSocksUDP(server *M.AddrPort) error {
func testSocksUDP(server M.Socksaddr) error {
tcpConn, err := net.Dial("tcp", server.String())
if err != nil {
return err
}
dest := M.AddrPortFrom(M.ParseAddr("1.0.0.1"), 53)
response, err := socks.ClientHandshake(tcpConn, socks.Version5, socks.CommandUDPAssociate, dest, "", "")
dest := M.ParseSocksaddrHostPort("1.0.0.1", "53")
response, err := socks5.ClientHandshake(tcpConn, socks5.Version5, socks5.CommandUDPAssociate, dest, "", "")
if err != nil {
return err
}
if response.ReplyCode != socks.ReplyCodeSuccess {
if response.ReplyCode != socks5.ReplyCodeSuccess {
logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode)
}
var dialer net.Dialer
@ -116,7 +113,7 @@ func testSocksUDP(server *M.AddrPort) error {
if err != nil {
return err
}
assConn := socks.NewAssociateConn(tcpConn, udpConn, dest)
assConn := socks5.NewAssociateConn(tcpConn, udpConn, dest)
message := &dnsmessage.Message{}
message.Header.ID = 1
message.Header.RecursionDesired = true

View file

@ -23,6 +23,7 @@ import (
"github.com/sagernet/sing/common/geosite"
_ "github.com/sagernet/sing/common/log"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/random"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/common/rw"
@ -30,7 +31,6 @@ import (
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/transport/mixed"
"github.com/sagernet/sing/transport/system"
"github.com/sirupsen/logrus"
@ -101,7 +101,7 @@ Only available with Linux kernel > 3.7.0.`)
type client struct {
*mixed.Listener
*geosite.Matcher
server *M.AddrPort
server M.Socksaddr
method shadowsocks.Method
dialer net.Dialer
bypass string
@ -163,7 +163,7 @@ func newClient(f *flags) (*client, error) {
}
c := &client{
server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort),
server: M.SocksaddrFromAddrPort(M.ParseAddr(f.Server), f.ServerPort),
bypass: f.Bypass,
}
@ -294,7 +294,7 @@ func newClient(f *flags) (*client, error) {
return c, nil
}
func bypass(conn net.Conn, destination *M.AddrPort) error {
func bypass(conn net.Conn, destination M.Socksaddr) error {
logrus.Info("BYPASS ", conn.RemoteAddr(), " ==> ", destination)
serverConn, err := net.Dial("tcp", destination.String())
if err != nil {
@ -313,12 +313,12 @@ func bypass(conn net.Conn, destination *M.AddrPort) error {
func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
if c.bypass != "" {
if metadata.Destination.Addr.Family().IsFqdn() {
if c.Match(metadata.Destination.Addr.Fqdn()) {
if metadata.Destination.Family().IsFqdn() {
if c.Match(metadata.Destination.Fqdn) {
return bypass(conn, metadata.Destination)
}
} else {
if geoip.Match(c.bypass, metadata.Destination.Addr.Addr().AsSlice()) {
if geoip.Match(c.bypass, metadata.Destination.Addr.AsSlice()) {
return bypass(conn, metadata.Destination)
}
}
@ -354,14 +354,14 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
return rw.CopyConn(ctx, serverConn, conn)
}
func (c *client) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error {
func (c *client) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
logrus.Info("outbound ", metadata.Protocol, " UDP ", metadata.Source, " ==> ", metadata.Destination)
udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String())
if err != nil {
return err
}
serverConn := c.method.DialPacketConn(udpConn)
return socks.CopyPacketConn(ctx, serverConn, conn)
return N.CopyPacketConn(ctx, serverConn, conn)
}
func run(cmd *cobra.Command, flags *flags) {

View file

@ -17,13 +17,12 @@ import (
E "github.com/sagernet/sing/common/exceptions"
_ "github.com/sagernet/sing/common/log"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/random"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/transport/tcp"
"github.com/sagernet/sing/transport/udp"
"github.com/sirupsen/logrus"
@ -191,23 +190,23 @@ func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
return s.service.NewConnection(ctx, conn, metadata)
}
logrus.Info("inbound TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination)
destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination)
destConn, err := N.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination)
if err != nil {
return err
}
return rw.CopyConn(ctx, conn, destConn)
}
func (s *server) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error {
func (s *server) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
logrus.Info("inbound UDP ", metadata.Source, " ==> ", metadata.Destination)
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return err
}
return socks.CopyNetPacketConn(ctx, conn, udpConn)
return N.CopyNetPacketConn(ctx, conn, udpConn)
}
func (s *server) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *server) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
logrus.Trace("inbound raw UDP from ", metadata.Source)
return s.service.NewPacket(conn, buffer, metadata)
}

View file

@ -19,9 +19,9 @@ import (
E "github.com/sagernet/sing/common/exceptions"
_ "github.com/sagernet/sing/common/log"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/trojan"
"github.com/sagernet/sing/transport/mixed"
"github.com/sirupsen/logrus"
@ -148,7 +148,7 @@ func newClient(f *flags) (*client, error) {
}
c := &client{
server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort).String(),
server: netip.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort).String(),
key: trojan.Key(f.Password),
sni: f.ServerName,
insecure: f.Insecure,
@ -319,7 +319,7 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
return rw.CopyConn(ctx, clientConn, conn)
}
func (c *client) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error {
func (c *client) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
logrus.Info("outbound ", metadata.Protocol, " UDP ", metadata.Source, " ==> ", metadata.Destination)
tlsConn, err := c.connect(ctx)
@ -332,7 +332,7 @@ func (c *client) NewPacketConnection(ctx context.Context, conn socks.PacketConn,
}
return socks.CopyPacketConn(ctx, &trojan.PacketConn{Conn: tlsConn}, conn)*/
clientConn := trojan.NewClientPacketConn(tlsConn, c.key)
return socks.CopyPacketConn(ctx, clientConn, conn)
return N.CopyPacketConn(ctx, clientConn, conn)
}
func (c *client) HandleError(err error) {

View file

@ -17,9 +17,8 @@ import (
E "github.com/sagernet/sing/common/exceptions"
_ "github.com/sagernet/sing/common/log"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/trojan"
"github.com/sagernet/sing/transport/tcp"
transTLS "github.com/sagernet/sing/transport/tls"
@ -193,7 +192,7 @@ func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
}
return s.service.NewConnection(ctx, tls.Server(conn, &s.tlsConfig), metadata)
}
destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination)
destConn, err := N.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination)
if err != nil {
return err
}
@ -201,13 +200,13 @@ func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
return rw.CopyConn(ctx, conn, destConn)
}
func (s *server) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error {
func (s *server) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
logrus.Info("inbound UDP ", metadata.Source, " ==> ", metadata.Destination)
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return err
}
return socks.CopyNetPacketConn(ctx, conn, udpConn)
return N.CopyNetPacketConn(ctx, conn, udpConn)
}
func (s *server) HandleError(err error) {

View file

@ -13,10 +13,11 @@ import (
E "github.com/sagernet/sing/common/exceptions"
_ "github.com/sagernet/sing/common/log"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
"github.com/sagernet/sing/transport/mixed"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@ -102,7 +103,7 @@ func (c *localClient) NewConnection(ctx context.Context, conn net.Conn, metadata
return E.Cause(err, "connect to upstream")
}
_, err = socks.ClientHandshake(upstream, socks.Version5, socks.CommandConnect, metadata.Destination, "", "")
_, err = socks5.ClientHandshake(upstream, socks5.Version5, socks5.CommandConnect, metadata.Destination, "", "")
if err != nil {
return E.Cause(err, "upstream handshake failed")
}
@ -110,19 +111,19 @@ func (c *localClient) NewConnection(ctx context.Context, conn net.Conn, metadata
return rw.CopyConn(context.Background(), upstream, conn)
}
func (c *localClient) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error {
func (c *localClient) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
upstream, err := net.Dial("tcp", c.upstream)
if err != nil {
return E.Cause(err, "connect to upstream")
}
_, err = socks.ClientHandshake(upstream, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn(uot.UOTMagicAddress), 443), "", "")
_, err = socks5.ClientHandshake(upstream, socks5.Version5, socks5.CommandConnect, M.ParseSocksaddrHostPort(uot.UOTMagicAddress, "443"), "", "")
if err != nil {
return E.Cause(err, "upstream handshake failed")
}
client := uot.NewClientConn(upstream)
return socks.CopyPacketConn(ctx, client, conn)
return N.CopyPacketConn(ctx, client, conn)
}
func (c *localClient) OnError(err error) {

View file

@ -6,93 +6,144 @@ import (
"strconv"
)
type Addr interface {
Family() Family
Addr() netip.Addr
Fqdn() string
String() string
}
type AddrPort struct {
Addr Addr
type Socksaddr struct {
Addr netip.Addr
Fqdn string
Port uint16
}
func (ap AddrPort) IPAddr() *net.IPAddr {
func (ap Socksaddr) Network() string {
return "socks"
}
func (ap Socksaddr) IsIP() bool {
return ap.Addr.IsValid()
}
func (ap Socksaddr) IsFqdn() bool {
return !ap.IsIP()
}
func (ap Socksaddr) IsValid() bool {
return ap.Addr.IsValid() || ap.Fqdn != ""
}
func (ap Socksaddr) Family() Family {
if ap.Addr.IsValid() {
if ap.Addr.Is4() {
return AddressFamilyIPv4
} else {
return AddressFamilyIPv6
}
}
if ap.Fqdn != "" {
return AddressFamilyFqdn
} else if ap.Addr.Is4() || ap.Addr.Is4In6() {
return AddressFamilyIPv4
} else {
return AddressFamilyIPv6
}
}
func (ap Socksaddr) AddrString() string {
if ap.Addr.IsValid() {
return ap.Addr.String()
} else {
return ap.Fqdn
}
}
func (ap Socksaddr) IPAddr() *net.IPAddr {
return &net.IPAddr{
IP: ap.Addr.Addr().AsSlice(),
IP: ap.Addr.AsSlice(),
}
}
func (ap AddrPort) TCPAddr() *net.TCPAddr {
func (ap Socksaddr) TCPAddr() *net.TCPAddr {
return &net.TCPAddr{
IP: ap.Addr.Addr().AsSlice(),
IP: ap.Addr.AsSlice(),
Port: int(ap.Port),
}
}
func (ap AddrPort) UDPAddr() *net.UDPAddr {
func (ap Socksaddr) UDPAddr() *net.UDPAddr {
return &net.UDPAddr{
IP: ap.Addr.Addr().AsSlice(),
IP: ap.Addr.AsSlice(),
Port: int(ap.Port),
}
}
func (ap AddrPort) AddrPort() netip.AddrPort {
return netip.AddrPortFrom(ap.Addr.Addr(), ap.Port)
func (ap Socksaddr) AddrPort() netip.AddrPort {
return netip.AddrPortFrom(ap.Addr, ap.Port)
}
func (ap AddrPort) String() string {
return net.JoinHostPort(ap.Addr.String(), strconv.Itoa(int(ap.Port)))
func (ap Socksaddr) String() string {
return net.JoinHostPort(ap.AddrString(), strconv.Itoa(int(ap.Port)))
}
func ParseAddr(address string) Addr {
addr, err := netip.ParseAddr(address)
if err == nil {
return AddrFromAddr(addr)
func TCPAddr(ap netip.AddrPort) *net.TCPAddr {
return &net.TCPAddr{
IP: ap.Addr().AsSlice(),
Port: int(ap.Port()),
}
return AddrFromFqdn(address)
}
func AddrPortFrom(addr Addr, port uint16) *AddrPort {
return &AddrPort{addr, port}
}
func ParseAddress(address string) (*AddrPort, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
func UDPAddr(ap netip.AddrPort) *net.UDPAddr {
return &net.UDPAddr{
IP: ap.Addr().AsSlice(),
Port: int(ap.Port()),
}
portInt, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
return AddrPortFrom(ParseAddr(host), uint16(portInt)), nil
}
func ParseAddrPort(address string, port string) (*AddrPort, error) {
portInt, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
return AddrPortFrom(ParseAddr(address), uint16(portInt)), nil
func AddrPortFrom(ip net.IP, port uint16) netip.AddrPort {
addr, _ := netip.AddrFromSlice(ip)
return netip.AddrPortFrom(addr, port)
}
func AddrFromNetAddr(netAddr net.Addr) Addr {
func SocksaddrFrom(ip net.IP, port uint16) Socksaddr {
return SocksaddrFromNetIP(AddrPortFrom(ip, port))
}
func SocksaddrFromAddrPort(addr netip.Addr, port uint16) Socksaddr {
return SocksaddrFromNetIP(netip.AddrPortFrom(addr, port))
}
func SocksaddrFromNetIP(ap netip.AddrPort) Socksaddr {
return Socksaddr{
Addr: ap.Addr(),
Port: ap.Port(),
}
}
func SocksaddrFromNet(ap net.Addr) Socksaddr {
if socksAddr, ok := ap.(Socksaddr); ok {
return socksAddr
}
return SocksaddrFromNetIP(AddrPortFromNet(ap))
}
func AddrFromNetAddr(netAddr net.Addr) netip.Addr {
if addr := AddrPortFromNet(netAddr); addr.Addr().IsValid() {
return addr.Addr()
}
switch addr := netAddr.(type) {
case Socksaddr:
return addr.Addr
case *net.IPAddr:
return AddrFromIP(addr.IP)
case *net.IPNet:
return AddrFromIP(addr.IP)
default:
return nil
return netip.Addr{}
}
}
func AddrPortFromNetAddr(netAddr net.Addr) *AddrPort {
func AddrPortFromNet(netAddr net.Addr) netip.AddrPort {
var ip net.IP
var port uint16
switch addr := netAddr.(type) {
case Socksaddr:
return addr.AddrPort()
case *net.TCPAddr:
ip = addr.IP
port = uint16(addr.Port)
@ -102,84 +153,39 @@ func AddrPortFromNetAddr(netAddr net.Addr) *AddrPort {
case *net.IPAddr:
ip = addr.IP
}
return AddrPortFrom(AddrFromIP(ip), port)
return netip.AddrPortFrom(AddrFromIP(ip), port)
}
func AddrFromIP(ip net.IP) Addr {
func AddrFromIP(ip net.IP) netip.Addr {
addr, _ := netip.AddrFromSlice(ip)
if addr.Is4() || addr.Is4In6() {
return Addr4(addr.As4())
return addr
}
func ParseAddr(s string) netip.Addr {
addr, _ := netip.ParseAddr(s)
return addr
}
func ParseSocksaddr(address string) Socksaddr {
host, port, err := net.SplitHostPort(address)
if err != nil {
return Socksaddr{}
}
return ParseSocksaddrHostPort(host, port)
}
func ParseSocksaddrHostPort(host string, portStr string) Socksaddr {
port, _ := strconv.Atoi(portStr)
netAddr, err := netip.ParseAddr(host)
if err != nil {
return Socksaddr{
Fqdn: host,
Port: uint16(port),
}
} else {
return Addr16(addr.As16())
return Socksaddr{
Addr: netAddr,
Port: uint16(port),
}
}
}
func AddrFromAddr(addr netip.Addr) Addr {
if addr.Is4() && addr.Is4In6() {
return Addr4(addr.As4())
} else {
return Addr16(addr.As16())
}
}
func AddrPortFromAddrPort(addrPort netip.AddrPort) *AddrPort {
return AddrPortFrom(AddrFromAddr(addrPort.Addr()), addrPort.Port())
}
func AddrFromFqdn(fqdn string) Addr {
return AddrFqdn(fqdn)
}
type Addr4 [4]byte
func (a Addr4) Family() Family {
return AddressFamilyIPv4
}
func (a Addr4) Addr() netip.Addr {
return netip.AddrFrom4(a)
}
func (a Addr4) Fqdn() string {
return ""
}
func (a Addr4) String() string {
return netip.AddrFrom4(a).String()
}
type Addr16 [16]byte
func (a Addr16) Family() Family {
return AddressFamilyIPv6
}
func (a Addr16) Addr() netip.Addr {
return netip.AddrFrom16(a)
}
func (a Addr16) Fqdn() string {
return ""
}
func (a Addr16) String() string {
return netip.AddrFrom16(a).String()
}
type AddrFqdn string
func (f AddrFqdn) Family() Family {
return AddressFamilyFqdn
}
func (f AddrFqdn) Addr() netip.Addr {
return netip.Addr{}
}
func (f AddrFqdn) Fqdn() string {
return string(f)
}
func (f AddrFqdn) String() string {
return string(f)
}

View file

@ -7,8 +7,8 @@ import (
type Metadata struct {
Protocol string
Source *AddrPort
Destination *AddrPort
Source Socksaddr
Destination Socksaddr
}
type TCPConnectionHandler interface {

View file

@ -3,6 +3,7 @@ package metadata
import (
"encoding/binary"
"io"
"net/netip"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
@ -41,28 +42,27 @@ func NewSerializer(options ...SerializerOption) *Serializer {
return s
}
func (s *Serializer) WriteAddress(writer io.Writer, addr Addr) error {
func (s *Serializer) WriteAddress(writer io.Writer, addr Socksaddr) error {
err := rw.WriteByte(writer, s.familyByteMap[addr.Family()])
if err != nil {
return err
}
if addr.Family().IsIP() {
err = rw.WriteBytes(writer, addr.Addr().AsSlice())
if addr.Addr.IsValid() {
err = rw.WriteBytes(writer, addr.Addr.AsSlice())
} else {
domain := addr.Fqdn()
err = WriteString(writer, "fqdn", domain)
err = WriteString(writer, "fqdn", addr.Fqdn)
}
return err
}
func (s *Serializer) AddressLen(addr Addr) int {
func (s *Serializer) AddressLen(addr Socksaddr) int {
switch addr.Family() {
case AddressFamilyIPv4:
return 5
case AddressFamilyIPv6:
return 17
default:
return 2 + len(addr.Fqdn())
return 2 + len(addr.Fqdn)
}
}
@ -70,10 +70,10 @@ func (s *Serializer) WritePort(writer io.Writer, port uint16) error {
return binary.Write(writer, binary.BigEndian, port)
}
func (s *Serializer) WriteAddrPort(writer io.Writer, destination *AddrPort) error {
func (s *Serializer) WriteAddrPort(writer io.Writer, destination Socksaddr) error {
var err error
if !s.portFirst {
err = s.WriteAddress(writer, destination.Addr)
err = s.WriteAddress(writer, destination)
} else {
err = s.WritePort(writer, destination.Port)
}
@ -81,48 +81,50 @@ func (s *Serializer) WriteAddrPort(writer io.Writer, destination *AddrPort) erro
return err
}
if s.portFirst {
err = s.WriteAddress(writer, destination.Addr)
err = s.WriteAddress(writer, destination)
} else {
err = s.WritePort(writer, destination.Port)
}
return err
}
func (s *Serializer) AddrPortLen(destination *AddrPort) int {
return s.AddressLen(destination.Addr) + 2
func (s *Serializer) AddrPortLen(destination Socksaddr) int {
return s.AddressLen(destination) + 2
}
func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) {
func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) {
af, err := rw.ReadByte(reader)
if err != nil {
return nil, err
return Socksaddr{}, err
}
family := s.familyMap[af]
switch family {
case AddressFamilyFqdn:
fqdn, err := ReadString(reader)
if err != nil {
return nil, E.Cause(err, "read fqdn")
return Socksaddr{}, E.Cause(err, "read fqdn")
}
return AddrFqdn(fqdn), nil
return Socksaddr{
Fqdn: fqdn,
}, nil
default:
switch family {
case AddressFamilyIPv4:
var addr [4]byte
err = common.Error(reader.Read(addr[:]))
if err != nil {
return nil, E.Cause(err, "read ipv4 address")
return Socksaddr{}, E.Cause(err, "read ipv4 address")
}
return Addr4(addr), nil
return Socksaddr{Addr: netip.AddrFrom4(addr)}, nil
case AddressFamilyIPv6:
var addr [16]byte
err = common.Error(reader.Read(addr[:]))
if err != nil {
return nil, E.Cause(err, "read ipv6 address")
return Socksaddr{}, E.Cause(err, "read ipv6 address")
}
return Addr16(addr), nil
return Socksaddr{Addr: netip.AddrFrom16(addr)}, nil
default:
return nil, E.New("unknown address family: ", af)
return Socksaddr{}, E.New("unknown address family: ", af)
}
}
}
@ -135,8 +137,8 @@ func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) {
return binary.BigEndian.Uint16(port), nil
}
func (s *Serializer) ReadAddrPort(reader io.Reader) (destination *AddrPort, err error) {
var addr Addr
func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err error) {
var addr Socksaddr
var port uint16
if !s.portFirst {
addr, err = s.ReadAddress(reader)
@ -154,7 +156,8 @@ func (s *Serializer) ReadAddrPort(reader io.Reader) (destination *AddrPort, err
if err != nil {
return
}
return AddrPortFrom(addr, port), nil
addr.Port = port
return addr, nil
}
func ReadString(reader io.Reader) (string, error) {

View file

@ -13,12 +13,8 @@ func LocalAddrs() ([]netip.Addr, error) {
if err != nil {
return nil, err
}
return common.Map(common.Filter(common.Map(interfaceAddrs, func(addr net.Addr) M.Addr {
return common.Map(interfaceAddrs, func(addr net.Addr) netip.Addr {
return M.AddrFromNetAddr(addr)
}), func(addr M.Addr) bool {
return addr != nil
}), func(it M.Addr) netip.Addr {
return it.Addr()
}), nil
}

View file

@ -1,4 +1,4 @@
package socks
package network
import (
"context"
@ -14,11 +14,11 @@ import (
)
type PacketReader interface {
ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error)
ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error)
}
type PacketWriter interface {
WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error
WritePacket(buffer *buf.Buffer, addr M.Socksaddr) error
}
type PacketConn interface {
@ -27,7 +27,6 @@ type PacketConn interface {
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
SetDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
@ -100,25 +99,49 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
}
func CopyNetPacketConn(ctx context.Context, conn PacketConn, dest net.PacketConn) error {
return CopyPacketConn(ctx, conn, &PacketConnWrapper{dest})
if udpConn, ok := dest.(*net.UDPConn); ok {
return CopyPacketConn(ctx, conn, &UDPConnWrapper{udpConn})
} else {
return CopyPacketConn(ctx, conn, &PacketConnWrapper{dest})
}
}
type UDPConnWrapper struct {
*net.UDPConn
}
func (w *UDPConnWrapper) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, addr, err := w.ReadFromUDPAddrPort(buffer.FreeBytes())
if err != nil {
return M.Socksaddr{}, err
}
buffer.Truncate(n)
return M.SocksaddrFromNetIP(addr), nil
}
func (w *UDPConnWrapper) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination.Family().IsFqdn() {
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
if err != nil {
return err
}
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
}
return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
}
type PacketConnWrapper struct {
net.PacketConn
}
func (p *PacketConnWrapper) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (p *PacketConnWrapper) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
_, addr, err := buffer.ReadPacketFrom(p)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
return M.AddrPortFromNetAddr(addr), err
return M.SocksaddrFromNet(addr), err
}
func (p *PacketConnWrapper) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (p *PacketConnWrapper) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return common.Error(p.WriteTo(buffer.Bytes(), destination.UDPAddr()))
}
func (p *PacketConnWrapper) RemoteAddr() net.Addr {
return &common.DummyAddr{}
}

View file

@ -8,7 +8,7 @@ import (
)
type ContextDialer interface {
DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error)
DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error)
}
var SystemDialer ContextDialer = &DefaultDialer{}
@ -21,7 +21,7 @@ func (d *DefaultDialer) ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPC
return net.ListenUDP(network, laddr)
}
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error) {
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) {
return d.Dialer.DialContext(ctx, network, address.String())
}

View file

@ -2,12 +2,13 @@ package redir
import (
"net"
"net/netip"
"syscall"
M "github.com/sagernet/sing/common/metadata"
)
func GetOriginalDestination(conn net.Conn) (destination *M.AddrPort, err error) {
func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err error) {
rawConn, err := conn.(syscall.Conn).SyscallConn()
if err != nil {
return
@ -23,14 +24,14 @@ func GetOriginalDestination(conn net.Conn) (destination *M.AddrPort, err error)
if conn.RemoteAddr().(*net.TCPAddr).IP.To4() != nil {
raw, err := syscall.GetsockoptIPv6Mreq(int(rawFd), syscall.IPPROTO_IP, SO_ORIGINAL_DST)
if err != nil {
return nil, err
return netip.AddrPort{}, err
}
return M.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
return netip.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
} else {
raw, err := syscall.GetsockoptIPv6MTUInfo(int(rawFd), syscall.IPPROTO_IPV6, SO_ORIGINAL_DST)
if err != nil {
return nil, err
return netip.AddrPort{}, err
}
return M.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), raw.Addr.Port), nil
return netip.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), raw.Addr.Port), nil
}
}

View file

@ -5,10 +5,9 @@ package redir
import (
"errors"
"net"
M "github.com/sagernet/sing/common/metadata"
"net/netip"
)
func GetOriginalDestination(conn net.Conn) (destination *M.AddrPort, err error) {
return nil, errors.New("unsupported platform")
func GetOriginalDestination(conn net.Conn) (destination netip.AddrPort, err error) {
return netip.AddrPort{}, errors.New("unsupported platform")
}

View file

@ -4,6 +4,7 @@ import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"syscall"
@ -36,19 +37,19 @@ func FWMark(fd uintptr, mark int) error {
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark)
}
func GetOriginalDestinationFromOOB(oob []byte) (*M.AddrPort, error) {
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
controlMessages, err := unix.ParseSocketControlMessage(oob)
if err != nil {
return nil, err
return netip.AddrPort{}, err
}
for _, message := range controlMessages {
if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR {
return M.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil
return netip.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil
} else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR {
return M.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil
return netip.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil
}
}
return nil, E.New("not found")
return netip.AddrPort{}, E.New("not found")
}
func DialUDP(network string, lAddr *net.UDPAddr, rAddr *net.UDPAddr) (*net.UDPConn, error) {

View file

@ -4,9 +4,9 @@ package redir
import (
"net"
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
func TProxy(fd uintptr, isIPv6 bool) error {
@ -21,8 +21,8 @@ func FWMark(fd uintptr, mark int) error {
return E.New("only available on linux")
}
func GetOriginalDestinationFromOOB(oob []byte) (*M.AddrPort, error) {
return nil, E.New("only available on linux")
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
return netip.AddrPort{}, E.New("only available on linux")
}
func DialUDP(network string, lAddr *net.UDPAddr, rAddr *net.UDPAddr) (*net.UDPConn, error) {

View file

@ -1,66 +0,0 @@
package session
import (
"net"
"strconv"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
type Network int
const (
NetworkTCP Network = iota
NetworkUDP
)
type InstanceContext struct{}
type Context struct {
InstanceContext
Network Network
Source M.Addr
Destination M.Addr
SourcePort uint16
DestinationPort uint16
}
func (c Context) DestinationNetAddr() string {
return net.JoinHostPort(c.Destination.String(), strconv.Itoa(int(c.DestinationPort)))
}
func AddressFromNetAddr(netAddr net.Addr) (addr M.Addr, port uint16) {
var ip net.IP
switch addr := netAddr.(type) {
case *net.TCPAddr:
ip = addr.IP
port = uint16(addr.Port)
case *net.UDPAddr:
ip = addr.IP
port = uint16(addr.Port)
}
return M.AddrFromIP(ip), port
}
type Conn struct {
Conn net.Conn
Context *Context
}
type PacketConn struct {
Conn net.PacketConn
Context *Context
}
type Packet struct {
Context *Context
Data *buf.Buffer
WriteBack func(buffer *buf.Buffer, addr *net.UDPAddr) error
Release func()
}
type Handler interface {
HandleConnection(conn *Conn)
HandlePacket(packet *Packet)
}

View file

@ -1,63 +0,0 @@
package session
import (
"container/list"
"sync"
"github.com/sagernet/sing/common"
)
var (
connectionPool list.List
connectionPoolEnabled bool
connectionAccess sync.Mutex
)
func EnableConnectionPool() {
connectionPoolEnabled = true
}
func DisableConnectionPool() {
connectionAccess.Lock()
defer connectionAccess.Unlock()
connectionPoolEnabled = false
clearConnections()
}
func AddConnection(connection any) any {
if !connectionPoolEnabled {
return connection
}
connectionAccess.Lock()
defer connectionAccess.Unlock()
return connectionPool.PushBack(connection)
}
func RemoveConnection(anyElement any) {
element, ok := anyElement.(*list.Element)
if !ok {
common.Close(anyElement)
return
}
if element.Value == nil {
return
}
common.Close(element.Value)
element.Value = nil
connectionAccess.Lock()
defer connectionAccess.Unlock()
connectionPool.Remove(element)
}
func ResetConnections() {
connectionAccess.Lock()
defer connectionAccess.Unlock()
clearConnections()
}
func clearConnections() {
for element := connectionPool.Front(); element != nil; element = element.Next() {
common.Close(element)
}
connectionPool.Init()
}

420
common/tun/system/tun.go Normal file
View file

@ -0,0 +1,420 @@
package system
import (
"context"
"net"
"net/netip"
"os"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/log"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/tun"
"github.com/sagernet/sing/common/udpnat"
"gvisor.dev/gvisor/pkg/tcpip"
tcpipBuffer "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
var logger = log.NewLogger("tun <system>")
type Stack struct {
tunFd uintptr
tunMtu int
inetAddress netip.Prefix
inet6Address netip.Prefix
handler tun.Handler
tunFile *os.File
tcpForwarder *net.TCPListener
tcpPort uint16
tcpSessions *cache.LruCache[netip.AddrPort, netip.AddrPort]
udpNat *udpnat.Service[netip.AddrPort]
}
func New(tunFd uintptr, tunMtu int, inetAddress netip.Prefix, inet6Address netip.Prefix, packetTimeout int64, handler tun.Handler) tun.Stack {
return &Stack{
tunFd: tunFd,
tunMtu: tunMtu,
inetAddress: inetAddress,
inet6Address: inet6Address,
handler: handler,
tunFile: os.NewFile(tunFd, "tun"),
tcpSessions: cache.New(
cache.WithAge[netip.AddrPort, netip.AddrPort](packetTimeout),
cache.WithUpdateAgeOnGet[netip.AddrPort, netip.AddrPort](),
),
udpNat: udpnat.New[netip.AddrPort](packetTimeout, handler),
}
}
func (t *Stack) Start() error {
var network string
var address net.TCPAddr
if !t.inet6Address.IsValid() {
network = "tcp4"
address.IP = t.inetAddress.Addr().AsSlice()
} else {
network = "tcp"
address.IP = net.IPv6zero
}
tcpListener, err := net.ListenTCP(network, &address)
if err != nil {
return err
}
t.tcpForwarder = tcpListener
go t.tcpLoop()
go t.tunLoop()
return nil
}
func (t *Stack) Close() error {
t.tcpForwarder.Close()
t.tunFile.Close()
return nil
}
func (t *Stack) tunLoop() {
_buffer := buf.Make(t.tunMtu)
buffer := common.Dup(_buffer)
for {
n, err := t.tunFile.Read(buffer)
if err != nil {
t.handler.HandleError(err)
break
}
packet := buffer[:n]
t.deliverPacket(packet)
}
}
func (t *Stack) deliverPacket(packet []byte) {
var err error
switch header.IPVersion(packet) {
case header.IPv4Version:
ipHdr := header.IPv4(packet)
switch ipHdr.TransportProtocol() {
case header.TCPProtocolNumber:
err = t.processIPv4TCP(ipHdr, ipHdr.Payload())
case header.UDPProtocolNumber:
err = t.processIPv4UDP(ipHdr, ipHdr.Payload())
default:
_, err = t.tunFile.Write(packet)
}
case header.IPv6Version:
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: tcpipBuffer.View(packet).ToVectorisedView(),
})
proto, _, _, _, ok := parse.IPv6(pkt)
pkt.DecRef()
if !ok {
return
}
ipHdr := header.IPv6(packet)
switch proto {
case header.TCPProtocolNumber:
err = t.processIPv6TCP(ipHdr, ipHdr.Payload())
case header.UDPProtocolNumber:
err = t.processIPv6UDP(ipHdr, ipHdr.Payload())
default:
_, err = t.tunFile.Write(packet)
}
}
if err != nil {
t.handler.HandleError(err)
}
}
func (t *Stack) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) error {
sourceAddress := ipHdr.SourceAddress()
destinationAddress := ipHdr.DestinationAddress()
sourcePort := tcpHdr.SourcePort()
destinationPort := tcpHdr.DestinationPort()
logger.Trace(sourceAddress, ":", sourcePort, " => ", destinationAddress, ":", destinationPort)
if sourcePort != t.tcpPort {
key := M.AddrPortFrom(net.IP(destinationAddress), sourcePort)
t.tcpSessions.LoadOrStore(key, func() netip.AddrPort {
return M.AddrPortFrom(net.IP(sourceAddress), destinationPort)
})
ipHdr.SetSourceAddress(destinationAddress)
ipHdr.SetDestinationAddress(tcpip.Address(t.inetAddress.Addr().AsSlice()))
tcpHdr.SetDestinationPort(t.tcpPort)
} else {
key := M.AddrPortFrom(net.IP(destinationAddress), destinationPort)
session, loaded := t.tcpSessions.Load(key)
if !loaded {
return E.New("unknown tcp session with source port ", destinationPort, " to destination address ", destinationAddress)
}
ipHdr.SetSourceAddress(destinationAddress)
tcpHdr.SetSourcePort(session.Port())
ipHdr.SetDestinationAddress(tcpip.Address(session.Addr().AsSlice()))
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.ChecksumCombine(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), uint16(len(tcpHdr))),
header.Checksum(tcpHdr.Payload(), 0),
)))
_, err := t.tunFile.Write(ipHdr)
return err
}
func (t *Stack) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) error {
sourceAddress := ipHdr.SourceAddress()
destinationAddress := ipHdr.DestinationAddress()
sourcePort := tcpHdr.SourcePort()
destinationPort := tcpHdr.DestinationPort()
if sourcePort != t.tcpPort {
key := M.AddrPortFrom(net.IP(destinationAddress), sourcePort)
t.tcpSessions.LoadOrStore(key, func() netip.AddrPort {
return M.AddrPortFrom(net.IP(sourceAddress), destinationPort)
})
ipHdr.SetSourceAddress(destinationAddress)
ipHdr.SetDestinationAddress(tcpip.Address(t.inet6Address.Addr().AsSlice()))
tcpHdr.SetDestinationPort(t.tcpPort)
} else {
key := M.AddrPortFrom(net.IP(destinationAddress), destinationPort)
session, loaded := t.tcpSessions.Load(key)
if !loaded {
return E.New("unknown tcp session with source port ", destinationPort, " to destination address ", destinationAddress)
}
ipHdr.SetSourceAddress(destinationAddress)
tcpHdr.SetSourcePort(session.Port())
ipHdr.SetDestinationAddress(tcpip.Address(session.Addr().AsSlice()))
}
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.ChecksumCombine(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), uint16(len(tcpHdr))),
header.Checksum(tcpHdr.Payload(), 0),
)))
_, err := t.tunFile.Write(ipHdr)
return err
}
func (t *Stack) tcpLoop() {
for {
logger.Trace("tcp start")
tcpConn, err := t.tcpForwarder.AcceptTCP()
logger.Trace("tcp accept")
if err != nil {
t.handler.HandleError(err)
return
}
key := M.AddrPortFromNet(tcpConn.RemoteAddr())
session, ok := t.tcpSessions.Load(key)
if !ok {
tcpConn.Close()
logger.Warn("dropped unknown tcp session from ", key)
continue
}
var metadata M.Metadata
metadata.Protocol = "tun"
metadata.Source.Addr = session.Addr()
metadata.Source.Port = key.Port()
metadata.Destination.Addr = key.Addr()
metadata.Destination.Port = session.Port()
go t.processConn(tcpConn, metadata, key)
}
}
func (t *Stack) processConn(conn *net.TCPConn, metadata M.Metadata, key netip.AddrPort) {
err := t.handler.NewConnection(context.Background(), conn, metadata)
if err != nil {
t.handler.HandleError(err)
}
t.tcpSessions.Delete(key)
}
func (t *Stack) processIPv4UDP(ipHdr header.IPv4, hdr header.UDP) error {
var metadata M.Metadata
metadata.Protocol = "tun"
metadata.Source = M.SocksaddrFrom(net.IP(ipHdr.SourceAddress()), hdr.SourcePort())
metadata.Source = M.SocksaddrFrom(net.IP(ipHdr.DestinationAddress()), hdr.DestinationPort())
headerCache := buf.New()
_, err := headerCache.Write(ipHdr[:ipHdr.HeaderLength()+header.UDPMinimumSize])
if err != nil {
return err
}
logger.Trace("[UDP] ", metadata.Source, "=>", metadata.Destination)
t.udpNat.NewPacket(metadata.Source.AddrPort(), func() N.PacketWriter {
return &inetPacketWriter{
tun: t,
headerCache: headerCache,
sourceAddress: ipHdr.SourceAddress(),
destination: ipHdr.DestinationAddress(),
destinationPort: hdr.DestinationPort(),
}
}, buf.With(hdr), metadata)
return nil
}
type inetPacketWriter struct {
tun *Stack
headerCache *buf.Buffer
sourceAddress tcpip.Address
destination tcpip.Address
destinationPort uint16
}
func (w *inetPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
index := w.headerCache.Len()
newHeader := w.headerCache.Extend(w.headerCache.Len())
copy(newHeader, w.headerCache.Bytes())
w.headerCache.Advance(index)
defer func() {
w.headerCache.FullReset()
w.headerCache.Resize(0, index)
}()
var newSourceAddress tcpip.Address
var newSourcePort uint16
if destination.IsValid() {
newSourceAddress = tcpip.Address(destination.Addr.AsSlice())
newSourcePort = destination.Port
} else {
newSourceAddress = w.destination
newSourcePort = w.destinationPort
}
newIpHdr := header.IPv4(newHeader)
newIpHdr.SetSourceAddress(newSourceAddress)
newIpHdr.SetTotalLength(uint16(int(w.headerCache.Len()) + buffer.Len()))
newIpHdr.SetChecksum(0)
newIpHdr.SetChecksum(^newIpHdr.CalculateChecksum())
udpHdr := header.UDP(w.headerCache.From(w.headerCache.Len() - header.UDPMinimumSize))
udpHdr.SetSourcePort(newSourcePort)
udpHdr.SetLength(uint16(header.UDPMinimumSize + buffer.Len()))
udpHdr.SetChecksum(0)
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(header.Checksum(buffer.Bytes(), header.PseudoHeaderChecksum(header.UDPProtocolNumber, newSourceAddress, w.sourceAddress, uint16(header.UDPMinimumSize+buffer.Len())))))
replyVV := tcpipBuffer.VectorisedView{}
replyVV.AppendView(newHeader)
replyVV.AppendView(buffer.Bytes())
return w.tun.WriteVV(replyVV)
}
func (w *inetPacketWriter) Close() error {
w.headerCache.Release()
return nil
}
func (t *Stack) processIPv6UDP(ipHdr header.IPv6, hdr header.UDP) error {
var metadata M.Metadata
metadata.Protocol = "tun"
metadata.Source = M.SocksaddrFrom(net.IP(ipHdr.SourceAddress()), hdr.SourcePort())
metadata.Destination = M.SocksaddrFrom(net.IP(ipHdr.DestinationAddress()), hdr.DestinationPort())
headerCache := buf.New()
_, err := headerCache.Write(ipHdr[:uint16(len(ipHdr))-ipHdr.PayloadLength()+header.UDPMinimumSize])
if err != nil {
return err
}
t.udpNat.NewPacket(metadata.Source.AddrPort(), func() N.PacketWriter {
return &inet6PacketWriter{
tun: t,
headerCache: headerCache,
sourceAddress: ipHdr.SourceAddress(),
destination: ipHdr.DestinationAddress(),
destinationPort: hdr.DestinationPort(),
}
}, buf.With(hdr), metadata)
return nil
}
type inet6PacketWriter struct {
tun *Stack
headerCache *buf.Buffer
sourceAddress tcpip.Address
destination tcpip.Address
destinationPort uint16
}
func (w *inet6PacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
index := w.headerCache.Len()
newHeader := w.headerCache.Extend(w.headerCache.Len())
copy(newHeader, w.headerCache.Bytes())
w.headerCache.Advance(index)
defer func() {
w.headerCache.FullReset()
w.headerCache.Resize(0, index)
}()
var newSourceAddress tcpip.Address
var newSourcePort uint16
if destination.IsValid() {
newSourceAddress = tcpip.Address(destination.Addr.AsSlice())
newSourcePort = destination.Port
} else {
newSourceAddress = w.destination
newSourcePort = w.destinationPort
}
newIpHdr := header.IPv6(newHeader)
newIpHdr.SetSourceAddress(newSourceAddress)
newIpHdr.SetPayloadLength(uint16(header.UDPMinimumSize + buffer.Len()))
udpHdr := header.UDP(w.headerCache.From(w.headerCache.Len() - header.UDPMinimumSize))
udpHdr.SetSourcePort(newSourcePort)
udpHdr.SetLength(uint16(header.UDPMinimumSize + buffer.Len()))
udpHdr.SetChecksum(0)
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(header.Checksum(buffer.Bytes(), header.PseudoHeaderChecksum(header.UDPProtocolNumber, newSourceAddress, w.sourceAddress, uint16(header.UDPMinimumSize+buffer.Len())))))
replyVV := tcpipBuffer.VectorisedView{}
replyVV.AppendView(newHeader)
replyVV.AppendView(buffer.Bytes())
return w.tun.WriteVV(replyVV)
}
func (t *Stack) WriteVV(vv tcpipBuffer.VectorisedView) error {
data := make([][]byte, 0, len(vv.Views()))
for _, view := range vv.Views() {
data = append(data, view)
}
return common.Error(rw.WriteV(t.tunFd, data...))
}
func (w *inet6PacketWriter) Close() error {
w.headerCache.Release()
return nil
}
type tcpipError struct {
Err tcpip.Error
}
func (e *tcpipError) Error() string {
return e.Err.String()
}

18
common/tun/tun.go Normal file
View file

@ -0,0 +1,18 @@
package tun
import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type Handler interface {
M.TCPConnectionHandler
N.UDPConnectionHandler
E.Handler
}
type Stack interface {
Start() error
Close() error
}

171
common/tun/tun_linux.go Normal file
View file

@ -0,0 +1,171 @@
package tun
/*
import (
"bytes"
"net"
"syscall"
"unsafe"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/sys/unix"
)
const ifReqSize = unix.IFNAMSIZ + 64
func (t *Interface) Name() (string, error) {
if t.tunName != "" {
return t.tunName, nil
}
var ifr [ifReqSize]byte
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(t.tunFd), uintptr(unix.TUNGETIFF), uintptr(unsafe.Pointer(&ifr[0])))
if errno != 0 {
return "", errno
}
name := ifr[:]
if i := bytes.IndexByte(name, 0); i != -1 {
name = name[:i]
}
t.tunName = string(name)
return t.tunName, nil
}
func (t *Interface) MTU() (int, error) {
name, err := t.Name()
if err != nil {
return 0, err
}
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return 0, err
}
defer unix.Close(fd)
var ifr [ifReqSize]byte
copy(ifr[:], name)
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr[0])))
if errno != 0 {
return 0, errno
}
return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
}
func (t *Interface) SetMTU(mtu int) error {
name, err := t.Name()
if err != nil {
return err
}
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return err
}
defer unix.Close(fd)
var ifr [ifReqSize]byte
copy(ifr[:], name)
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(mtu)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCSIFMTU),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return errno
}
return nil
}
func (t *Interface) SetAddress() error {
name, err := t.Name()
if err != nil {
return err
}
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return err
}
defer unix.Close(fd)
ifreq, err := unix.NewIfreq(name)
if err != nil {
return E.Cause(err, "failed to create ifreq for name ", name)
}
ifreq.SetInet4Addr(t.inetAddress.Addr().AsSlice())
err = unix.IoctlIfreq(fd, syscall.SIOCSIFADDR, ifreq)
if err == nil {
ifreq, _ = unix.NewIfreq(name)
ifreq.SetInet4Addr(net.CIDRMask(t.inetAddress.Bits(), 32))
err = unix.IoctlIfreq(fd, syscall.SIOCSIFNETMASK, ifreq)
}
if err != nil {
return E.Cause(err, "failed to set ipv4 address on ", name)
}
if t.inet6Address.IsValid() {
ifreq, _ = unix.NewIfreq(name)
err = unix.IoctlIfreq(fd, syscall.SIOCGIFINDEX, ifreq)
if err != nil {
return E.Cause(err, "failed to get interface index for ", name)
}
ifreq6 := in6_ifreq{
ifr6_addr: in6_addr{
addr: t.inet6Address.Addr().As16(),
},
ifr6_prefixlen: uint32(t.inet6Address.Bits()),
ifr6_ifindex: ifreq.Uint32(),
}
fd6, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return err
}
defer unix.Close(fd6)
if _, _, errno := syscall.Syscall(
syscall.SYS_IOCTL,
uintptr(fd6),
uintptr(syscall.SIOCSIFADDR),
uintptr(unsafe.Pointer(&ifreq6)),
); errno != 0 {
return E.Cause(errno, "failed to set ipv6 address on ", name)
}
}
ifreq, _ = unix.NewIfreq(name)
err = unix.IoctlIfreq(fd, syscall.SIOCGIFFLAGS, ifreq)
if err == nil {
ifreq.SetUint16(ifreq.Uint16() | syscall.IFF_UP | syscall.IFF_RUNNING)
err = unix.IoctlIfreq(fd, syscall.SIOCSIFFLAGS, ifreq)
}
if err != nil {
return E.Cause(err, "failed to bring tun device up")
}
return nil
}
type in6_addr struct {
addr [16]byte
}
type in6_ifreq struct {
ifr6_addr in6_addr
ifr6_prefixlen uint32
ifr6_ifindex uint32
}
*/

View file

@ -13,11 +13,11 @@ import (
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
N "github.com/sagernet/sing/common/network"
)
type Handler interface {
socks.UDPConnectionHandler
N.UDPConnectionHandler
E.Handler
}
@ -36,15 +36,16 @@ func New[K comparable](maxAge int64, handler Handler) *Service[K] {
}
}
func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) {
func (s *Service[T]) NewPacket(key T, writer func() N.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) {
s.NewContextPacket(context.Background(), key, writer, buffer, metadata)
}
func (s *Service[T]) NewContextPacket(ctx context.Context, key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) {
func (s *Service[T]) NewContextPacket(ctx context.Context, key T, writer func() N.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) {
c, loaded := s.nat.LoadOrStore(key, func() *conn {
c := &conn{
data: make(chan packet),
remoteAddr: metadata.Source.UDPAddr(),
localAddr: metadata.Source,
remoteAddr: metadata.Destination,
source: writer(),
}
c.ctx, c.cancel = context.WithCancel(ctx)
@ -80,7 +81,7 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, writer func()
type packet struct {
data *buf.Buffer
destination *M.AddrPort
destination M.Socksaddr
done context.CancelFunc
}
@ -89,15 +90,16 @@ type conn struct {
ctx context.Context
cancel context.CancelFunc
data chan packet
remoteAddr *net.UDPAddr
source socks.PacketWriter
localAddr M.Socksaddr
remoteAddr M.Socksaddr
source N.PacketWriter
}
func (c *conn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *conn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
select {
case p, ok := <-c.data:
if !ok {
return nil, io.ErrClosedPipe
return M.Socksaddr{}, io.ErrClosedPipe
}
defer p.data.Release()
_, err := buffer.ReadFrom(p.data)
@ -106,7 +108,7 @@ func (c *conn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
}
}
func (c *conn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.source.WritePacket(buffer, destination)
}
@ -126,7 +128,7 @@ func (c *conn) Close() error {
}
func (c *conn) LocalAddr() net.Addr {
return &common.DummyAddr{}
return c.localAddr
}
func (c *conn) RemoteAddr() net.Addr {

View file

@ -18,23 +18,23 @@ func NewClientConn(conn net.Conn) *ClientConn {
return &ClientConn{conn}
}
func (c *ClientConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *ClientConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
destination, err := AddrParser.ReadAddrPort(c)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
var length uint16
err = binary.Read(c, binary.BigEndian, &length)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
if buffer.FreeLen() < int(length) {
return nil, io.ErrShortBuffer
return M.Socksaddr{}, io.ErrShortBuffer
}
return destination, common.Error(buffer.ReadFullFrom(c, int(length)))
}
func (c *ClientConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *ClientConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
err := AddrParser.WriteAddrPort(c, destination)
if err != nil {
return err
@ -68,7 +68,7 @@ func (c *ClientConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
}
func (c *ClientConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
err = AddrParser.WriteAddrPort(c, M.AddrPortFromNetAddr(addr))
err = AddrParser.WriteAddrPort(c, M.SocksaddrFromNet(addr))
if err != nil {
return
}

View file

@ -47,8 +47,8 @@ func (c *ServerConn) loopInput() {
if err != nil {
break
}
if destination.Addr.Family().IsFqdn() {
ip, err := LookupAddress(destination.Addr.Fqdn())
if destination.Family().IsFqdn() {
ip, err := LookupAddress(destination.Fqdn)
if err != nil {
break
}
@ -81,8 +81,7 @@ func (c *ServerConn) loopOutput() {
if err != nil {
break
}
destination := M.AddrPortFromNetAddr(addr)
err = AddrParser.WriteAddrPort(c.outputWriter, destination)
err = AddrParser.WriteAddrPort(c.outputWriter, M.SocksaddrFromNet(addr))
if err != nil {
break
}

2
go.mod
View file

@ -23,6 +23,7 @@ require (
golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6
golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d
google.golang.org/protobuf v1.28.0
gvisor.dev/gvisor v0.0.0-20220428010907-8082b77961ba
lukechampine.com/blake3 v1.1.7
)
@ -31,6 +32,7 @@ require (
github.com/cenkalti/backoff/v4 v4.1.1 // indirect
github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/btree v1.0.1 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/klauspost/cpuid/v2 v2.0.12 // indirect

4
go.sum
View file

@ -154,6 +154,8 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu
github.com/golangci/lint-1 v0.0.0-20181222135242-d2cdd8c08219/go.mod h1:/X8TswGSh1pIozq4ZwCfxS0WA5JGXguxk94ar/4c87Y=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
@ -754,6 +756,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20220428010907-8082b77961ba h1:qJ6jWSTl9q+/y4l8QCNpkNnasX/sHzhVnPRysee8PzY=
gvisor.dev/gvisor v0.0.0-20220428010907-8082b77961ba/go.mod h1:tWwEcFvJavs154OdjFCw78axNrsDlz4Zh8jvPqwcpGI=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View file

@ -45,13 +45,8 @@ func HandleRequest(ctx context.Context, request *http.Request, conn net.Conn, au
if portStr == "" {
portStr = "80"
}
destination, err := M.ParseAddrPort(request.URL.Hostname(), portStr)
if err != nil {
if err != nil {
return err
}
}
_, err = fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established")
destination := M.ParseSocksaddrHostPort(request.URL.Hostname(), portStr)
_, err := fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established")
if err != nil {
return E.Cause(err, "write http response")
}
@ -87,17 +82,11 @@ func HandleRequest(ctx context.Context, request *http.Request, conn net.Conn, au
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, E.New("unsupported network ", network)
}
destination, err := M.ParseAddress(address)
if err != nil {
return nil, err
}
metadata.Destination = M.ParseSocksaddr(address)
metadata.Protocol = "http"
left, right := net.Pipe()
go func() {
metadata.Destination = destination
metadata.Protocol = "http"
err = handler.NewConnection(ctx, right, metadata)
err := handler.NewConnection(ctx, right, metadata)
if err != nil {
handler.HandleError(&tcp.Error{Conn: right, Cause: err})
}

View file

@ -10,8 +10,9 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
)
const MethodNone = "none"
@ -30,7 +31,7 @@ func (m *NoneMethod) KeyLength() int {
return 0
}
func (m *NoneMethod) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
func (m *NoneMethod) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &noneConn{
Conn: conn,
handshake: true,
@ -39,14 +40,14 @@ func (m *NoneMethod) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn,
return shadowsocksConn, shadowsocksConn.clientHandshake()
}
func (m *NoneMethod) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn {
func (m *NoneMethod) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &noneConn{
Conn: conn,
destination: destination,
}
}
func (m *NoneMethod) DialPacketConn(conn net.Conn) socks.PacketConn {
func (m *NoneMethod) DialPacketConn(conn net.Conn) N.PacketConn {
return &nonePacketConn{conn}
}
@ -55,11 +56,11 @@ type noneConn struct {
access sync.Mutex
handshake bool
destination *M.AddrPort
destination M.Socksaddr
}
func (c *noneConn) clientHandshake() error {
err := socks.AddressSerializer.WriteAddrPort(c.Conn, c.destination)
err := socks5.AddressSerializer.WriteAddrPort(c.Conn, c.destination)
if err != nil {
return err
}
@ -87,7 +88,7 @@ func (c *noneConn) Write(b []byte) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination)
err = socks5.AddressSerializer.WriteAddrPort(buffer, c.destination)
if err != nil {
return
}
@ -132,19 +133,19 @@ type nonePacketConn struct {
net.Conn
}
func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
_, err := buffer.ReadFrom(c)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
return socks.AddressSerializer.ReadAddrPort(buffer)
return socks5.AddressSerializer.ReadAddrPort(buffer)
}
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort M.Socksaddr) error {
defer buffer.Release()
_header := buf.StackNewMax()
header := common.Dup(_header)
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
err := socks5.AddressSerializer.WriteAddrPort(header, addrPort)
if err != nil {
header.Release()
return err
@ -167,7 +168,7 @@ func NewNoneService(udpTimeout int64, handler Handler) Service {
}
func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
destination, err := socks.AddressSerializer.ReadAddrPort(conn)
destination, err := socks5.AddressSerializer.ReadAddrPort(conn)
if err != nil {
return err
}
@ -176,34 +177,34 @@ func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata
return s.handler.NewConnection(ctx, conn, metadata)
}
func (s *NoneService) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
func (s *NoneService) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
return err
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
s.udp.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter {
s.udp.NewPacket(metadata.Source.AddrPort(), func() N.PacketWriter {
return &nonePacketWriter{conn, metadata.Source}
}, buffer, metadata)
return nil
}
type nonePacketWriter struct {
socks.PacketConn
sourceAddr *M.AddrPort
N.PacketConn
sourceAddr M.Socksaddr
}
func (s *nonePacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
header := buf.With(buffer.ExtendHeader(socks.AddressSerializer.AddrPortLen(destination)))
err := socks.AddressSerializer.WriteAddrPort(header, destination)
func (s *nonePacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buf.With(buffer.ExtendHeader(socks5.AddressSerializer.AddrPortLen(destination)))
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
return s.PacketConn.WritePacket(buffer, s.sourceAddr)
}
func (s *NoneService) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error {
func (s *NoneService) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(ctx, conn, metadata)
}

View file

@ -8,15 +8,15 @@ import (
"net"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
N "github.com/sagernet/sing/common/network"
)
type Method interface {
Name() string
KeyLength() int
DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error)
DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
DialPacketConn(conn net.Conn) socks.PacketConn
DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error)
DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn
DialPacketConn(conn net.Conn) N.PacketConn
}
func Key(password []byte, keySize int) []byte {

View file

@ -7,17 +7,17 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
N "github.com/sagernet/sing/common/network"
)
type Service interface {
M.TCPConnectionHandler
socks.UDPHandler
N.UDPHandler
}
type Handler interface {
M.TCPConnectionHandler
socks.UDPConnectionHandler
N.UDPConnectionHandler
E.Handler
}
@ -34,7 +34,7 @@ type UserContext[U comparable] struct {
type ServerConnError struct {
net.Conn
Source *M.AddrPort
Source M.Socksaddr
Cause error
}
@ -47,8 +47,8 @@ func (e *ServerConnError) Error() string {
}
type ServerPacketError struct {
socks.PacketConn
Source *M.AddrPort
N.PacketConn
Source M.Socksaddr
Cause error
}

View file

@ -12,10 +12,11 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/hkdf"
)
@ -138,7 +139,7 @@ func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) {
return NewWriter(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
}
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &clientConn{
Conn: conn,
method: m,
@ -147,7 +148,7 @@ func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, err
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
}
func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn {
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &clientConn{
Conn: conn,
method: m,
@ -155,7 +156,7 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
}
}
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
func (m *Method) DialPacketConn(conn net.Conn) N.PacketConn {
return &clientPacketConn{m, conn}
}
@ -186,7 +187,7 @@ type clientConn struct {
net.Conn
method *Method
destination *M.AddrPort
destination M.Socksaddr
access sync.Mutex
reader *Reader
@ -209,7 +210,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
bufferedWriter := writer.BufferedWriter(header.Len())
if len(payload) > 0 {
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
err := socks5.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return err
}
@ -219,7 +220,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
return err
}
} else {
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
err := socks5.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return err
}
@ -325,10 +326,10 @@ type clientPacketConn struct {
net.Conn
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
header := buffer.ExtendHeader(c.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buffer.ExtendHeader(c.keySaltLength + socks5.AddressSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(c.secureRNG, header[:c.keySaltLength]))
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
err := socks5.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
if err != nil {
return err
}
@ -339,17 +340,17 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo
return common.Error(c.Write(buffer.Bytes()))
}
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
buffer.Truncate(n)
err = c.DecodePacket(buffer)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
return socks.AddressSerializer.ReadAddrPort(buffer)
return socks5.AddressSerializer.ReadAddrPort(buffer)
}
func (c *clientPacketConn) UpstreamReader() io.Reader {

View file

@ -12,11 +12,12 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
"golang.org/x/crypto/chacha20poly1305"
)
@ -97,7 +98,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
key := Kdf(s.key, salt, s.keySaltLength)
reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize)
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
destination, err := socks5.AddressSerializer.ReadAddrPort(reader)
if err != nil {
return err
}
@ -198,7 +199,7 @@ func (c *serverConn) WriterReplaceable() bool {
return c.writer != nil
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *Service) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
err := s.newPacket(conn, buffer, metadata)
if err != nil {
err = &shadowsocks.ServerPacketError{PacketConn: conn, Source: metadata.Source, Cause: err}
@ -206,7 +207,7 @@ func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata
return err
}
func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
if buffer.Len() < s.keySaltLength {
return E.New("bad packet")
}
@ -219,7 +220,7 @@ func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata
buffer.Advance(s.keySaltLength)
buffer.Truncate(len(packet))
metadata.Protocol = "shadowsocks"
s.udpNat.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter {
s.udpNat.NewPacket(metadata.Source.AddrPort(), func() N.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source}
}, buffer, metadata)
return nil
@ -227,14 +228,14 @@ func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata
type serverPacketWriter struct {
*Service
socks.PacketConn
source *M.AddrPort
N.PacketConn
source M.Socksaddr
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
header := buffer.ExtendHeader(w.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buffer.ExtendHeader(w.keySaltLength + socks5.AddressSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(w.secureRNG, header[:w.keySaltLength]))
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
err := socks5.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
if err != nil {
return err
}

View file

@ -19,11 +19,12 @@ import (
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/log"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
"golang.org/x/crypto/chacha20poly1305"
wgReplay "golang.zx2c4.com/wireguard/replay"
"lukechampine.com/blake3"
@ -163,7 +164,7 @@ func (m *Method) KeyLength() int {
return m.keyLength
}
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &clientConn{
Conn: conn,
method: m,
@ -172,7 +173,7 @@ func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, err
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
}
func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn {
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &clientConn{
Conn: conn,
method: m,
@ -180,7 +181,7 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
}
}
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
func (m *Method) DialPacketConn(conn net.Conn) N.PacketConn {
return &clientPacketConn{conn, m, m.newUDPSession()}
}
@ -188,7 +189,7 @@ type clientConn struct {
net.Conn
method *Method
destination *M.AddrPort
destination M.Socksaddr
request sync.Mutex
response sync.Mutex
@ -267,7 +268,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeClient))
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
err := socks5.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return E.Cause(err, "write destination")
}
@ -465,7 +466,7 @@ type clientPacketConn struct {
session *udpSession
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if debug.Enabled {
logger.Trace("begin client packet")
}
@ -534,7 +535,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
err := socks.AddressSerializer.WriteAddrPort(header, destination)
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
@ -551,14 +552,16 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo
buffer.Extend(c.session.cipher.Overhead())
c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader)
}
logger.Trace("ended client packet")
if debug.Enabled {
logger.Trace("ended client packet")
}
return common.Error(c.Write(buffer.Bytes()))
}
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
buffer.Truncate(n)
@ -566,7 +569,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
if c.method.udpCipher != nil {
_, err = c.method.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
if err != nil {
return nil, E.Cause(err, "decrypt packet")
return M.Socksaddr{}, E.Cause(err, "decrypt packet")
}
buffer.Advance(PacketNonceSize)
} else {
@ -577,11 +580,11 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
var sessionId, packetId uint64
err = binary.Read(buffer, binary.BigEndian, &sessionId)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
err = binary.Read(buffer, binary.BigEndian, &packetId)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
var remoteCipher cipher.AEAD
@ -596,42 +599,42 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
}
_, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
if err != nil {
return nil, E.Cause(err, "decrypt packet")
return M.Socksaddr{}, E.Cause(err, "decrypt packet")
}
}
var headerType byte
headerType, err = buffer.ReadByte()
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
if headerType != HeaderTypeServer {
return nil, ErrBadHeaderType
return M.Socksaddr{}, ErrBadHeaderType
}
var epoch uint64
err = binary.Read(buffer, binary.BigEndian, &epoch)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 {
return nil, ErrBadTimestamp
return M.Socksaddr{}, ErrBadTimestamp
}
if sessionId == c.session.remoteSessionId {
if !c.session.filter.ValidateCounter(packetId, math.MaxUint64) {
return nil, ErrPacketIdNotUnique
return M.Socksaddr{}, ErrPacketIdNotUnique
}
} else if sessionId == c.session.lastRemoteSessionId {
if !c.session.lastFilter.ValidateCounter(packetId, math.MaxUint64) {
return nil, ErrPacketIdNotUnique
return M.Socksaddr{}, ErrPacketIdNotUnique
}
remoteCipher = c.session.lastRemoteCipher
c.session.lastRemoteSeen = time.Now().Unix()
} else {
if c.session.remoteSessionId != 0 {
if time.Now().Unix()-c.session.lastRemoteSeen < 60 {
return nil, ErrTooManyServerSessions
return M.Socksaddr{}, ErrTooManyServerSessions
} else {
c.session.lastRemoteSessionId = c.session.remoteSessionId
c.session.lastFilter = c.session.filter
@ -648,20 +651,20 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
var clientSessionId uint64
err = binary.Read(buffer, binary.BigEndian, &clientSessionId)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
if clientSessionId != c.session.sessionId {
return nil, ErrBadClientSessionId
return M.Socksaddr{}, ErrBadClientSessionId
}
var paddingLength uint16
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
if err != nil {
return nil, E.Cause(err, "read padding length")
return M.Socksaddr{}, E.Cause(err, "read padding length")
}
buffer.Advance(int(paddingLength))
return socks.AddressSerializer.ReadAddrPort(buffer)
return socks5.AddressSerializer.ReadAddrPort(buffer)
}
type udpSession struct {

View file

@ -18,12 +18,13 @@ import (
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
wgReplay "golang.zx2c4.com/wireguard/replay"
)
@ -132,7 +133,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
return ErrBadTimestamp
}
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
destination, err := socks5.AddressSerializer.ReadAddrPort(reader)
if err != nil {
return E.Cause(err, "read destination")
}
@ -268,7 +269,7 @@ func (c *serverConn) WriterReplaceable() bool {
return c.writer != nil
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *Service) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
err := s.newPacket(conn, buffer, metadata)
if err != nil {
err = &shadowsocks.ServerPacketError{PacketConn: conn, Source: metadata.Source, Cause: err}
@ -276,7 +277,7 @@ func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata
return err
}
func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
var packetHeader []byte
if s.udpCipher != nil {
_, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
@ -358,14 +359,14 @@ process:
}
buffer.Advance(int(paddingLength))
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
goto returnErr
}
metadata.Destination = destination
session.remoteAddr = metadata.Source
s.udpNat.NewPacket(sessionId, func() socks.PacketWriter {
s.udpNat.NewPacket(sessionId, func() N.PacketWriter {
return &serverPacketWriter{s, conn, session}
}, buffer, metadata)
return nil
@ -373,11 +374,11 @@ process:
type serverPacketWriter struct {
*Service
socks.PacketConn
N.PacketConn
session *serverUDPSession
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
_header := buf.StackNew()
@ -400,7 +401,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.Addr
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
err := socks.AddressSerializer.WriteAddrPort(header, destination)
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
@ -425,7 +426,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.Addr
type serverUDPSession struct {
sessionId uint64
remoteSessionId uint64
remoteAddr *M.AddrPort
remoteAddr M.Socksaddr
packetId uint64
cipher cipher.AEAD
remoteCipher cipher.AEAD

View file

@ -13,10 +13,11 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
"lukechampine.com/blake3"
)
@ -140,7 +141,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
return ErrBadTimestamp
}
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
destination, err := socks5.AddressSerializer.ReadAddrPort(reader)
if err != nil {
return E.Cause(err, "read destination")
}
@ -173,7 +174,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
}, metadata)
}
func (s *MultiService[U]) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *MultiService[U]) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
err := s.newPacket(conn, buffer, metadata)
if err != nil {
err = &shadowsocks.ServerPacketError{PacketConn: conn, Source: metadata.Source, Cause: err}
@ -181,7 +182,7 @@ func (s *MultiService[U]) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, m
return err
}
func (s *MultiService[U]) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
packetHeader := buffer.To(aes.BlockSize)
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
@ -272,7 +273,7 @@ process:
}
buffer.Advance(int(paddingLength))
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
goto returnErr
}
@ -284,7 +285,7 @@ process:
userCtx.Context = context.Background()
userCtx.User = user
s.udpNat.NewContextPacket(&userCtx, sessionId, func() socks.PacketWriter {
s.udpNat.NewContextPacket(&userCtx, sessionId, func() N.PacketWriter {
return &serverPacketWriter{s.Service, conn, session}
}, buffer, metadata)
return nil

View file

@ -1,75 +0,0 @@
package socks_test
import (
"net"
"sync"
"testing"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
)
func TestHandshake(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
defer client.Close()
wg := new(sync.WaitGroup)
wg.Add(1)
method := socks.AuthTypeUsernamePassword
go func() {
response, err := socks.ClientHandshake(client, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn("test"), 80), "user", "pswd")
if err != nil {
t.Fatal(err)
}
if response.ReplyCode != socks.ReplyCodeSuccess {
t.Fatal(response)
}
wg.Done()
}()
authRequest, err := socks.ReadAuthRequest(server)
if err != nil {
t.Fatal(err)
}
if len(authRequest.Methods) != 1 || authRequest.Methods[0] != method {
t.Fatal("bad methods: ", authRequest.Methods)
}
err = socks.WriteAuthResponse(server, &socks.AuthResponse{
Version: socks.Version5,
Method: method,
})
if err != nil {
t.Fatal(err)
}
usernamePasswordAuthRequest, err := socks.ReadUsernamePasswordAuthRequest(server)
if err != nil {
t.Fatal(err)
}
if usernamePasswordAuthRequest.Username != "user" || usernamePasswordAuthRequest.Password != "pswd" {
t.Fatal(authRequest)
}
err = socks.WriteUsernamePasswordAuthResponse(server, &socks.UsernamePasswordAuthResponse{
Status: socks.UsernamePasswordStatusSuccess,
})
if err != nil {
t.Fatal(err)
}
request, err := socks.ReadRequest(server)
if err != nil {
t.Fatal(err)
}
if request.Version != socks.Version5 || request.Command != socks.CommandConnect || request.Destination.Addr.Fqdn() != "test" || request.Destination.Port != 80 {
t.Fatal(request)
}
err = socks.WriteResponse(server, &socks.Response{
Version: socks.Version5,
ReplyCode: socks.ReplyCodeSuccess,
Bind: M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0),
})
if err != nil {
t.Fatal(err)
}
wg.Wait()
}

View file

@ -1,4 +1,4 @@
package socks
package socks5
import (
"net"
@ -12,10 +12,10 @@ type AssociateConn struct {
net.Conn
conn net.Conn
addr net.Addr
dest *M.AddrPort
dest M.Socksaddr
}
func NewAssociateConn(conn net.Conn, packetConn net.Conn, destination *M.AddrPort) *AssociateConn {
func NewAssociateConn(conn net.Conn, packetConn net.Conn, destination M.Socksaddr) *AssociateConn {
return &AssociateConn{
Conn: packetConn,
conn: conn,
@ -46,7 +46,7 @@ func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
err = AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil {
return
}
@ -80,17 +80,17 @@ func (c *AssociateConn) Write(b []byte) (n int, err error) {
return
}
func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, err := buffer.ReadFrom(c.conn)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
buffer.Truncate(int(n))
buffer.Advance(3)
return AddressSerializer.ReadAddrPort(buffer)
}
func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
common.Must(header.WriteZeroN(3))
@ -102,10 +102,10 @@ type AssociatePacketConn struct {
net.PacketConn
conn net.Conn
addr net.Addr
dest *M.AddrPort
dest M.Socksaddr
}
func NewAssociatePacketConn(conn net.Conn, packetConn net.PacketConn, destination *M.AddrPort) *AssociatePacketConn {
func NewAssociatePacketConn(conn net.Conn, packetConn net.PacketConn, destination M.Socksaddr) *AssociatePacketConn {
return &AssociatePacketConn{
PacketConn: packetConn,
conn: conn,
@ -137,7 +137,7 @@ func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
err = AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil {
return
}
@ -171,10 +171,10 @@ func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
return
}
func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes())
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
c.addr = addr
buffer.Truncate(n)
@ -183,7 +183,7 @@ func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error
return dest, err
}
func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
common.Must(header.WriteZeroN(3))

View file

@ -1,4 +1,4 @@
package socks
package socks5
import (
"strconv"

View file

@ -1,4 +1,4 @@
package socks
package socks5
import "fmt"

View file

@ -1,4 +1,4 @@
package socks
package socks5
import (
"io"
@ -8,7 +8,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
)
func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination *M.AddrPort, username string, password string) (*Response, error) {
func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination M.Socksaddr, username string, password string) (*Response, error) {
var method byte
if common.IsBlank(username) {
method = AuthTypeNotRequired
@ -56,7 +56,7 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination
return ReadResponse(conn)
}
func ClientFastHandshake(writer io.Writer, version byte, command byte, destination *M.AddrPort, username string, password string) error {
func ClientFastHandshake(writer io.Writer, version byte, command byte, destination M.Socksaddr, username string, password string) error {
var method byte
if common.IsBlank(username) {
method = AuthTypeNotRequired

View file

@ -1,4 +1,4 @@
package socks
package socks5
import (
"context"
@ -10,12 +10,13 @@ import (
"github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/transport/tcp"
)
type Handler interface {
tcp.Handler
UDPConnectionHandler
N.UDPConnectionHandler
}
type Listener struct {
@ -36,7 +37,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler
}
func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
return HandleConnection(ctx, conn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata)
return HandleConnection(ctx, conn, l.authenticator, M.AddrFromNetAddr(conn.LocalAddr()), l.handler, metadata)
}
func (l *Listener) Start() error {
@ -117,7 +118,7 @@ func handleConnection(authRequest *AuthRequest, ctx context.Context, conn net.Co
err = WriteResponse(conn, &Response{
Version: request.Version,
ReplyCode: ReplyCodeSuccess,
Bind: M.AddrPortFromNetAddr(conn.LocalAddr()),
Bind: M.SocksaddrFromNet(conn.LocalAddr()),
})
if err != nil {
return E.Cause(err, "write socks response")
@ -138,7 +139,7 @@ func handleConnection(authRequest *AuthRequest, ctx context.Context, conn net.Co
err = WriteResponse(conn, &Response{
Version: request.Version,
ReplyCode: ReplyCodeSuccess,
Bind: M.AddrPortFromNetAddr(udpConn.LocalAddr()),
Bind: M.SocksaddrFromNet(udpConn.LocalAddr()),
})
if err != nil {
return E.Cause(err, "write socks response")

View file

@ -0,0 +1 @@
package socks5

View file

@ -1,9 +1,9 @@
package socks
package socks5
import (
"bytes"
"io"
"net"
"net/netip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
@ -203,7 +203,7 @@ func ReadUsernamePasswordAuthResponse(reader io.Reader) (*UsernamePasswordAuthRe
type Request struct {
Version byte
Command byte
Destination *M.AddrPort
Destination M.Socksaddr
}
func WriteRequest(writer io.Writer, request *Request) error {
@ -262,7 +262,7 @@ func ReadRequest(reader io.Reader) (*Request, error) {
type Response struct {
Version byte
ReplyCode ReplyCode
Bind *M.AddrPort
Bind M.Socksaddr
}
func WriteResponse(writer io.Writer, response *Response) error {
@ -278,8 +278,10 @@ func WriteResponse(writer io.Writer, response *Response) error {
if err != nil {
return err
}
if response.Bind == nil {
return AddressSerializer.WriteAddrPort(writer, M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0))
if !response.Bind.IsValid() {
return AddressSerializer.WriteAddrPort(writer, M.Socksaddr{
Addr: netip.IPv4Unspecified(),
})
}
return AddressSerializer.WriteAddrPort(writer, response.Bind)
}
@ -320,7 +322,7 @@ func ReadResponse(reader io.Reader) (*Response, error) {
type AssociatePacket struct {
Fragment byte
Destination *M.AddrPort
Destination M.Socksaddr
Data []byte
}

View file

@ -0,0 +1,75 @@
package socks5_test
import (
"net"
"sync"
"testing"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks5"
)
func TestHandshake(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
defer client.Close()
wg := new(sync.WaitGroup)
wg.Add(1)
method := socks5.AuthTypeUsernamePassword
go func() {
response, err := socks5.ClientHandshake(client, socks5.Version5, socks5.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn("test"), 80), "user", "pswd")
if err != nil {
t.Fatal(err)
}
if response.ReplyCode != socks5.ReplyCodeSuccess {
t.Fatal(response)
}
wg.Done()
}()
authRequest, err := socks5.ReadAuthRequest(server)
if err != nil {
t.Fatal(err)
}
if len(authRequest.Methods) != 1 || authRequest.Methods[0] != method {
t.Fatal("bad methods: ", authRequest.Methods)
}
err = socks5.WriteAuthResponse(server, &socks5.AuthResponse{
Version: socks5.Version5,
Method: method,
})
if err != nil {
t.Fatal(err)
}
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(server)
if err != nil {
t.Fatal(err)
}
if usernamePasswordAuthRequest.Username != "user" || usernamePasswordAuthRequest.Password != "pswd" {
t.Fatal(authRequest)
}
err = socks5.WriteUsernamePasswordAuthResponse(server, &socks5.UsernamePasswordAuthResponse{
Status: socks5.UsernamePasswordStatusSuccess,
})
if err != nil {
t.Fatal(err)
}
request, err := socks5.ReadRequest(server)
if err != nil {
t.Fatal(err)
}
if request.Version != socks5.Version5 || request.Command != socks5.CommandConnect || request.Destination.Addr.Fqdn() != "test" || request.Destination.Port != 80 {
t.Fatal(request)
}
err = socks5.WriteResponse(server, &socks5.Response{
Version: socks5.Version5,
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0),
})
if err != nil {
t.Fatal(err)
}
wg.Wait()
}

View file

@ -12,7 +12,7 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
)
const (
@ -26,11 +26,11 @@ var CRLF = []byte{'\r', '\n'}
type ClientConn struct {
net.Conn
key [KeyLength]byte
destination *M.AddrPort
destination M.Socksaddr
headerWritten bool
}
func NewClientConn(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort) *ClientConn {
func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn {
return &ClientConn{
Conn: conn,
key: key,
@ -75,11 +75,11 @@ func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {
}
}
func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
return ReadPacket(c.Conn, buffer)
}
func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if !c.headerWritten {
return ClientHandshakePacket(c.Conn, c.key, destination, buffer)
}
@ -98,7 +98,7 @@ func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
}
func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
err = c.WritePacket(buf.With(p), M.AddrPortFromNetAddr(addr))
err = c.WritePacket(buf.With(p), M.SocksaddrFromNet(addr))
if err == nil {
n = len(p)
}
@ -113,7 +113,7 @@ func Key(password string) [KeyLength]byte {
return key
}
func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination *M.AddrPort, payload []byte) error {
func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error {
_, err := conn.Write(key[:])
if err != nil {
return err
@ -126,7 +126,7 @@ func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destin
if err != nil {
return err
}
err = socks.AddressSerializer.WriteAddrPort(conn, destination)
err = socks5.AddressSerializer.WriteAddrPort(conn, destination)
if err != nil {
return err
}
@ -143,8 +143,8 @@ func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destin
return nil
}
func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort, payload []byte) error {
headerLen := KeyLength + socks.AddressSerializer.AddrPortLen(destination) + 5
func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error {
headerLen := KeyLength + socks5.AddressSerializer.AddrPortLen(destination) + 5
var header *buf.Buffer
var writeHeader bool
if len(payload) > 0 && headerLen+len(payload) < 65535 {
@ -158,7 +158,7 @@ func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort
common.Must1(header.Write(key[:]))
common.Must1(header.Write(CRLF))
common.Must(header.WriteByte(CommandTCP))
common.Must(socks.AddressSerializer.WriteAddrPort(header, destination))
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))
common.Must1(header.Write(CRLF))
common.Must1(header.Write(payload))
@ -176,8 +176,8 @@ func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort
return nil
}
func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort, payload *buf.Buffer) error {
headerLen := KeyLength + 2*socks.AddressSerializer.AddrPortLen(destination) + 9
func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
headerLen := KeyLength + 2*socks5.AddressSerializer.AddrPortLen(destination) + 9
payloadLen := payload.Len()
var header *buf.Buffer
var writeHeader bool
@ -191,9 +191,9 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination *M.Ad
common.Must1(header.Write(key[:]))
common.Must1(header.Write(CRLF))
common.Must(header.WriteByte(CommandUDP))
common.Must(socks.AddressSerializer.WriteAddrPort(header, destination))
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))
common.Must1(header.Write(CRLF))
common.Must(socks.AddressSerializer.WriteAddrPort(header, destination))
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))
common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
common.Must1(header.Write(CRLF))
@ -211,33 +211,33 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination *M.Ad
return nil
}
func ReadPacket(conn net.Conn, buffer *buf.Buffer) (*M.AddrPort, error) {
destination, err := socks.AddressSerializer.ReadAddrPort(conn)
func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
destination, err := socks5.AddressSerializer.ReadAddrPort(conn)
if err != nil {
return nil, E.Cause(err, "read destination")
return M.Socksaddr{}, E.Cause(err, "read destination")
}
var length uint16
err = binary.Read(conn, binary.BigEndian, &length)
if err != nil {
return nil, E.Cause(err, "read chunk length")
return M.Socksaddr{}, E.Cause(err, "read chunk length")
}
if buffer.FreeLen() < int(length) {
return nil, io.ErrShortBuffer
return M.Socksaddr{}, io.ErrShortBuffer
}
err = rw.SkipN(conn, 2)
if err != nil {
return nil, E.Cause(err, "skip crlf")
return M.Socksaddr{}, E.Cause(err, "skip crlf")
}
_, err = buffer.ReadFullFrom(conn, int(length))
return destination, err
}
func WritePacket(conn net.Conn, buffer *buf.Buffer, destination *M.AddrPort) error {
headerOverload := socks.AddressSerializer.AddrPortLen(destination) + 4
func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error {
headerOverload := socks5.AddressSerializer.AddrPortLen(destination) + 4
var header *buf.Buffer
var writeHeader bool
bufferLen := buffer.Len()
@ -248,7 +248,7 @@ func WritePacket(conn net.Conn, buffer *buf.Buffer, destination *M.AddrPort) err
_buffer := buf.Make(headerOverload)
header = buf.With(common.Dup(_buffer))
}
common.Must(socks.AddressSerializer.WriteAddrPort(header, destination))
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))
common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
common.Must1(header.Write(CRLF))
if writeHeader {

View file

@ -10,13 +10,14 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
)
type Handler interface {
M.TCPConnectionHandler
socks.UDPConnectionHandler
N.UDPConnectionHandler
}
type Context[K comparable] struct {
@ -115,7 +116,7 @@ process:
goto returnErr
}
destination, err := socks.AddressSerializer.ReadAddrPort(conn)
destination, err := socks5.AddressSerializer.ReadAddrPort(conn)
if err != nil {
err = E.Cause(err, "read destination")
goto returnErr
@ -141,11 +142,11 @@ type PacketConn struct {
net.Conn
}
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
return ReadPacket(c.Conn, buffer)
}
func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return WritePacket(c.Conn, buffer, destination)
}

View file

@ -15,17 +15,18 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/http"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
"github.com/sagernet/sing/transport/tcp"
"github.com/sagernet/sing/transport/udp"
)
type Handler interface {
socks.Handler
socks5.Handler
}
type Listener struct {
@ -53,15 +54,15 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transpro
}
func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
if metadata.Destination != nil {
if metadata.Destination.IsValid() {
return l.handler.NewConnection(ctx, conn, metadata)
}
headerType, err := rw.ReadByte(conn)
switch headerType {
case socks.Version4:
case socks5.Version4:
return E.New("socks4 request dropped (TODO)")
case socks.Version5:
return socks.HandleConnection0(ctx, conn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata)
case socks5.Version5:
return socks5.HandleConnection0(ctx, conn, l.authenticator, M.AddrFromNetAddr(conn.LocalAddr()), l.handler, metadata)
}
reader := bufio.NewReader(&rw.BufferedReader{
@ -75,7 +76,7 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.
}
if request.Method == "GET" && request.URL.Path == "/proxy.pac" {
content := newPAC(M.AddrPortFromNetAddr(conn.LocalAddr()))
content := newPAC(M.AddrPortFromNet(conn.LocalAddr()))
response := &netHttp.Response{
StatusCode: 200,
Status: netHttp.StatusText(200),
@ -113,8 +114,8 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.
return http.HandleRequest(ctx, request, conn, l.authenticator, l.handler, metadata)
}
func (l *Listener) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
l.udpNat.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter {
func (l *Listener) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
l.udpNat.NewPacket(metadata.Source.AddrPort(), func() N.PacketWriter {
return &tproxyPacketWriter{metadata.Source.UDPAddr()}
}, buffer, metadata)
return nil
@ -124,7 +125,7 @@ type tproxyPacketWriter struct {
source *net.UDPAddr
}
func (w *tproxyPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (w *tproxyPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), w.source)
if err != nil {
return E.Cause(err, "tproxy udp write back")

View file

@ -1,8 +1,10 @@
package mixed
import M "github.com/sagernet/sing/common/metadata"
import (
"net/netip"
)
/*func newPAC(proxyAddr *M.AddrPort) string {
/*func newPAC(proxyAddr M.Socksaddr) string {
return `
function FindProxyForURL(url, host) {
return "SOCKS5 ` + proxyAddr.String() + `;SOCKS ` + proxyAddr.String() + `; PROXY ` + proxyAddr.String() + `";
@ -10,7 +12,7 @@ function FindProxyForURL(url, host) {
}
*/
func newPAC(proxyAddr *M.AddrPort) string {
func newPAC(proxyAddr netip.AddrPort) string {
// TODO: socks4 not supported
return `
function FindProxyForURL(url, host) {

View file

@ -89,14 +89,14 @@ func (l *Listener) loop() {
return
}
metadata := M.Metadata{
Source: M.AddrPortFromNetAddr(tcpConn.RemoteAddr()),
Source: M.SocksaddrFromNet(tcpConn.RemoteAddr()),
}
switch l.trans {
case redir.ModeRedirect:
destination, err := redir.GetOriginalDestination(tcpConn)
if err == nil {
metadata.Protocol = "redirect"
metadata.Destination = destination
metadata.Destination = M.SocksaddrFromNetIP(destination)
}
case redir.ModeTProxy:
lAddr := tcpConn.LocalAddr().(*net.TCPAddr)
@ -104,7 +104,7 @@ func (l *Listener) loop() {
if lAddr.Port != l.lAddr.Port || !lAddr.IP.Equal(rAddr.IP) && !lAddr.IP.IsLoopback() && !lAddr.IP.IsPrivate() {
metadata.Protocol = "tproxy"
metadata.Destination = M.AddrPortFromNetAddr(lAddr)
metadata.Destination = M.SocksaddrFromNet(lAddr)
}
}
go func() {

View file

@ -8,12 +8,12 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/protocol/socks"
)
type Handler interface {
socks.UDPHandler
N.UDPHandler
E.Handler
}
@ -25,17 +25,24 @@ type Listener struct {
tproxy bool
}
func (l *Listener) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, addr, err := l.ReadFromUDP(buffer.FreeBytes())
func (l *Listener) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, addr, err := l.ReadFromUDPAddrPort(buffer.FreeBytes())
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
buffer.Truncate(n)
return M.AddrPortFromNetAddr(addr), nil
return M.SocksaddrFromNetIP(addr), nil
}
func (l *Listener) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
return common.Error(l.UDPConn.WriteTo(buffer.Bytes(), destination.UDPAddr()))
func (l *Listener) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination.Family().IsFqdn() {
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
if err != nil {
return err
}
return common.Error(l.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
}
return common.Error(l.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
}
func NewUDPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener {
@ -88,7 +95,7 @@ func (l *Listener) loop() {
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader).Slice()
if !l.tproxy {
for {
n, addr, err := l.ReadFromUDP(data)
n, addr, err := l.ReadFromUDPAddrPort(data)
if err != nil {
l.handler.HandleError(err)
return
@ -96,7 +103,7 @@ func (l *Listener) loop() {
buffer.Resize(buf.ReversedHeader, n)
err = l.handler.NewPacket(l, buffer, M.Metadata{
Protocol: "udp",
Source: M.AddrPortFromNetAddr(addr),
Source: M.SocksaddrFromNetIP(addr),
})
if err != nil {
l.handler.HandleError(err)
@ -119,8 +126,8 @@ func (l *Listener) loop() {
buffer.Resize(buf.ReversedHeader, n)
err = l.handler.NewPacket(l, buffer, M.Metadata{
Protocol: "tproxy",
Source: M.AddrPortFromAddrPort(addr),
Destination: destination,
Source: M.SocksaddrFromNetIP(addr),
Destination: M.SocksaddrFromNetIP(destination),
})
if err != nil {
l.handler.HandleError(err)