diff --git a/README.md b/README.md index d98b1d6..61cf930 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,12 @@ wget 'https://github.com/Dreamacro/maxmind-geoip/releases/latest/download/Countr go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/ss-local ``` +### ss-server + +```shell +go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/ss-server +``` + ### ddns ```shell diff --git a/cli/socks-chk/main.go b/cli/socks-chk/main.go new file mode 100644 index 0000000..b9bc9b6 --- /dev/null +++ b/cli/socks-chk/main.go @@ -0,0 +1,145 @@ +package main + +import ( + "context" + "encoding/binary" + "io" + "net" + "net/netip" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/protocol/socks" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "golang.org/x/net/dns/dnsmessage" +) + +func main() { + command := &cobra.Command{ + Use: "socks-chk address:port", + Args: cobra.ExactArgs(1), + Run: run, + } + if err := command.Execute(); err != nil { + logrus.Fatal(err) + } +} + +func run(cmd *cobra.Command, args []string) { + server, err := M.ParseAddress(args[0]) + if err != nil { + logrus.Fatal("invalid server address ", args[0]) + } + err = testSocksTCP(server) + if err != nil { + logrus.Fatal(err) + } + err = testSocksUDP(server) + if err != nil { + logrus.Fatal(err) + } +} + +func testSocksTCP(server *M.AddrPort) error { + tcpConn, err := net.Dial("tcp", server.String()) + if err != nil { + return err + } + response, err := socks.ClientHandshake(tcpConn, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.ParseAddr("1.0.0.1"), 53), "", "") + if err != nil { + return err + } + if response.ReplyCode != socks.ReplyCodeSuccess { + logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode) + } + + message := &dnsmessage.Message{} + message.Header.ID = 1 + message.Header.RecursionDesired = true + message.Questions = append(message.Questions, dnsmessage.Question{ + Name: dnsmessage.MustNewName("google.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }) + packet, err := message.Pack() + + err = binary.Write(tcpConn, binary.BigEndian, uint16(len(packet))) + if err != nil { + return err + } + + _, err = tcpConn.Write(packet) + if err != nil { + return err + } + + var respLen uint16 + err = binary.Read(tcpConn, binary.BigEndian, &respLen) + if err != nil { + return err + } + + respBuf := buf.Make(int(respLen)) + _, err = io.ReadFull(tcpConn, respBuf) + if err != nil { + return err + } + + common.Must(message.Unpack(respBuf)) + for _, answer := range message.Answers { + logrus.Info("tcp got answer: ", netip.AddrFrom4(answer.Body.(*dnsmessage.AResource).A)) + } + + tcpConn.Close() + + return nil +} + +func testSocksUDP(server *M.AddrPort) error { + tcpConn, err := net.Dial("tcp", server.String()) + if err != nil { + return err + } + dest := M.AddrPortFrom(M.ParseAddr("1.0.0.1"), 53) + response, err := socks.ClientHandshake(tcpConn, socks.Version5, socks.CommandUDPAssociate, dest, "", "") + if err != nil { + return err + } + if response.ReplyCode != socks.ReplyCodeSuccess { + logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode) + } + var dialer net.Dialer + udpConn, err := dialer.DialContext(context.Background(), "udp", response.Bind.String()) + if err != nil { + return err + } + assConn := socks.NewAssociateConn(tcpConn, udpConn, dest) + message := &dnsmessage.Message{} + message.Header.ID = 1 + message.Header.RecursionDesired = true + message.Questions = append(message.Questions, dnsmessage.Question{ + Name: dnsmessage.MustNewName("google.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }) + packet, err := message.Pack() + common.Must(err) + common.Must1(assConn.WriteTo(packet, &net.UDPAddr{ + IP: net.IPv4(1, 0, 0, 1), + Port: 53, + })) + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) + common.Must2(buffer.ReadPacketFrom(assConn)) + common.Must(message.Unpack(buffer.Bytes())) + + for _, answer := range message.Answers { + logrus.Info("udp got answer: ", netip.AddrFrom4(answer.Body.(*dnsmessage.AResource).A)) + } + + udpConn.Close() + tcpConn.Close() + return nil +} diff --git a/cli/ss-local/main.go b/cli/ss-local/main.go index 80ef7e4..671bac8 100644 --- a/cli/ss-local/main.go +++ b/cli/ss-local/main.go @@ -345,28 +345,15 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me return rw.CopyConn(ctx, serverConn, conn) } -func (c *client) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error { +func (c *client) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error { + logrus.Info("outbound ", metadata.Protocol, " UDP ", metadata.Source, " ==> ", metadata.Destination) ctx := context.Background() udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String()) if err != nil { return err } serverConn := c.method.DialPacketConn(udpConn) - return task.Run(ctx, func() error { - var init bool - return socks.CopyPacketConn0(serverConn, conn, func(destination *M.AddrPort, n int) { - if !init { - init = true - logrus.Info("UDP ", conn.LocalAddr(), " ==> ", destination) - } else { - logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination) - } - }) - }, func() error { - return socks.CopyPacketConn0(conn, serverConn, func(destination *M.AddrPort, n int) { - logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination) - }) - }) + return socks.CopyPacketConn(ctx, serverConn, conn) } func run(cmd *cobra.Command, flags *flags) { diff --git a/cli/ss-server/main.go b/cli/ss-server/main.go index d6b493f..a0cf28f 100644 --- a/cli/ss-server/main.go +++ b/cli/ss-server/main.go @@ -3,6 +3,8 @@ package main import ( "context" "encoding/base64" + "encoding/json" + "io/ioutil" "net" "net/netip" "os" @@ -12,22 +14,29 @@ import ( "github.com/sagernet/sing" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/gsync" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/random" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/protocol/shadowsocks" + "github.com/sagernet/sing/protocol/shadowsocks/shadowaead" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022" + "github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/transport/tcp" + "github.com/sagernet/sing/transport/udp" "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) type flags struct { - Bind string `json:"local_address"` - LocalPort uint16 `json:"local_port"` - // Password string `json:"password"` + Server string `json:"server"` + ServerPort uint16 `json:"server_port"` + Bind string `json:"local_address"` + LocalPort uint16 `json:"local_port"` + Password string `json:"password"` Key string `json:"key"` Method string `json:"method"` Verbose bool `json:"verbose"` @@ -35,8 +44,6 @@ type flags struct { } func main() { - logrus.SetLevel(logrus.TraceLevel) - f := new(flags) command := &cobra.Command{ @@ -48,12 +55,16 @@ func main() { }, } + command.Flags().StringVarP(&f.Server, "server", "s", "", "Set the server’s hostname or IP.") + command.Flags().Uint16VarP(&f.ServerPort, "server-port", "p", 0, "Set the server’s port number.") command.Flags().StringVarP(&f.Bind, "local-address", "b", "", "Set the local address.") command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.") - command.Flags().StringVarP(&f.Key, "key", "k", "", "Set the key directly. The key should be encoded with URL-safe Base64.") + command.Flags().StringVar(&f.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.") + command.Flags().StringVarP(&f.Password, "password", "k", "", "Set the password. The server and the client should use the same password.") var supportedCiphers []string supportedCiphers = append(supportedCiphers, shadowsocks.MethodNone) + supportedCiphers = append(supportedCiphers, shadowaead.List...) supportedCiphers = append(supportedCiphers, shadowaead_2022.List...) command.Flags().StringVarP(&f.Method, "encrypt-method", "m", "", "Set the cipher.\n\nSupported ciphers:\n\n"+strings.Join(supportedCiphers, "\n")) @@ -75,6 +86,10 @@ func run(cmd *cobra.Command, f *flags) { if err != nil { logrus.Fatal(err) } + err = s.udpIn.Start() + if err != nil { + logrus.Fatal(err) + } logrus.Info("server started at ", s.tcpIn.TCPListener.Addr()) osSignals := make(chan os.Signal, 1) @@ -82,33 +97,101 @@ func run(cmd *cobra.Command, f *flags) { <-osSignals s.tcpIn.Close() + s.udpIn.Close() } type server struct { tcpIn *tcp.Listener + udpIn *udp.Listener service shadowsocks.Service + udpNat gsync.Map[string, *net.UDPConn] +} + +func (s *server) Start() error { + err := s.tcpIn.Start() + if err != nil { + return err + } + err = s.udpIn.Start() + return err +} + +func (s *server) Close() error { + s.tcpIn.Close() + s.udpIn.Close() + return nil } func newServer(f *flags) (*server, error) { s := new(server) + if f.ConfigFile != "" { + configFile, err := ioutil.ReadFile(f.ConfigFile) + if err != nil { + return nil, E.Cause(err, "read config file") + } + flagsNew := new(flags) + err = json.Unmarshal(configFile, flagsNew) + if err != nil { + return nil, E.Cause(err, "decode config file") + } + if flagsNew.Server != "" && f.Server == "" { + f.Server = flagsNew.Server + } + if flagsNew.ServerPort != 0 && f.ServerPort == 0 { + f.ServerPort = flagsNew.ServerPort + } + if flagsNew.Bind != "" && f.Bind == "" { + f.Bind = flagsNew.Bind + } + if flagsNew.LocalPort != 0 && f.LocalPort == 0 { + f.LocalPort = flagsNew.LocalPort + } + if flagsNew.Password != "" && f.Password == "" { + f.Password = flagsNew.Password + } + if flagsNew.Key != "" && f.Key == "" { + f.Key = flagsNew.Key + } + if flagsNew.Method != "" && f.Method == "" { + f.Method = flagsNew.Method + } + if flagsNew.Verbose { + f.Verbose = true + } + } + + if f.Verbose { + logrus.SetLevel(logrus.TraceLevel) + } + + if f.Server == "" { + return nil, E.New("missing server address") + } else if f.ServerPort == 0 { + return nil, E.New("missing server port") + } else if f.Method == "" { + return nil, E.New("missing method") + } + + var key []byte + if f.Key != "" { + kb, err := base64.StdEncoding.DecodeString(f.Key) + if err != nil { + return nil, E.Cause(err, "decode key") + } + key = kb + } + if f.Method == shadowsocks.MethodNone { s.service = shadowsocks.NewNoneService(s) - } else if common.Contains(shadowaead_2022.List, f.Method) { - var pskList [][]byte - if f.Key != "" { - keyStrList := strings.Split(f.Key, ":") - pskList = make([][]byte, len(keyStrList)) - for i, keyStr := range keyStrList { - key, err := base64.StdEncoding.DecodeString(keyStr) - if err != nil { - return nil, E.Cause(err, "decode key") - } - pskList[i] = key - } + } else if common.Contains(shadowaead.List, f.Method) { + service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Blake3KeyedHash(), false, s) + if err != nil { + return nil, err } - rng := random.System - service, err := shadowaead_2022.NewService(f.Method, pskList[0], rng, s) + s.service = service + } else if common.Contains(shadowaead_2022.List, f.Method) { + service, err := shadowaead_2022.NewService(f.Method, key, random.Blake3KeyedHash(), s) if err != nil { return nil, err } @@ -118,16 +201,17 @@ func newServer(f *flags) (*server, error) { } var bind netip.Addr - if f.Bind != "" { - addr, err := netip.ParseAddr(f.Bind) + if f.Server != "" { + addr, err := netip.ParseAddr(f.Server) if err != nil { - return nil, E.Cause(err, "bad local address") + return nil, E.Cause(err, "bad server address") } bind = addr } else { bind = netip.IPv6Unspecified() } - s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.LocalPort), s) + s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.ServerPort), s) + s.udpIn = udp.NewUDPListener(netip.AddrPortFrom(bind, f.ServerPort), s) return s, nil } @@ -143,6 +227,19 @@ func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me return rw.CopyConn(ctx, conn, destConn) } +func (s *server) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error { + logrus.Info("inbound UDP ", metadata.Source, " ==> ", metadata.Destination) + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return err + } + return socks.CopyNetPacketConn(context.Background(), udpConn, conn) +} + +func (s *server) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + return s.service.NewPacket(conn, buffer, metadata) +} + func (s *server) HandleError(err error) { if E.IsClosed(err) { return diff --git a/cli/uot-local/main.go b/cli/uot-local/main.go index 612355d..67e32f5 100644 --- a/cli/uot-local/main.go +++ b/cli/uot-local/main.go @@ -14,7 +14,6 @@ import ( M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/redir" "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/uot" "github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/transport/mixed" @@ -122,15 +121,7 @@ func (c *localClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) e } client := uot.NewClientConn(upstream) - return task.Run(context.Background(), func() error { - return socks.CopyPacketConn0(client, conn, func(destination *M.AddrPort, n int) { - logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination) - }) - }, func() error { - return socks.CopyPacketConn0(conn, client, func(destination *M.AddrPort, n int) { - logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination) - }) - }) + return socks.CopyPacketConn(context.Background(), client, conn) } func (c *localClient) OnError(err error) { diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 89879bc..2553d60 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -104,11 +104,15 @@ func (b *Buffer) ExtendHeader(size int) []byte { b.start -= size return b.data[b.start-size : b.start] } else { - offset := size - b.start + /*offset := size - b.start end := b.end + size + if end > len(b.data) { + panic("buffer overflow") + } copy(b.data[offset:end], b.data[b.start:b.end]) b.end = end - return b.data[:offset] + return b.data[:offset]*/ + panic("no header available") } } @@ -119,10 +123,18 @@ func (b *Buffer) WriteBufferAtFirst(buffer *Buffer) *Buffer { b.start -= n buffer.Release() return b + } else if buffer.FreeLen() >= b.Len() { + common.Must1(buffer.Write(b.Bytes())) + b.Release() + return buffer + } else if b.FreeLen() >= size { + copy(b.data[b.start+size:b.end+size], b.data[b.start:b.end]) + copy(b.data, buffer.data) + buffer.Release() + return b + } else { + panic("buffer overflow") } - common.Must1(buffer.Write(b.Bytes())) - b.Release() - return buffer } func (b *Buffer) WriteAtFirst(data []byte) (n int, err error) { @@ -305,6 +317,10 @@ func (b *Buffer) Cut(start int, end int) *Buffer { } } +func (b Buffer) Start() int { + return b.start +} + func (b Buffer) Len() int { return b.end - b.start } diff --git a/common/metadata/metadata.go b/common/metadata/metadata.go index 3278850..f91646a 100644 --- a/common/metadata/metadata.go +++ b/common/metadata/metadata.go @@ -3,8 +3,6 @@ package metadata import ( "context" "net" - - "github.com/sagernet/sing/common/buf" ) type Metadata struct { @@ -16,7 +14,3 @@ type Metadata struct { type TCPConnectionHandler interface { NewConnection(ctx context.Context, conn net.Conn, metadata Metadata) error } - -type UDPHandler interface { - NewPacket(packet *buf.Buffer, metadata Metadata) error -} diff --git a/common/metadata/serializer.go b/common/metadata/serializer.go index 045966d..ef764e7 100644 --- a/common/metadata/serializer.go +++ b/common/metadata/serializer.go @@ -55,28 +55,43 @@ func (s *Serializer) WriteAddress(writer io.Writer, addr Addr) error { return err } +func (s *Serializer) AddressLen(addr Addr) int { + switch addr.Family() { + case AddressFamilyIPv4: + return 5 + case AddressFamilyIPv6: + return 17 + default: + return 1 + len(addr.Fqdn()) + } +} + func (s *Serializer) WritePort(writer io.Writer, port uint16) error { return binary.Write(writer, binary.BigEndian, port) } -func (s *Serializer) WriteAddrPort(writer io.Writer, addrPort *AddrPort) error { +func (s *Serializer) WriteAddrPort(writer io.Writer, destination *AddrPort) error { var err error if !s.portFirst { - err = s.WriteAddress(writer, addrPort.Addr) + err = s.WriteAddress(writer, destination.Addr) } else { - err = s.WritePort(writer, addrPort.Port) + err = s.WritePort(writer, destination.Port) } if err != nil { return err } if s.portFirst { - err = s.WriteAddress(writer, addrPort.Addr) + err = s.WriteAddress(writer, destination.Addr) } else { - err = s.WritePort(writer, addrPort.Port) + err = s.WritePort(writer, destination.Port) } return err } +func (s *Serializer) AddrPortLen(destination *AddrPort) int { + return s.AddressLen(destination.Addr) + 2 +} + func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) { af, err := rw.ReadByte(reader) if err != nil { @@ -120,7 +135,7 @@ func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) { return binary.BigEndian.Uint16(port), nil } -func (s *Serializer) ReadAddrPort(reader io.Reader) (addrPort *AddrPort, err error) { +func (s *Serializer) ReadAddrPort(reader io.Reader) (destination *AddrPort, err error) { var addr Addr var port uint16 if !s.portFirst { diff --git a/common/network/dialer.go b/common/network/dialer.go index efb1411..fe1b359 100644 --- a/common/network/dialer.go +++ b/common/network/dialer.go @@ -17,6 +17,17 @@ type DefaultDialer struct { net.Dialer } +func (d *DefaultDialer) ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { + return net.ListenUDP(network, laddr) +} + func (d *DefaultDialer) DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error) { return d.Dialer.DialContext(ctx, network, address.String()) } + +type Listener interface { + Listen(ctx context.Context, network, address string) (net.Listener, error) + ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) +} + +var SystemListener Listener = &net.ListenConfig{} diff --git a/common/udpnat/server.go b/common/udpnat/server.go deleted file mode 100644 index 41938b6..0000000 --- a/common/udpnat/server.go +++ /dev/null @@ -1,108 +0,0 @@ -package udpnat - -import ( - "io" - "net" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/gsync" - M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/redir" - "github.com/sagernet/sing/protocol/socks" -) - -type Handler interface { - socks.UDPConnectionHandler - E.Handler -} - -type Server struct { - udpNat gsync.Map[string, *packetConn] - handler Handler -} - -func NewServer(handler Handler) *Server { - return &Server{handler: handler} -} - -func (s *Server) HandleUDP(buffer *buf.Buffer, metadata M.Metadata) error { - conn, loaded := s.udpNat.LoadOrStore(metadata.Source.String(), func() *packetConn { - return &packetConn{source: metadata.Source.UDPAddr(), in: make(chan *udpPacket)} - }) - if !loaded { - go func() { - err := s.handler.NewPacketConnection(conn, metadata) - if err != nil { - s.handler.HandleError(err) - } - }() - } - conn.in <- &udpPacket{ - buffer: buffer, - destination: metadata.Destination, - } - return nil -} - -func (s *Server) OnError(err error) { - s.handler.HandleError(err) -} - -func (s *Server) Close() error { - s.udpNat.Range(func(key string, conn *packetConn) bool { - conn.Close() - return true - }) - s.udpNat = gsync.Map[string, *packetConn]{} - return nil -} - -type packetConn struct { - socks.PacketConnStub - source *net.UDPAddr - in chan *udpPacket -} - -type udpPacket struct { - buffer *buf.Buffer - destination *M.AddrPort -} - -func (c *packetConn) LocalAddr() net.Addr { - return c.source -} - -func (c *packetConn) Close() error { - select { - case <-c.in: - return io.ErrClosedPipe - default: - close(c.in) - } - return nil -} - -func (c *packetConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { - select { - case packet, ok := <-c.in: - if !ok { - return nil, io.ErrClosedPipe - } - defer packet.buffer.Release() - if buffer.FreeLen() < packet.buffer.Len() { - return nil, io.ErrShortBuffer - } - return packet.destination, common.Error(buffer.Write(packet.buffer.Bytes())) - } -} - -func (c *packetConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { - udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), c.source) - if err != nil { - return E.Cause(err, "tproxy udp write back") - } - defer udpConn.Close() - return common.Error(udpConn.Write(buffer.Bytes())) -} diff --git a/common/udpnat/service.go b/common/udpnat/service.go new file mode 100644 index 0000000..7634c43 --- /dev/null +++ b/common/udpnat/service.go @@ -0,0 +1,123 @@ +package udpnat + +import ( + "context" + "io" + "net" + "os" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/gsync" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/protocol/socks" +) + +type Handler interface { + socks.UDPConnectionHandler + E.Handler +} + +type Service[K comparable] struct { + nat gsync.Map[K, *conn] + handler Handler +} + +func New[T comparable](handler Handler) *Service[T] { + return &Service[T]{ + handler: handler, + } +} + +func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) error { + c, loaded := s.nat.LoadOrStore(key, func() *conn { + c := &conn{ + data: make(chan packet), + remoteAddr: metadata.Source.UDPAddr(), + source: writer(), + } + c.ctx, c.cancel = context.WithCancel(context.Background()) + return c + }) + if !loaded { + go func() { + err := s.handler.NewPacketConnection(c, metadata) + if err != nil { + s.handler.HandleError(err) + } + }() + } + ctx, done := context.WithCancel(c.ctx) + p := packet{ + done: done, + data: buffer, + destination: metadata.Destination, + } + c.data <- p + <-ctx.Done() + return nil +} + +type packet struct { + data *buf.Buffer + destination *M.AddrPort + done context.CancelFunc +} + +type conn struct { + ctx context.Context + cancel context.CancelFunc + data chan packet + remoteAddr *net.UDPAddr + source socks.PacketWriter +} + +func (c *conn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { + select { + case p, ok := <-c.data: + if !ok { + return nil, io.ErrClosedPipe + } + defer p.data.Release() + _, err := buffer.ReadFrom(p.data) + p.done() + return p.destination, err + } +} + +func (c *conn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { + return c.source.WritePacket(buffer, destination) +} + +func (c *conn) Close() error { + c.cancel() + select { + case <-c.data: + return os.ErrClosed + default: + close(c.data) + return nil + } +} + +func (c *conn) LocalAddr() net.Addr { + return &common.DummyAddr{} +} + +func (c *conn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *conn) SetDeadline(t time.Time) error { + return nil +} + +func (c *conn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *conn) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/common/uot/uot_test.go b/common/uot/uot_test.go index 337088d..7eb960f 100644 --- a/common/uot/uot_test.go +++ b/common/uot/uot_test.go @@ -30,7 +30,6 @@ func TestServerConn(t *testing.T) { Port: 53, })) _buffer := buf.StackNew() - common.Use(_buffer) buffer := common.Dup(_buffer) common.Must2(buffer.ReadPacketFrom(clientConn)) common.Must(message.Unpack(buffer.Bytes())) diff --git a/protocol/shadowsocks/service.go b/protocol/shadowsocks/service.go index 071791a..8c204e7 100644 --- a/protocol/shadowsocks/service.go +++ b/protocol/shadowsocks/service.go @@ -4,32 +4,35 @@ import ( "context" "net" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/udpnat" "github.com/sagernet/sing/protocol/socks" ) type Service interface { M.TCPConnectionHandler -} - -type MultiUserService interface { - Service - AddUser(key []byte) - RemoveUser(key []byte) + socks.UDPHandler } type Handler interface { M.TCPConnectionHandler + socks.UDPConnectionHandler + E.Handler } type NoneService struct { handler Handler + udp *udpnat.Service[string] } func NewNoneService(handler Handler) Service { - return &NoneService{ + s := &NoneService{ handler: handler, } + s.udp = udpnat.New[string](s) + return s } func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { @@ -41,3 +44,37 @@ func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata metadata.Destination = destination return s.handler.NewConnection(ctx, conn, metadata) } + +func (s *NoneService) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + destination, err := socks.AddressSerializer.ReadAddrPort(buffer) + if err != nil { + return err + } + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter { + return &serverPacketWriter{conn, metadata.Source} + }, buffer, metadata) +} + +type serverPacketWriter struct { + socks.PacketConn + sourceAddr *M.AddrPort +} + +func (s *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { + header := buf.With(buffer.ExtendHeader(socks.AddressSerializer.AddrPortLen(destination))) + err := socks.AddressSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + return s.PacketConn.WritePacket(buffer, s.sourceAddr) +} + +func (s *NoneService) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error { + return s.handler.NewPacketConnection(conn, metadata) +} + +func (s *NoneService) HandleError(err error) { + s.handler.HandleError(err) +} diff --git a/protocol/shadowsocks/shadowaead/protocol.go b/protocol/shadowsocks/shadowaead/protocol.go index ca27969..5867417 100644 --- a/protocol/shadowsocks/shadowaead/protocol.go +++ b/protocol/shadowsocks/shadowaead/protocol.go @@ -156,13 +156,13 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn } func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn { - return &clientPacketConn{conn, m} + return &clientPacketConn{m, conn} } func (m *Method) EncodePacket(buffer *buf.Buffer) error { key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength) c := m.constructor(common.Dup(key)) - c.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) + c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) buffer.Extend(c.Overhead()) return nil } @@ -299,20 +299,18 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) { } type clientPacketConn struct { + *Method net.Conn - method *Method } func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { - _header := buf.StackNew() - header := common.Dup(_header) - common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength)) - err := socks.AddressSerializer.WriteAddrPort(header, destination) + header := buffer.ExtendHeader(c.keySaltLength + socks.AddressSerializer.AddrPortLen(destination)) + common.Must1(io.ReadFull(c.secureRNG, header[:c.keySaltLength])) + err := socks.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination) if err != nil { return err } - buffer = buffer.WriteBufferAtFirst(header) - err = c.method.EncodePacket(buffer) + err = c.EncodePacket(buffer) if err != nil { return err } @@ -325,7 +323,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { return nil, err } buffer.Truncate(n) - err = c.method.DecodePacket(buffer) + err = c.DecodePacket(buffer) if err != nil { return nil, err } diff --git a/protocol/shadowsocks/shadowaead/service.go b/protocol/shadowsocks/shadowaead/service.go index f5f36f6..461302e 100644 --- a/protocol/shadowsocks/shadowaead/service.go +++ b/protocol/shadowsocks/shadowaead/service.go @@ -13,6 +13,7 @@ import ( M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/replay" "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/udpnat" "github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/socks" "golang.org/x/crypto/chacha20poly1305" @@ -25,6 +26,7 @@ type Service struct { key []byte secureRNG io.Reader replayFilter replay.Filter + udp *udpnat.Service[string] handler shadowsocks.Handler } @@ -34,6 +36,7 @@ func NewService(method string, key []byte, password []byte, secureRNG io.Reader, secureRNG: secureRNG, handler: handler, } + s.udp = udpnat.New[string](s) if replayFilter { s.replayFilter = replay.NewBloomRing() } @@ -163,3 +166,59 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) { func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) { return c.reader.WriteTo(w) } + +func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + if buffer.Len() < s.keySaltLength { + return E.New("bad packet") + } + key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength) + c := s.constructor(common.Dup(key)) + /*data := buf.New() + packet, err := c.Open(data.Index(0), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil) + if err != nil { + return err + } + data.Truncate(len(packet)) + metadata.Protocol = "shadowsocks" + return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter { + return &serverPacketWriter{s, conn, metadata.Source} + }, data, metadata)*/ + packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil) + if err != nil { + return err + } + buffer.Advance(s.keySaltLength) + buffer.Truncate(len(packet)) + metadata.Protocol = "shadowsocks" + return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter { + return &serverPacketWriter{s, conn, metadata.Source} + }, buffer, metadata) +} + +type serverPacketWriter struct { + *Service + socks.PacketConn + source *M.AddrPort +} + +func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { + header := buffer.ExtendHeader(w.keySaltLength + socks.AddressSerializer.AddrPortLen(destination)) + common.Must1(io.ReadFull(w.secureRNG, header[:w.keySaltLength])) + err := socks.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination) + if err != nil { + return err + } + key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength) + c := w.constructor(common.Dup(key)) + c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil) + buffer.Extend(c.Overhead()) + return w.PacketConn.WritePacket(buffer, w.source) +} + +func (s *Service) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error { + return s.handler.NewPacketConnection(conn, metadata) +} + +func (s *Service) HandleError(err error) { + s.handler.HandleError(err) +} diff --git a/protocol/shadowsocks/shadowaead_2022/protocol.go b/protocol/shadowsocks/shadowaead_2022/protocol.go index 73e74dd..4ff7fa8 100644 --- a/protocol/shadowsocks/shadowaead_2022/protocol.go +++ b/protocol/shadowsocks/shadowaead_2022/protocol.go @@ -109,9 +109,9 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth } func Blake3DeriveKey(psk, salt []byte, keyLength int) []byte { - sessionKey := make([]byte, 2*KeySaltSize) + sessionKey := buf.Make(len(psk) + len(salt)) copy(sessionKey, psk) - copy(sessionKey[KeySaltSize:], salt) + copy(sessionKey[len(psk):], salt) outKey := buf.Make(keyLength) blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey) return outKey @@ -434,6 +434,9 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo return err } buffer = buffer.WriteBufferAtFirst(header) + if err != nil { + return err + } if c.method.udpCipher != nil { c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) buffer.Extend(c.method.udpCipher.Overhead()) @@ -574,9 +577,9 @@ func (s *udpSession) nextPacketId() uint64 { } func (m *Method) newUDPSession() *udpSession { - session := &udpSession{ - sessionId: rand.Uint64(), - } + session := &udpSession{} + common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId)) + session.packetId-- if m.udpCipher == nil { sessionId := make([]byte, 8) binary.BigEndian.PutUint64(sessionId, session.sessionId) diff --git a/protocol/shadowsocks/shadowaead_2022/service.go b/protocol/shadowsocks/shadowaead_2022/service.go index 16fdfce..179e19e 100644 --- a/protocol/shadowsocks/shadowaead_2022/service.go +++ b/protocol/shadowsocks/shadowaead_2022/service.go @@ -2,32 +2,43 @@ package shadowaead_2022 import ( "context" + "crypto/aes" "crypto/cipher" "encoding/binary" "io" "math" "net" "sync" + "sync/atomic" "time" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/gsync" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/replay" "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/udpnat" "github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead" "github.com/sagernet/sing/protocol/socks" + wgReplay "golang.zx2c4.com/wireguard/replay" ) type Service struct { - name string - secureRNG io.Reader - keyLength int - constructor func(key []byte) cipher.AEAD - psk []byte - replayFilter replay.Filter - handler shadowsocks.Handler + name string + secureRNG io.Reader + keyLength int + constructor func(key []byte) cipher.AEAD + blockConstructor func(key []byte) cipher.Block + udpCipher cipher.AEAD + udpBlockCipher cipher.Block + psk []byte + replayFilter replay.Filter + handler shadowsocks.Handler + udpNat *udpnat.Service[uint64] + sessions gsync.Map[uint64, *serverUDPSession] } func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) { @@ -47,18 +58,20 @@ func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowso case "2022-blake3-aes-128-gcm": s.keyLength = 16 s.constructor = newAESGCM - // m.blockConstructor = newAES - // m.udpBlockCipher = newAES(m.psk) + s.blockConstructor = newAES + s.udpBlockCipher = newAES(s.psk) case "2022-blake3-aes-256-gcm": s.keyLength = 32 s.constructor = newAESGCM - // m.blockConstructor = newAES - // m.udpBlockCipher = newAES(m.psk) + s.blockConstructor = newAES + s.udpBlockCipher = newAES(s.psk) case "2022-blake3-chacha20-poly1305": s.keyLength = 32 s.constructor = newChacha20Poly1305 - // m.udpCipher = newXChacha20Poly1305(m.psk) + s.udpCipher = newXChacha20Poly1305(s.psk) } + + s.udpNat = udpnat.New[uint64](s) return s, nil } @@ -194,3 +207,169 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) { func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) { return c.reader.WriteTo(w) } + +func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + var packetHeader []byte + if s.udpCipher != nil { + _, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) + if err != nil { + return E.Cause(err, "decrypt packet header") + } + buffer.Advance(PacketNonceSize) + } else { + packetHeader = buffer.To(aes.BlockSize) + s.udpBlockCipher.Decrypt(packetHeader, packetHeader) + } + + var sessionId, packetId uint64 + err := binary.Read(buffer, binary.BigEndian, &sessionId) + if err != nil { + return err + } + err = binary.Read(buffer, binary.BigEndian, &packetId) + if err != nil { + return err + } + + session, loaded := s.sessions.LoadOrStore(sessionId, s.newUDPSession) + if !loaded { + session.remoteSessionId = sessionId + if packetHeader != nil { + key := Blake3DeriveKey(s.psk, packetHeader[:8], s.keyLength) + session.remoteCipher = s.constructor(common.Dup(key)) + } + } + + if !session.filter.ValidateCounter(packetId, math.MaxUint64) { + return ErrPacketIdNotUnique + } + + if packetHeader != nil { + _, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) + if err != nil { + return E.Cause(err, "decrypt packet") + } + } + + var headerType byte + headerType, err = buffer.ReadByte() + if err != nil { + return err + } + + if headerType != HeaderTypeClient { + return ErrBadHeaderType + } + + var epoch uint64 + err = binary.Read(buffer, binary.BigEndian, &epoch) + if err != nil { + return err + } + if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 { + return ErrBadTimestamp + } + + var paddingLength uint16 + err = binary.Read(buffer, binary.BigEndian, &paddingLength) + if err != nil { + return E.Cause(err, "read padding length") + } + buffer.Advance(int(paddingLength)) + + destination, err := socks.AddressSerializer.ReadAddrPort(buffer) + if err != nil { + return err + } + metadata.Destination = destination + + return s.udpNat.NewPacket(sessionId, func() socks.PacketWriter { + return &serverPacketWriter{s, conn, session, metadata.Source} + }, buffer, metadata) +} + +type serverPacketWriter struct { + *Service + socks.PacketConn + session *serverUDPSession + source *M.AddrPort +} + +func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { + defer buffer.Release() + + _header := buf.StackNew() + header := common.Dup(_header) + + var dataIndex int + if w.udpCipher != nil { + common.Must1(header.ReadFullFrom(w.secureRNG, PacketNonceSize)) + dataIndex = buffer.Len() + } else { + dataIndex = aes.BlockSize + } + + common.Must( + binary.Write(header, binary.BigEndian, w.session.sessionId), + binary.Write(header, binary.BigEndian, w.session.nextPacketId()), + header.WriteByte(HeaderTypeServer), + binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())), + binary.Write(header, binary.BigEndian, w.session.remoteSessionId), + binary.Write(header, binary.BigEndian, uint16(0)), // padding length + ) + + err := socks.AddressSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + + _, err = header.Write(buffer.Bytes()) + if err != nil { + return err + } + + if w.udpCipher != nil { + w.udpCipher.Seal(header.Index(dataIndex), header.To(dataIndex), header.From(dataIndex), nil) + header.Extend(w.udpCipher.Overhead()) + } else { + packetHeader := header.To(aes.BlockSize) + w.session.cipher.Seal(header.Index(dataIndex), packetHeader[4:16], header.From(dataIndex), nil) + header.Extend(w.session.cipher.Overhead()) + w.udpBlockCipher.Encrypt(packetHeader, packetHeader) + } + return w.PacketConn.WritePacket(header, w.source) +} + +type serverUDPSession struct { + sessionId uint64 + remoteSessionId uint64 + packetId uint64 + cipher cipher.AEAD + remoteCipher cipher.AEAD + filter wgReplay.Filter +} + +func (s *serverUDPSession) nextPacketId() uint64 { + return atomic.AddUint64(&s.packetId, 1) +} + +func (m *Service) newUDPSession() *serverUDPSession { + session := &serverUDPSession{} + common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId)) + session.packetId-- + if m.udpCipher == nil { + sessionId := make([]byte, 8) + binary.BigEndian.PutUint64(sessionId, session.sessionId) + key := Blake3DeriveKey(m.psk, sessionId, m.keyLength) + session.cipher = m.constructor(common.Dup(key)) + } + return session +} + +func (s *Service) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error { + return s.handler.NewPacketConnection(conn, metadata) +} + +func (s *Service) HandleError(err error) { + s.handler.HandleError(err) +} diff --git a/protocol/socks/conn.go b/protocol/socks/conn.go index 7f4ed06..0d57524 100644 --- a/protocol/socks/conn.go +++ b/protocol/socks/conn.go @@ -8,12 +8,21 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/task" ) -type PacketConn interface { +type PacketReader interface { ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) - WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error +} + +type PacketWriter interface { + WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error +} + +type PacketConn interface { + PacketReader + PacketWriter Close() error LocalAddr() net.Addr @@ -23,6 +32,10 @@ type PacketConn interface { SetWriteDeadline(t time.Time) error } +type UDPHandler interface { + NewPacket(conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error +} + type UDPConnectionHandler interface { NewPacketConnection(conn PacketConn, metadata M.Metadata) error } @@ -47,6 +60,8 @@ func (s *PacketConnStub) SetWriteDeadline(t time.Time) error { func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error { return task.Run(ctx, func() error { + defer rw.CloseRead(conn) + defer rw.CloseWrite(dest) _buffer := buf.StackNewMax() buffer := common.Dup(_buffer) data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader) @@ -56,13 +71,15 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error if err != nil { return err } - buffer.Truncate(data.Len()) + buffer.Resize(buf.ReversedHeader+data.Start(), data.Len()) err = dest.WritePacket(buffer, destination) if err != nil { return err } } }, func() error { + defer rw.CloseRead(dest) + defer rw.CloseWrite(conn) _buffer := buf.StackNewMax() buffer := common.Dup(_buffer) data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader) @@ -72,7 +89,7 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error if err != nil { return err } - buffer.Truncate(data.Len()) + buffer.Resize(buf.ReversedHeader+data.Start(), data.Len()) err = conn.WritePacket(buffer, destination) if err != nil { return err @@ -81,44 +98,211 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error }) } -func CopyPacketConn0(dest PacketConn, conn PacketConn, onAction func(destination *M.AddrPort, n int)) error { - for { - buffer := buf.New() - destination, err := conn.ReadPacket(buffer) - if err != nil { - buffer.Release() - return err +func CopyNetPacketConn(ctx context.Context, dest net.PacketConn, conn PacketConn) error { + return task.Run(ctx, func() error { + defer rw.CloseRead(conn) + defer rw.CloseWrite(dest) + + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) + for { + buffer.FullReset() + destination, err := conn.ReadPacket(buffer) + if err != nil { + return err + } + + _, err = dest.WriteTo(buffer.Bytes(), destination.UDPAddr()) + if err != nil { + return err + } } - size := buffer.Len() - err = dest.WritePacket(buffer, destination) - if err != nil { - buffer.Release() - return err + }, func() error { + defer rw.CloseRead(dest) + defer rw.CloseWrite(conn) + + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) + data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader) + for { + data.FullReset() + n, addr, err := dest.ReadFrom(data.FreeBytes()) + if err != nil { + return err + } + buffer.Resize(buf.ReversedHeader, n) + err = conn.WritePacket(buffer, M.AddrPortFromNetAddr(addr)) + if err != nil { + return err + } } - if onAction != nil { - onAction(destination, size) - } - } + }) } -type associatePacketConn struct { - net.PacketConn +type AssociateConn struct { + net.Conn conn net.Conn addr net.Addr + dest *M.AddrPort } -func NewPacketConn(conn net.Conn, packetConn net.PacketConn) PacketConn { - return &associatePacketConn{ - PacketConn: packetConn, - conn: conn, +func NewAssociateConn(conn net.Conn, packetConn net.Conn, destination *M.AddrPort) net.PacketConn { + return &AssociateConn{ + Conn: packetConn, + conn: conn, + dest: destination, } } -func (c *associatePacketConn) RemoteAddr() net.Addr { +func (c *AssociateConn) RemoteAddr() net.Addr { return c.addr } -func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (c *AssociateConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = c.Conn.Read(p) + if err != nil { + return + } + reader := buf.As(p[3:n]) + destination, err := AddressSerializer.ReadAddrPort(reader) + if err != nil { + return + } + addr = destination.UDPAddr() + n = copy(p, reader.Bytes()) + return +} + +func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) + common.Must(buffer.WriteZeroN(3)) + err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr)) + if err != nil { + return + } + _, err = buffer.Write(p) + if err != nil { + return + } + + _, err = c.Conn.Write(buffer.Bytes()) + return +} + +func (c *AssociateConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *AssociateConn) Write(b []byte) (n int, err error) { + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) + common.Must(buffer.WriteZeroN(3)) + err = AddressSerializer.WriteAddrPort(buffer, c.dest) + if err != nil { + return + } + _, err = buffer.Write(b) + if err != nil { + return + } + _, err = c.Conn.Write(buffer.Bytes()) + return +} + +func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { + n, err := buffer.ReadFrom(c.conn) + if err != nil { + return nil, err + } + buffer.Truncate(int(n)) + buffer.Advance(3) + return AddressSerializer.ReadAddrPort(buffer) +} + +func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { + defer buffer.Release() + header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination))) + common.Must(header.WriteZeroN(3)) + common.Must(AddressSerializer.WriteAddrPort(header, destination)) + return common.Error(c.Conn.Write(buffer.Bytes())) +} + +type AssociatePacketConn struct { + net.PacketConn + conn net.Conn + addr net.Addr + dest *M.AddrPort +} + +func NewAssociatePacketConn(conn net.Conn, packetConn net.PacketConn, destination *M.AddrPort) *AssociatePacketConn { + return &AssociatePacketConn{ + PacketConn: packetConn, + conn: conn, + dest: destination, + } +} + +func (c *AssociatePacketConn) RemoteAddr() net.Addr { + return c.addr +} + +func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return + } + reader := buf.As(p[3:n]) + destination, err := AddressSerializer.ReadAddrPort(reader) + if err != nil { + return + } + addr = destination.UDPAddr() + n = copy(p, reader.Bytes()) + return +} + +func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) + common.Must(buffer.WriteZeroN(3)) + + err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr)) + if err != nil { + return + } + _, err = buffer.Write(p) + if err != nil { + return + } + _, err = c.PacketConn.WriteTo(buffer.Bytes(), c.addr) + return +} + +func (c *AssociatePacketConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *AssociatePacketConn) Write(b []byte) (n int, err error) { + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) + common.Must(buffer.WriteZeroN(3)) + + err = AddressSerializer.WriteAddrPort(buffer, c.dest) + if err != nil { + return + } + _, err = buffer.Write(b) + if err != nil { + return + } + _, err = c.PacketConn.WriteTo(buffer.Bytes(), c.addr) + return +} + +func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes()) if err != nil { return nil, err @@ -126,15 +310,14 @@ func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error c.addr = addr buffer.Truncate(n) buffer.Advance(3) - return AddressSerializer.ReadAddrPort(buffer) + dest, err := AddressSerializer.ReadAddrPort(buffer) + return dest, err } -func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error { +func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { defer buffer.Release() - _header := buf.StackNew() - header := common.Dup(_header) + header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination))) common.Must(header.WriteZeroN(3)) - common.Must(AddressSerializer.WriteAddrPort(header, addrPort)) - buffer = buffer.WriteBufferAtFirst(header) + common.Must(AddressSerializer.WriteAddrPort(header, destination)) return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr)) } diff --git a/protocol/socks/listener.go b/protocol/socks/listener.go index 22deef0..3abff6c 100644 --- a/protocol/socks/listener.go +++ b/protocol/socks/listener.go @@ -36,7 +36,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler } func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - return HandleConnection(ctx, conn, l.authenticator, l.bindAddr, l.handler, metadata) + return HandleConnection(ctx, conn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata) } func (l *Listener) Start() error { @@ -131,9 +131,10 @@ func HandleConnection(ctx context.Context, conn net.Conn, authenticator auth.Aut if err != nil { return E.Cause(err, "write socks response") } + metadata.Protocol = "socks" metadata.Destination = request.Destination go func() { - err := handler.NewPacketConnection(NewPacketConn(conn, udpConn), metadata) + err := handler.NewPacketConnection(NewAssociatePacketConn(conn, udpConn, request.Destination), metadata) if err != nil { handler.HandleError(err) } diff --git a/transport/mixed/listener.go b/transport/mixed/listener.go index 9b4a4b2..7e5646a 100644 --- a/transport/mixed/listener.go +++ b/transport/mixed/listener.go @@ -32,7 +32,7 @@ type Listener struct { bindAddr netip.Addr handler Handler authenticator auth.Authenticator - udpNat *udpnat.Server + udpNat *udpnat.Service[string] } func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, handler Handler) *Listener { @@ -45,7 +45,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transpro listener.TCPListener = tcp.NewTCPListener(bind, listener, tcp.WithTransproxyMode(transproxy)) if transproxy == redir.ModeTProxy { listener.UDPListener = udp.NewUDPListener(bind, listener, udp.WithTransproxyMode(transproxy)) - listener.udpNat = udpnat.NewServer(handler) + listener.udpNat = udpnat.New[string](handler) } return listener } @@ -63,7 +63,7 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M. case socks.Version4: return E.New("socks4 request dropped (TODO)") case socks.Version5: - return socks.HandleConnection(ctx, bufConn, l.authenticator, l.bindAddr, l.handler, metadata) + return socks.HandleConnection(ctx, bufConn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata) } request, err := http.ReadRequest(bufConn.Reader()) @@ -96,8 +96,23 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M. return http.HandleRequest(ctx, request, bufConn, l.authenticator, l.handler, metadata) } -func (l *Listener) NewPacket(packet *buf.Buffer, metadata M.Metadata) error { - return l.udpNat.HandleUDP(packet, metadata) +func (l *Listener) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { + return l.udpNat.NewPacket(metadata.Source.String(), func() socks.PacketWriter { + return &tproxyPacketWriter{metadata.Source.UDPAddr()} + }, buffer, metadata) +} + +type tproxyPacketWriter struct { + source *net.UDPAddr +} + +func (w *tproxyPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { + udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), w.source) + if err != nil { + return E.Cause(err, "tproxy udp write back") + } + defer udpConn.Close() + return common.Error(udpConn.Write(buffer.Bytes())) } func (l *Listener) HandleError(err error) { diff --git a/transport/udp/udp.go b/transport/udp/udp.go index 3f2889a..133a615 100644 --- a/transport/udp/udp.go +++ b/transport/udp/udp.go @@ -9,10 +9,11 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/redir" + "github.com/sagernet/sing/protocol/socks" ) type Handler interface { - M.UDPHandler + socks.UDPHandler E.Handler } @@ -24,6 +25,19 @@ type Listener struct { tproxy bool } +func (l *Listener) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { + n, addr, err := l.ReadFromUDP(buffer.FreeBytes()) + if err != nil { + return nil, err + } + buffer.Truncate(n) + return M.AddrPortFromNetAddr(addr), nil +} + +func (l *Listener) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { + return common.Error(l.UDPConn.WriteTo(buffer.Bytes(), destination.UDPAddr())) +} + func NewUDPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener { listener := &Listener{ handler: handler, @@ -69,32 +83,31 @@ func (l *Listener) Close() error { } func (l *Listener) loop() { + _buffer := buf.StackNewMax() + buffer := common.Dup(_buffer) + data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader).Slice() if !l.tproxy { for { - buffer := buf.New() - n, addr, err := l.ReadFromUDP(buffer.Extend(buf.UDPBufferSize)) + n, addr, err := l.ReadFromUDP(data) if err != nil { - buffer.Release() l.handler.HandleError(err) return } - buffer.Truncate(n) - err = l.handler.NewPacket(buffer, M.Metadata{ + buffer.Resize(buf.ReversedHeader, n) + err = l.handler.NewPacket(l, buffer, M.Metadata{ Protocol: "udp", Source: M.AddrPortFromNetAddr(addr), }) if err != nil { - buffer.Release() l.handler.HandleError(err) } } } else { - oob := make([]byte, 1024) + _oob := make([]byte, 1024) + oob := common.Dup(_oob) for { - buffer := buf.New() - n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(buffer.FreeBytes(), oob) + n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(data, oob) if err != nil { - buffer.Release() l.handler.HandleError(err) return } @@ -103,14 +116,13 @@ func (l *Listener) loop() { l.handler.HandleError(E.Cause(err, "get original destination")) return } - buffer.Truncate(n) - err = l.handler.NewPacket(buffer, M.Metadata{ + buffer.Resize(buf.ReversedHeader, n) + err = l.handler.NewPacket(l, buffer, M.Metadata{ Protocol: "tproxy", Source: M.AddrPortFromAddrPort(addr), Destination: destination, }) if err != nil { - buffer.Release() l.handler.HandleError(err) } }