Add shadowsocks service

This commit is contained in:
世界 2022-04-29 12:06:10 +08:00
parent 5be6eb2d64
commit df1e1cfafd
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
21 changed files with 1028 additions and 265 deletions

View file

@ -28,6 +28,12 @@ wget 'https://github.com/Dreamacro/maxmind-geoip/releases/latest/download/Countr
go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/ss-local go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/ss-local
``` ```
### ss-server
```shell
go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/ss-server
```
### ddns ### ddns
```shell ```shell

145
cli/socks-chk/main.go Normal file
View file

@ -0,0 +1,145 @@
package main
import (
"context"
"encoding/binary"
"io"
"net"
"net/netip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/net/dns/dnsmessage"
)
func main() {
command := &cobra.Command{
Use: "socks-chk address:port",
Args: cobra.ExactArgs(1),
Run: run,
}
if err := command.Execute(); err != nil {
logrus.Fatal(err)
}
}
func run(cmd *cobra.Command, args []string) {
server, err := M.ParseAddress(args[0])
if err != nil {
logrus.Fatal("invalid server address ", args[0])
}
err = testSocksTCP(server)
if err != nil {
logrus.Fatal(err)
}
err = testSocksUDP(server)
if err != nil {
logrus.Fatal(err)
}
}
func testSocksTCP(server *M.AddrPort) error {
tcpConn, err := net.Dial("tcp", server.String())
if err != nil {
return err
}
response, err := socks.ClientHandshake(tcpConn, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.ParseAddr("1.0.0.1"), 53), "", "")
if err != nil {
return err
}
if response.ReplyCode != socks.ReplyCodeSuccess {
logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode)
}
message := &dnsmessage.Message{}
message.Header.ID = 1
message.Header.RecursionDesired = true
message.Questions = append(message.Questions, dnsmessage.Question{
Name: dnsmessage.MustNewName("google.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
})
packet, err := message.Pack()
err = binary.Write(tcpConn, binary.BigEndian, uint16(len(packet)))
if err != nil {
return err
}
_, err = tcpConn.Write(packet)
if err != nil {
return err
}
var respLen uint16
err = binary.Read(tcpConn, binary.BigEndian, &respLen)
if err != nil {
return err
}
respBuf := buf.Make(int(respLen))
_, err = io.ReadFull(tcpConn, respBuf)
if err != nil {
return err
}
common.Must(message.Unpack(respBuf))
for _, answer := range message.Answers {
logrus.Info("tcp got answer: ", netip.AddrFrom4(answer.Body.(*dnsmessage.AResource).A))
}
tcpConn.Close()
return nil
}
func testSocksUDP(server *M.AddrPort) error {
tcpConn, err := net.Dial("tcp", server.String())
if err != nil {
return err
}
dest := M.AddrPortFrom(M.ParseAddr("1.0.0.1"), 53)
response, err := socks.ClientHandshake(tcpConn, socks.Version5, socks.CommandUDPAssociate, dest, "", "")
if err != nil {
return err
}
if response.ReplyCode != socks.ReplyCodeSuccess {
logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode)
}
var dialer net.Dialer
udpConn, err := dialer.DialContext(context.Background(), "udp", response.Bind.String())
if err != nil {
return err
}
assConn := socks.NewAssociateConn(tcpConn, udpConn, dest)
message := &dnsmessage.Message{}
message.Header.ID = 1
message.Header.RecursionDesired = true
message.Questions = append(message.Questions, dnsmessage.Question{
Name: dnsmessage.MustNewName("google.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
})
packet, err := message.Pack()
common.Must(err)
common.Must1(assConn.WriteTo(packet, &net.UDPAddr{
IP: net.IPv4(1, 0, 0, 1),
Port: 53,
}))
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must2(buffer.ReadPacketFrom(assConn))
common.Must(message.Unpack(buffer.Bytes()))
for _, answer := range message.Answers {
logrus.Info("udp got answer: ", netip.AddrFrom4(answer.Body.(*dnsmessage.AResource).A))
}
udpConn.Close()
tcpConn.Close()
return nil
}

View file

@ -345,28 +345,15 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
return rw.CopyConn(ctx, serverConn, conn) return rw.CopyConn(ctx, serverConn, conn)
} }
func (c *client) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error { func (c *client) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
logrus.Info("outbound ", metadata.Protocol, " UDP ", metadata.Source, " ==> ", metadata.Destination)
ctx := context.Background() ctx := context.Background()
udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String()) udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String())
if err != nil { if err != nil {
return err return err
} }
serverConn := c.method.DialPacketConn(udpConn) serverConn := c.method.DialPacketConn(udpConn)
return task.Run(ctx, func() error { return socks.CopyPacketConn(ctx, serverConn, conn)
var init bool
return socks.CopyPacketConn0(serverConn, conn, func(destination *M.AddrPort, n int) {
if !init {
init = true
logrus.Info("UDP ", conn.LocalAddr(), " ==> ", destination)
} else {
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination)
}
})
}, func() error {
return socks.CopyPacketConn0(conn, serverConn, func(destination *M.AddrPort, n int) {
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination)
})
})
} }
func run(cmd *cobra.Command, flags *flags) { func run(cmd *cobra.Command, flags *flags) {

View file

@ -3,6 +3,8 @@ package main
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/json"
"io/ioutil"
"net" "net"
"net/netip" "net/netip"
"os" "os"
@ -12,22 +14,29 @@ import (
"github.com/sagernet/sing" "github.com/sagernet/sing"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/random" "github.com/sagernet/sing/common/random"
"github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/transport/tcp" "github.com/sagernet/sing/transport/tcp"
"github.com/sagernet/sing/transport/udp"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
type flags struct { type flags struct {
Server string `json:"server"`
ServerPort uint16 `json:"server_port"`
Bind string `json:"local_address"` Bind string `json:"local_address"`
LocalPort uint16 `json:"local_port"` LocalPort uint16 `json:"local_port"`
// Password string `json:"password"` Password string `json:"password"`
Key string `json:"key"` Key string `json:"key"`
Method string `json:"method"` Method string `json:"method"`
Verbose bool `json:"verbose"` Verbose bool `json:"verbose"`
@ -35,8 +44,6 @@ type flags struct {
} }
func main() { func main() {
logrus.SetLevel(logrus.TraceLevel)
f := new(flags) f := new(flags)
command := &cobra.Command{ command := &cobra.Command{
@ -48,12 +55,16 @@ func main() {
}, },
} }
command.Flags().StringVarP(&f.Server, "server", "s", "", "Set the server’s hostname or IP.")
command.Flags().Uint16VarP(&f.ServerPort, "server-port", "p", 0, "Set the server’s port number.")
command.Flags().StringVarP(&f.Bind, "local-address", "b", "", "Set the local address.") command.Flags().StringVarP(&f.Bind, "local-address", "b", "", "Set the local address.")
command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.") command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.")
command.Flags().StringVarP(&f.Key, "key", "k", "", "Set the key directly. The key should be encoded with URL-safe Base64.") command.Flags().StringVar(&f.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
command.Flags().StringVarP(&f.Password, "password", "k", "", "Set the password. The server and the client should use the same password.")
var supportedCiphers []string var supportedCiphers []string
supportedCiphers = append(supportedCiphers, shadowsocks.MethodNone) supportedCiphers = append(supportedCiphers, shadowsocks.MethodNone)
supportedCiphers = append(supportedCiphers, shadowaead.List...)
supportedCiphers = append(supportedCiphers, shadowaead_2022.List...) supportedCiphers = append(supportedCiphers, shadowaead_2022.List...)
command.Flags().StringVarP(&f.Method, "encrypt-method", "m", "", "Set the cipher.\n\nSupported ciphers:\n\n"+strings.Join(supportedCiphers, "\n")) command.Flags().StringVarP(&f.Method, "encrypt-method", "m", "", "Set the cipher.\n\nSupported ciphers:\n\n"+strings.Join(supportedCiphers, "\n"))
@ -75,6 +86,10 @@ func run(cmd *cobra.Command, f *flags) {
if err != nil { if err != nil {
logrus.Fatal(err) logrus.Fatal(err)
} }
err = s.udpIn.Start()
if err != nil {
logrus.Fatal(err)
}
logrus.Info("server started at ", s.tcpIn.TCPListener.Addr()) logrus.Info("server started at ", s.tcpIn.TCPListener.Addr())
osSignals := make(chan os.Signal, 1) osSignals := make(chan os.Signal, 1)
@ -82,33 +97,101 @@ func run(cmd *cobra.Command, f *flags) {
<-osSignals <-osSignals
s.tcpIn.Close() s.tcpIn.Close()
s.udpIn.Close()
} }
type server struct { type server struct {
tcpIn *tcp.Listener tcpIn *tcp.Listener
udpIn *udp.Listener
service shadowsocks.Service service shadowsocks.Service
udpNat gsync.Map[string, *net.UDPConn]
}
func (s *server) Start() error {
err := s.tcpIn.Start()
if err != nil {
return err
}
err = s.udpIn.Start()
return err
}
func (s *server) Close() error {
s.tcpIn.Close()
s.udpIn.Close()
return nil
} }
func newServer(f *flags) (*server, error) { func newServer(f *flags) (*server, error) {
s := new(server) s := new(server)
if f.Method == shadowsocks.MethodNone { if f.ConfigFile != "" {
s.service = shadowsocks.NewNoneService(s) configFile, err := ioutil.ReadFile(f.ConfigFile)
} else if common.Contains(shadowaead_2022.List, f.Method) { if err != nil {
var pskList [][]byte return nil, E.Cause(err, "read config file")
}
flagsNew := new(flags)
err = json.Unmarshal(configFile, flagsNew)
if err != nil {
return nil, E.Cause(err, "decode config file")
}
if flagsNew.Server != "" && f.Server == "" {
f.Server = flagsNew.Server
}
if flagsNew.ServerPort != 0 && f.ServerPort == 0 {
f.ServerPort = flagsNew.ServerPort
}
if flagsNew.Bind != "" && f.Bind == "" {
f.Bind = flagsNew.Bind
}
if flagsNew.LocalPort != 0 && f.LocalPort == 0 {
f.LocalPort = flagsNew.LocalPort
}
if flagsNew.Password != "" && f.Password == "" {
f.Password = flagsNew.Password
}
if flagsNew.Key != "" && f.Key == "" {
f.Key = flagsNew.Key
}
if flagsNew.Method != "" && f.Method == "" {
f.Method = flagsNew.Method
}
if flagsNew.Verbose {
f.Verbose = true
}
}
if f.Verbose {
logrus.SetLevel(logrus.TraceLevel)
}
if f.Server == "" {
return nil, E.New("missing server address")
} else if f.ServerPort == 0 {
return nil, E.New("missing server port")
} else if f.Method == "" {
return nil, E.New("missing method")
}
var key []byte
if f.Key != "" { if f.Key != "" {
keyStrList := strings.Split(f.Key, ":") kb, err := base64.StdEncoding.DecodeString(f.Key)
pskList = make([][]byte, len(keyStrList))
for i, keyStr := range keyStrList {
key, err := base64.StdEncoding.DecodeString(keyStr)
if err != nil { if err != nil {
return nil, E.Cause(err, "decode key") return nil, E.Cause(err, "decode key")
} }
pskList[i] = key key = kb
} }
if f.Method == shadowsocks.MethodNone {
s.service = shadowsocks.NewNoneService(s)
} else if common.Contains(shadowaead.List, f.Method) {
service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Blake3KeyedHash(), false, s)
if err != nil {
return nil, err
} }
rng := random.System s.service = service
service, err := shadowaead_2022.NewService(f.Method, pskList[0], rng, s) } else if common.Contains(shadowaead_2022.List, f.Method) {
service, err := shadowaead_2022.NewService(f.Method, key, random.Blake3KeyedHash(), s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -118,16 +201,17 @@ func newServer(f *flags) (*server, error) {
} }
var bind netip.Addr var bind netip.Addr
if f.Bind != "" { if f.Server != "" {
addr, err := netip.ParseAddr(f.Bind) addr, err := netip.ParseAddr(f.Server)
if err != nil { if err != nil {
return nil, E.Cause(err, "bad local address") return nil, E.Cause(err, "bad server address")
} }
bind = addr bind = addr
} else { } else {
bind = netip.IPv6Unspecified() bind = netip.IPv6Unspecified()
} }
s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.LocalPort), s) s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.ServerPort), s)
s.udpIn = udp.NewUDPListener(netip.AddrPortFrom(bind, f.ServerPort), s)
return s, nil return s, nil
} }
@ -143,6 +227,19 @@ func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
return rw.CopyConn(ctx, conn, destConn) return rw.CopyConn(ctx, conn, destConn)
} }
func (s *server) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
logrus.Info("inbound UDP ", metadata.Source, " ==> ", metadata.Destination)
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return err
}
return socks.CopyNetPacketConn(context.Background(), udpConn, conn)
}
func (s *server) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
return s.service.NewPacket(conn, buffer, metadata)
}
func (s *server) HandleError(err error) { func (s *server) HandleError(err error) {
if E.IsClosed(err) { if E.IsClosed(err) {
return return

View file

@ -14,7 +14,6 @@ import (
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/redir" "github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/common/uot" "github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/transport/mixed" "github.com/sagernet/sing/transport/mixed"
@ -122,15 +121,7 @@ func (c *localClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) e
} }
client := uot.NewClientConn(upstream) client := uot.NewClientConn(upstream)
return task.Run(context.Background(), func() error { return socks.CopyPacketConn(context.Background(), client, conn)
return socks.CopyPacketConn0(client, conn, func(destination *M.AddrPort, n int) {
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination)
})
}, func() error {
return socks.CopyPacketConn0(conn, client, func(destination *M.AddrPort, n int) {
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination)
})
})
} }
func (c *localClient) OnError(err error) { func (c *localClient) OnError(err error) {

View file

@ -104,11 +104,15 @@ func (b *Buffer) ExtendHeader(size int) []byte {
b.start -= size b.start -= size
return b.data[b.start-size : b.start] return b.data[b.start-size : b.start]
} else { } else {
offset := size - b.start /*offset := size - b.start
end := b.end + size end := b.end + size
if end > len(b.data) {
panic("buffer overflow")
}
copy(b.data[offset:end], b.data[b.start:b.end]) copy(b.data[offset:end], b.data[b.start:b.end])
b.end = end b.end = end
return b.data[:offset] return b.data[:offset]*/
panic("no header available")
} }
} }
@ -119,10 +123,18 @@ func (b *Buffer) WriteBufferAtFirst(buffer *Buffer) *Buffer {
b.start -= n b.start -= n
buffer.Release() buffer.Release()
return b return b
} } else if buffer.FreeLen() >= b.Len() {
common.Must1(buffer.Write(b.Bytes())) common.Must1(buffer.Write(b.Bytes()))
b.Release() b.Release()
return buffer return buffer
} else if b.FreeLen() >= size {
copy(b.data[b.start+size:b.end+size], b.data[b.start:b.end])
copy(b.data, buffer.data)
buffer.Release()
return b
} else {
panic("buffer overflow")
}
} }
func (b *Buffer) WriteAtFirst(data []byte) (n int, err error) { func (b *Buffer) WriteAtFirst(data []byte) (n int, err error) {
@ -305,6 +317,10 @@ func (b *Buffer) Cut(start int, end int) *Buffer {
} }
} }
func (b Buffer) Start() int {
return b.start
}
func (b Buffer) Len() int { func (b Buffer) Len() int {
return b.end - b.start return b.end - b.start
} }

View file

@ -3,8 +3,6 @@ package metadata
import ( import (
"context" "context"
"net" "net"
"github.com/sagernet/sing/common/buf"
) )
type Metadata struct { type Metadata struct {
@ -16,7 +14,3 @@ type Metadata struct {
type TCPConnectionHandler interface { type TCPConnectionHandler interface {
NewConnection(ctx context.Context, conn net.Conn, metadata Metadata) error NewConnection(ctx context.Context, conn net.Conn, metadata Metadata) error
} }
type UDPHandler interface {
NewPacket(packet *buf.Buffer, metadata Metadata) error
}

View file

@ -55,28 +55,43 @@ func (s *Serializer) WriteAddress(writer io.Writer, addr Addr) error {
return err return err
} }
func (s *Serializer) AddressLen(addr Addr) int {
switch addr.Family() {
case AddressFamilyIPv4:
return 5
case AddressFamilyIPv6:
return 17
default:
return 1 + len(addr.Fqdn())
}
}
func (s *Serializer) WritePort(writer io.Writer, port uint16) error { func (s *Serializer) WritePort(writer io.Writer, port uint16) error {
return binary.Write(writer, binary.BigEndian, port) return binary.Write(writer, binary.BigEndian, port)
} }
func (s *Serializer) WriteAddrPort(writer io.Writer, addrPort *AddrPort) error { func (s *Serializer) WriteAddrPort(writer io.Writer, destination *AddrPort) error {
var err error var err error
if !s.portFirst { if !s.portFirst {
err = s.WriteAddress(writer, addrPort.Addr) err = s.WriteAddress(writer, destination.Addr)
} else { } else {
err = s.WritePort(writer, addrPort.Port) err = s.WritePort(writer, destination.Port)
} }
if err != nil { if err != nil {
return err return err
} }
if s.portFirst { if s.portFirst {
err = s.WriteAddress(writer, addrPort.Addr) err = s.WriteAddress(writer, destination.Addr)
} else { } else {
err = s.WritePort(writer, addrPort.Port) err = s.WritePort(writer, destination.Port)
} }
return err return err
} }
func (s *Serializer) AddrPortLen(destination *AddrPort) int {
return s.AddressLen(destination.Addr) + 2
}
func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) { func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) {
af, err := rw.ReadByte(reader) af, err := rw.ReadByte(reader)
if err != nil { if err != nil {
@ -120,7 +135,7 @@ func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) {
return binary.BigEndian.Uint16(port), nil return binary.BigEndian.Uint16(port), nil
} }
func (s *Serializer) ReadAddrPort(reader io.Reader) (addrPort *AddrPort, err error) { func (s *Serializer) ReadAddrPort(reader io.Reader) (destination *AddrPort, err error) {
var addr Addr var addr Addr
var port uint16 var port uint16
if !s.portFirst { if !s.portFirst {

View file

@ -17,6 +17,17 @@ type DefaultDialer struct {
net.Dialer net.Dialer
} }
func (d *DefaultDialer) ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
return net.ListenUDP(network, laddr)
}
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error) { func (d *DefaultDialer) DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error) {
return d.Dialer.DialContext(ctx, network, address.String()) return d.Dialer.DialContext(ctx, network, address.String())
} }
type Listener interface {
Listen(ctx context.Context, network, address string) (net.Listener, error)
ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error)
}
var SystemListener Listener = &net.ListenConfig{}

View file

@ -1,108 +0,0 @@
package udpnat
import (
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/protocol/socks"
)
type Handler interface {
socks.UDPConnectionHandler
E.Handler
}
type Server struct {
udpNat gsync.Map[string, *packetConn]
handler Handler
}
func NewServer(handler Handler) *Server {
return &Server{handler: handler}
}
func (s *Server) HandleUDP(buffer *buf.Buffer, metadata M.Metadata) error {
conn, loaded := s.udpNat.LoadOrStore(metadata.Source.String(), func() *packetConn {
return &packetConn{source: metadata.Source.UDPAddr(), in: make(chan *udpPacket)}
})
if !loaded {
go func() {
err := s.handler.NewPacketConnection(conn, metadata)
if err != nil {
s.handler.HandleError(err)
}
}()
}
conn.in <- &udpPacket{
buffer: buffer,
destination: metadata.Destination,
}
return nil
}
func (s *Server) OnError(err error) {
s.handler.HandleError(err)
}
func (s *Server) Close() error {
s.udpNat.Range(func(key string, conn *packetConn) bool {
conn.Close()
return true
})
s.udpNat = gsync.Map[string, *packetConn]{}
return nil
}
type packetConn struct {
socks.PacketConnStub
source *net.UDPAddr
in chan *udpPacket
}
type udpPacket struct {
buffer *buf.Buffer
destination *M.AddrPort
}
func (c *packetConn) LocalAddr() net.Addr {
return c.source
}
func (c *packetConn) Close() error {
select {
case <-c.in:
return io.ErrClosedPipe
default:
close(c.in)
}
return nil
}
func (c *packetConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
select {
case packet, ok := <-c.in:
if !ok {
return nil, io.ErrClosedPipe
}
defer packet.buffer.Release()
if buffer.FreeLen() < packet.buffer.Len() {
return nil, io.ErrShortBuffer
}
return packet.destination, common.Error(buffer.Write(packet.buffer.Bytes()))
}
}
func (c *packetConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), c.source)
if err != nil {
return E.Cause(err, "tproxy udp write back")
}
defer udpConn.Close()
return common.Error(udpConn.Write(buffer.Bytes()))
}

123
common/udpnat/service.go Normal file
View file

@ -0,0 +1,123 @@
package udpnat
import (
"context"
"io"
"net"
"os"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
)
type Handler interface {
socks.UDPConnectionHandler
E.Handler
}
type Service[K comparable] struct {
nat gsync.Map[K, *conn]
handler Handler
}
func New[T comparable](handler Handler) *Service[T] {
return &Service[T]{
handler: handler,
}
}
func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) error {
c, loaded := s.nat.LoadOrStore(key, func() *conn {
c := &conn{
data: make(chan packet),
remoteAddr: metadata.Source.UDPAddr(),
source: writer(),
}
c.ctx, c.cancel = context.WithCancel(context.Background())
return c
})
if !loaded {
go func() {
err := s.handler.NewPacketConnection(c, metadata)
if err != nil {
s.handler.HandleError(err)
}
}()
}
ctx, done := context.WithCancel(c.ctx)
p := packet{
done: done,
data: buffer,
destination: metadata.Destination,
}
c.data <- p
<-ctx.Done()
return nil
}
type packet struct {
data *buf.Buffer
destination *M.AddrPort
done context.CancelFunc
}
type conn struct {
ctx context.Context
cancel context.CancelFunc
data chan packet
remoteAddr *net.UDPAddr
source socks.PacketWriter
}
func (c *conn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
select {
case p, ok := <-c.data:
if !ok {
return nil, io.ErrClosedPipe
}
defer p.data.Release()
_, err := buffer.ReadFrom(p.data)
p.done()
return p.destination, err
}
}
func (c *conn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
return c.source.WritePacket(buffer, destination)
}
func (c *conn) Close() error {
c.cancel()
select {
case <-c.data:
return os.ErrClosed
default:
close(c.data)
return nil
}
}
func (c *conn) LocalAddr() net.Addr {
return &common.DummyAddr{}
}
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *conn) SetDeadline(t time.Time) error {
return nil
}
func (c *conn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return nil
}

View file

@ -30,7 +30,6 @@ func TestServerConn(t *testing.T) {
Port: 53, Port: 53,
})) }))
_buffer := buf.StackNew() _buffer := buf.StackNew()
common.Use(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
common.Must2(buffer.ReadPacketFrom(clientConn)) common.Must2(buffer.ReadPacketFrom(clientConn))
common.Must(message.Unpack(buffer.Bytes())) common.Must(message.Unpack(buffer.Bytes()))

View file

@ -4,32 +4,35 @@ import (
"context" "context"
"net" "net"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/protocol/socks"
) )
type Service interface { type Service interface {
M.TCPConnectionHandler M.TCPConnectionHandler
} socks.UDPHandler
type MultiUserService interface {
Service
AddUser(key []byte)
RemoveUser(key []byte)
} }
type Handler interface { type Handler interface {
M.TCPConnectionHandler M.TCPConnectionHandler
socks.UDPConnectionHandler
E.Handler
} }
type NoneService struct { type NoneService struct {
handler Handler handler Handler
udp *udpnat.Service[string]
} }
func NewNoneService(handler Handler) Service { func NewNoneService(handler Handler) Service {
return &NoneService{ s := &NoneService{
handler: handler, handler: handler,
} }
s.udp = udpnat.New[string](s)
return s
} }
func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
@ -41,3 +44,37 @@ func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata
metadata.Destination = destination metadata.Destination = destination
return s.handler.NewConnection(ctx, conn, metadata) return s.handler.NewConnection(ctx, conn, metadata)
} }
func (s *NoneService) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
return err
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &serverPacketWriter{conn, metadata.Source}
}, buffer, metadata)
}
type serverPacketWriter struct {
socks.PacketConn
sourceAddr *M.AddrPort
}
func (s *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
header := buf.With(buffer.ExtendHeader(socks.AddressSerializer.AddrPortLen(destination)))
err := socks.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
return s.PacketConn.WritePacket(buffer, s.sourceAddr)
}
func (s *NoneService) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(conn, metadata)
}
func (s *NoneService) HandleError(err error) {
s.handler.HandleError(err)
}

View file

@ -156,13 +156,13 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
} }
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn { func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
return &clientPacketConn{conn, m} return &clientPacketConn{m, conn}
} }
func (m *Method) EncodePacket(buffer *buf.Buffer) error { func (m *Method) EncodePacket(buffer *buf.Buffer) error {
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength) key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key)) c := m.constructor(common.Dup(key))
c.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
buffer.Extend(c.Overhead()) buffer.Extend(c.Overhead())
return nil return nil
} }
@ -299,20 +299,18 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
} }
type clientPacketConn struct { type clientPacketConn struct {
*Method
net.Conn net.Conn
method *Method
} }
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
_header := buf.StackNew() header := buffer.ExtendHeader(c.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
header := common.Dup(_header) common.Must1(io.ReadFull(c.secureRNG, header[:c.keySaltLength]))
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength)) err := socks.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
err := socks.AddressSerializer.WriteAddrPort(header, destination)
if err != nil { if err != nil {
return err return err
} }
buffer = buffer.WriteBufferAtFirst(header) err = c.EncodePacket(buffer)
err = c.method.EncodePacket(buffer)
if err != nil { if err != nil {
return err return err
} }
@ -325,7 +323,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
return nil, err return nil, err
} }
buffer.Truncate(n) buffer.Truncate(n)
err = c.method.DecodePacket(buffer) err = c.DecodePacket(buffer)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -13,6 +13,7 @@ import (
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay" "github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/protocol/socks"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
@ -25,6 +26,7 @@ type Service struct {
key []byte key []byte
secureRNG io.Reader secureRNG io.Reader
replayFilter replay.Filter replayFilter replay.Filter
udp *udpnat.Service[string]
handler shadowsocks.Handler handler shadowsocks.Handler
} }
@ -34,6 +36,7 @@ func NewService(method string, key []byte, password []byte, secureRNG io.Reader,
secureRNG: secureRNG, secureRNG: secureRNG,
handler: handler, handler: handler,
} }
s.udp = udpnat.New[string](s)
if replayFilter { if replayFilter {
s.replayFilter = replay.NewBloomRing() s.replayFilter = replay.NewBloomRing()
} }
@ -163,3 +166,59 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) { func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w) return c.reader.WriteTo(w)
} }
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
if buffer.Len() < s.keySaltLength {
return E.New("bad packet")
}
key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength)
c := s.constructor(common.Dup(key))
/*data := buf.New()
packet, err := c.Open(data.Index(0), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil {
return err
}
data.Truncate(len(packet))
metadata.Protocol = "shadowsocks"
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source}
}, data, metadata)*/
packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil {
return err
}
buffer.Advance(s.keySaltLength)
buffer.Truncate(len(packet))
metadata.Protocol = "shadowsocks"
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source}
}, buffer, metadata)
}
type serverPacketWriter struct {
*Service
socks.PacketConn
source *M.AddrPort
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
header := buffer.ExtendHeader(w.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(w.secureRNG, header[:w.keySaltLength]))
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
if err != nil {
return err
}
key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength)
c := w.constructor(common.Dup(key))
c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil)
buffer.Extend(c.Overhead())
return w.PacketConn.WritePacket(buffer, w.source)
}
func (s *Service) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(conn, metadata)
}
func (s *Service) HandleError(err error) {
s.handler.HandleError(err)
}

View file

@ -109,9 +109,9 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
} }
func Blake3DeriveKey(psk, salt []byte, keyLength int) []byte { func Blake3DeriveKey(psk, salt []byte, keyLength int) []byte {
sessionKey := make([]byte, 2*KeySaltSize) sessionKey := buf.Make(len(psk) + len(salt))
copy(sessionKey, psk) copy(sessionKey, psk)
copy(sessionKey[KeySaltSize:], salt) copy(sessionKey[len(psk):], salt)
outKey := buf.Make(keyLength) outKey := buf.Make(keyLength)
blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey) blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey)
return outKey return outKey
@ -434,6 +434,9 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo
return err return err
} }
buffer = buffer.WriteBufferAtFirst(header) buffer = buffer.WriteBufferAtFirst(header)
if err != nil {
return err
}
if c.method.udpCipher != nil { if c.method.udpCipher != nil {
c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
buffer.Extend(c.method.udpCipher.Overhead()) buffer.Extend(c.method.udpCipher.Overhead())
@ -574,9 +577,9 @@ func (s *udpSession) nextPacketId() uint64 {
} }
func (m *Method) newUDPSession() *udpSession { func (m *Method) newUDPSession() *udpSession {
session := &udpSession{ session := &udpSession{}
sessionId: rand.Uint64(), common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId))
} session.packetId--
if m.udpCipher == nil { if m.udpCipher == nil {
sessionId := make([]byte, 8) sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId) binary.BigEndian.PutUint64(sessionId, session.sessionId)

View file

@ -2,22 +2,28 @@ package shadowaead_2022
import ( import (
"context" "context"
"crypto/aes"
"crypto/cipher" "crypto/cipher"
"encoding/binary" "encoding/binary"
"io" "io"
"math" "math"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay" "github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/protocol/socks"
wgReplay "golang.zx2c4.com/wireguard/replay"
) )
type Service struct { type Service struct {
@ -25,9 +31,14 @@ type Service struct {
secureRNG io.Reader secureRNG io.Reader
keyLength int keyLength int
constructor func(key []byte) cipher.AEAD constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block
udpCipher cipher.AEAD
udpBlockCipher cipher.Block
psk []byte psk []byte
replayFilter replay.Filter replayFilter replay.Filter
handler shadowsocks.Handler handler shadowsocks.Handler
udpNat *udpnat.Service[uint64]
sessions gsync.Map[uint64, *serverUDPSession]
} }
func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) { func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) {
@ -47,18 +58,20 @@ func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowso
case "2022-blake3-aes-128-gcm": case "2022-blake3-aes-128-gcm":
s.keyLength = 16 s.keyLength = 16
s.constructor = newAESGCM s.constructor = newAESGCM
// m.blockConstructor = newAES s.blockConstructor = newAES
// m.udpBlockCipher = newAES(m.psk) s.udpBlockCipher = newAES(s.psk)
case "2022-blake3-aes-256-gcm": case "2022-blake3-aes-256-gcm":
s.keyLength = 32 s.keyLength = 32
s.constructor = newAESGCM s.constructor = newAESGCM
// m.blockConstructor = newAES s.blockConstructor = newAES
// m.udpBlockCipher = newAES(m.psk) s.udpBlockCipher = newAES(s.psk)
case "2022-blake3-chacha20-poly1305": case "2022-blake3-chacha20-poly1305":
s.keyLength = 32 s.keyLength = 32
s.constructor = newChacha20Poly1305 s.constructor = newChacha20Poly1305
// m.udpCipher = newXChacha20Poly1305(m.psk) s.udpCipher = newXChacha20Poly1305(s.psk)
} }
s.udpNat = udpnat.New[uint64](s)
return s, nil return s, nil
} }
@ -194,3 +207,169 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) { func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w) return c.reader.WriteTo(w)
} }
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
var packetHeader []byte
if s.udpCipher != nil {
_, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
if err != nil {
return E.Cause(err, "decrypt packet header")
}
buffer.Advance(PacketNonceSize)
} else {
packetHeader = buffer.To(aes.BlockSize)
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
}
var sessionId, packetId uint64
err := binary.Read(buffer, binary.BigEndian, &sessionId)
if err != nil {
return err
}
err = binary.Read(buffer, binary.BigEndian, &packetId)
if err != nil {
return err
}
session, loaded := s.sessions.LoadOrStore(sessionId, s.newUDPSession)
if !loaded {
session.remoteSessionId = sessionId
if packetHeader != nil {
key := Blake3DeriveKey(s.psk, packetHeader[:8], s.keyLength)
session.remoteCipher = s.constructor(common.Dup(key))
}
}
if !session.filter.ValidateCounter(packetId, math.MaxUint64) {
return ErrPacketIdNotUnique
}
if packetHeader != nil {
_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
if err != nil {
return E.Cause(err, "decrypt packet")
}
}
var headerType byte
headerType, err = buffer.ReadByte()
if err != nil {
return err
}
if headerType != HeaderTypeClient {
return ErrBadHeaderType
}
var epoch uint64
err = binary.Read(buffer, binary.BigEndian, &epoch)
if err != nil {
return err
}
if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 {
return ErrBadTimestamp
}
var paddingLength uint16
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
if err != nil {
return E.Cause(err, "read padding length")
}
buffer.Advance(int(paddingLength))
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
return err
}
metadata.Destination = destination
return s.udpNat.NewPacket(sessionId, func() socks.PacketWriter {
return &serverPacketWriter{s, conn, session, metadata.Source}
}, buffer, metadata)
}
type serverPacketWriter struct {
*Service
socks.PacketConn
session *serverUDPSession
source *M.AddrPort
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
_header := buf.StackNew()
header := common.Dup(_header)
var dataIndex int
if w.udpCipher != nil {
common.Must1(header.ReadFullFrom(w.secureRNG, PacketNonceSize))
dataIndex = buffer.Len()
} else {
dataIndex = aes.BlockSize
}
common.Must(
binary.Write(header, binary.BigEndian, w.session.sessionId),
binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
header.WriteByte(HeaderTypeServer),
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
err := socks.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
_, err = header.Write(buffer.Bytes())
if err != nil {
return err
}
if w.udpCipher != nil {
w.udpCipher.Seal(header.Index(dataIndex), header.To(dataIndex), header.From(dataIndex), nil)
header.Extend(w.udpCipher.Overhead())
} else {
packetHeader := header.To(aes.BlockSize)
w.session.cipher.Seal(header.Index(dataIndex), packetHeader[4:16], header.From(dataIndex), nil)
header.Extend(w.session.cipher.Overhead())
w.udpBlockCipher.Encrypt(packetHeader, packetHeader)
}
return w.PacketConn.WritePacket(header, w.source)
}
type serverUDPSession struct {
sessionId uint64
remoteSessionId uint64
packetId uint64
cipher cipher.AEAD
remoteCipher cipher.AEAD
filter wgReplay.Filter
}
func (s *serverUDPSession) nextPacketId() uint64 {
return atomic.AddUint64(&s.packetId, 1)
}
func (m *Service) newUDPSession() *serverUDPSession {
session := &serverUDPSession{}
common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId))
session.packetId--
if m.udpCipher == nil {
sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := Blake3DeriveKey(m.psk, sessionId, m.keyLength)
session.cipher = m.constructor(common.Dup(key))
}
return session
}
func (s *Service) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(conn, metadata)
}
func (s *Service) HandleError(err error) {
s.handler.HandleError(err)
}

View file

@ -8,12 +8,21 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/task"
) )
type PacketConn interface { type PacketReader interface {
ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error)
WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error }
type PacketWriter interface {
WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error
}
type PacketConn interface {
PacketReader
PacketWriter
Close() error Close() error
LocalAddr() net.Addr LocalAddr() net.Addr
@ -23,6 +32,10 @@ type PacketConn interface {
SetWriteDeadline(t time.Time) error SetWriteDeadline(t time.Time) error
} }
type UDPHandler interface {
NewPacket(conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
}
type UDPConnectionHandler interface { type UDPConnectionHandler interface {
NewPacketConnection(conn PacketConn, metadata M.Metadata) error NewPacketConnection(conn PacketConn, metadata M.Metadata) error
} }
@ -47,6 +60,8 @@ func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error { func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error {
return task.Run(ctx, func() error { return task.Run(ctx, func() error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(dest)
_buffer := buf.StackNewMax() _buffer := buf.StackNewMax()
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader) data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
@ -56,13 +71,15 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
if err != nil { if err != nil {
return err return err
} }
buffer.Truncate(data.Len()) buffer.Resize(buf.ReversedHeader+data.Start(), data.Len())
err = dest.WritePacket(buffer, destination) err = dest.WritePacket(buffer, destination)
if err != nil { if err != nil {
return err return err
} }
} }
}, func() error { }, func() error {
defer rw.CloseRead(dest)
defer rw.CloseWrite(conn)
_buffer := buf.StackNewMax() _buffer := buf.StackNewMax()
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader) data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
@ -72,7 +89,7 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
if err != nil { if err != nil {
return err return err
} }
buffer.Truncate(data.Len()) buffer.Resize(buf.ReversedHeader+data.Start(), data.Len())
err = conn.WritePacket(buffer, destination) err = conn.WritePacket(buffer, destination)
if err != nil { if err != nil {
return err return err
@ -81,44 +98,211 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
}) })
} }
func CopyPacketConn0(dest PacketConn, conn PacketConn, onAction func(destination *M.AddrPort, n int)) error { func CopyNetPacketConn(ctx context.Context, dest net.PacketConn, conn PacketConn) error {
return task.Run(ctx, func() error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(dest)
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
for { for {
buffer := buf.New() buffer.FullReset()
destination, err := conn.ReadPacket(buffer) destination, err := conn.ReadPacket(buffer)
if err != nil { if err != nil {
buffer.Release()
return err return err
} }
size := buffer.Len()
err = dest.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return err
}
if onAction != nil {
onAction(destination, size)
}
}
}
type associatePacketConn struct { _, err = dest.WriteTo(buffer.Bytes(), destination.UDPAddr())
net.PacketConn if err != nil {
return err
}
}
}, func() error {
defer rw.CloseRead(dest)
defer rw.CloseWrite(conn)
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
for {
data.FullReset()
n, addr, err := dest.ReadFrom(data.FreeBytes())
if err != nil {
return err
}
buffer.Resize(buf.ReversedHeader, n)
err = conn.WritePacket(buffer, M.AddrPortFromNetAddr(addr))
if err != nil {
return err
}
}
})
}
type AssociateConn struct {
net.Conn
conn net.Conn conn net.Conn
addr net.Addr addr net.Addr
dest *M.AddrPort
} }
func NewPacketConn(conn net.Conn, packetConn net.PacketConn) PacketConn { func NewAssociateConn(conn net.Conn, packetConn net.Conn, destination *M.AddrPort) net.PacketConn {
return &associatePacketConn{ return &AssociateConn{
PacketConn: packetConn, Conn: packetConn,
conn: conn, conn: conn,
dest: destination,
} }
} }
func (c *associatePacketConn) RemoteAddr() net.Addr { func (c *AssociateConn) RemoteAddr() net.Addr {
return c.addr return c.addr
} }
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { func (c *AssociateConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.Conn.Read(p)
if err != nil {
return
}
reader := buf.As(p[3:n])
destination, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return
}
addr = destination.UDPAddr()
n = copy(p, reader.Bytes())
return
}
func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
if err != nil {
return
}
_, err = buffer.Write(p)
if err != nil {
return
}
_, err = c.Conn.Write(buffer.Bytes())
return
}
func (c *AssociateConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *AssociateConn) Write(b []byte) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, c.dest)
if err != nil {
return
}
_, err = buffer.Write(b)
if err != nil {
return
}
_, err = c.Conn.Write(buffer.Bytes())
return
}
func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, err := buffer.ReadFrom(c.conn)
if err != nil {
return nil, err
}
buffer.Truncate(int(n))
buffer.Advance(3)
return AddressSerializer.ReadAddrPort(buffer)
}
func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
common.Must(header.WriteZeroN(3))
common.Must(AddressSerializer.WriteAddrPort(header, destination))
return common.Error(c.Conn.Write(buffer.Bytes()))
}
type AssociatePacketConn struct {
net.PacketConn
conn net.Conn
addr net.Addr
dest *M.AddrPort
}
func NewAssociatePacketConn(conn net.Conn, packetConn net.PacketConn, destination *M.AddrPort) *AssociatePacketConn {
return &AssociatePacketConn{
PacketConn: packetConn,
conn: conn,
dest: destination,
}
}
func (c *AssociatePacketConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.PacketConn.ReadFrom(p)
if err != nil {
return
}
reader := buf.As(p[3:n])
destination, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return
}
addr = destination.UDPAddr()
n = copy(p, reader.Bytes())
return
}
func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
if err != nil {
return
}
_, err = buffer.Write(p)
if err != nil {
return
}
_, err = c.PacketConn.WriteTo(buffer.Bytes(), c.addr)
return
}
func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, c.dest)
if err != nil {
return
}
_, err = buffer.Write(b)
if err != nil {
return
}
_, err = c.PacketConn.WriteTo(buffer.Bytes(), c.addr)
return
}
func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes()) n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes())
if err != nil { if err != nil {
return nil, err return nil, err
@ -126,15 +310,14 @@ func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error
c.addr = addr c.addr = addr
buffer.Truncate(n) buffer.Truncate(n)
buffer.Advance(3) buffer.Advance(3)
return AddressSerializer.ReadAddrPort(buffer) dest, err := AddressSerializer.ReadAddrPort(buffer)
return dest, err
} }
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error { func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release() defer buffer.Release()
_header := buf.StackNew() header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
header := common.Dup(_header)
common.Must(header.WriteZeroN(3)) common.Must(header.WriteZeroN(3))
common.Must(AddressSerializer.WriteAddrPort(header, addrPort)) common.Must(AddressSerializer.WriteAddrPort(header, destination))
buffer = buffer.WriteBufferAtFirst(header)
return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr)) return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr))
} }

View file

@ -36,7 +36,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler
} }
func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
return HandleConnection(ctx, conn, l.authenticator, l.bindAddr, l.handler, metadata) return HandleConnection(ctx, conn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata)
} }
func (l *Listener) Start() error { func (l *Listener) Start() error {
@ -131,9 +131,10 @@ func HandleConnection(ctx context.Context, conn net.Conn, authenticator auth.Aut
if err != nil { if err != nil {
return E.Cause(err, "write socks response") return E.Cause(err, "write socks response")
} }
metadata.Protocol = "socks"
metadata.Destination = request.Destination metadata.Destination = request.Destination
go func() { go func() {
err := handler.NewPacketConnection(NewPacketConn(conn, udpConn), metadata) err := handler.NewPacketConnection(NewAssociatePacketConn(conn, udpConn, request.Destination), metadata)
if err != nil { if err != nil {
handler.HandleError(err) handler.HandleError(err)
} }

View file

@ -32,7 +32,7 @@ type Listener struct {
bindAddr netip.Addr bindAddr netip.Addr
handler Handler handler Handler
authenticator auth.Authenticator authenticator auth.Authenticator
udpNat *udpnat.Server udpNat *udpnat.Service[string]
} }
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, handler Handler) *Listener { func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, handler Handler) *Listener {
@ -45,7 +45,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transpro
listener.TCPListener = tcp.NewTCPListener(bind, listener, tcp.WithTransproxyMode(transproxy)) listener.TCPListener = tcp.NewTCPListener(bind, listener, tcp.WithTransproxyMode(transproxy))
if transproxy == redir.ModeTProxy { if transproxy == redir.ModeTProxy {
listener.UDPListener = udp.NewUDPListener(bind, listener, udp.WithTransproxyMode(transproxy)) listener.UDPListener = udp.NewUDPListener(bind, listener, udp.WithTransproxyMode(transproxy))
listener.udpNat = udpnat.NewServer(handler) listener.udpNat = udpnat.New[string](handler)
} }
return listener return listener
} }
@ -63,7 +63,7 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.
case socks.Version4: case socks.Version4:
return E.New("socks4 request dropped (TODO)") return E.New("socks4 request dropped (TODO)")
case socks.Version5: case socks.Version5:
return socks.HandleConnection(ctx, bufConn, l.authenticator, l.bindAddr, l.handler, metadata) return socks.HandleConnection(ctx, bufConn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata)
} }
request, err := http.ReadRequest(bufConn.Reader()) request, err := http.ReadRequest(bufConn.Reader())
@ -96,8 +96,23 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.
return http.HandleRequest(ctx, request, bufConn, l.authenticator, l.handler, metadata) return http.HandleRequest(ctx, request, bufConn, l.authenticator, l.handler, metadata)
} }
func (l *Listener) NewPacket(packet *buf.Buffer, metadata M.Metadata) error { func (l *Listener) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
return l.udpNat.HandleUDP(packet, metadata) return l.udpNat.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &tproxyPacketWriter{metadata.Source.UDPAddr()}
}, buffer, metadata)
}
type tproxyPacketWriter struct {
source *net.UDPAddr
}
func (w *tproxyPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), w.source)
if err != nil {
return E.Cause(err, "tproxy udp write back")
}
defer udpConn.Close()
return common.Error(udpConn.Write(buffer.Bytes()))
} }
func (l *Listener) HandleError(err error) { func (l *Listener) HandleError(err error) {

View file

@ -9,10 +9,11 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/redir" "github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/protocol/socks"
) )
type Handler interface { type Handler interface {
M.UDPHandler socks.UDPHandler
E.Handler E.Handler
} }
@ -24,6 +25,19 @@ type Listener struct {
tproxy bool tproxy bool
} }
func (l *Listener) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, addr, err := l.ReadFromUDP(buffer.FreeBytes())
if err != nil {
return nil, err
}
buffer.Truncate(n)
return M.AddrPortFromNetAddr(addr), nil
}
func (l *Listener) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
return common.Error(l.UDPConn.WriteTo(buffer.Bytes(), destination.UDPAddr()))
}
func NewUDPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener { func NewUDPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener {
listener := &Listener{ listener := &Listener{
handler: handler, handler: handler,
@ -69,32 +83,31 @@ func (l *Listener) Close() error {
} }
func (l *Listener) loop() { func (l *Listener) loop() {
_buffer := buf.StackNewMax()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader).Slice()
if !l.tproxy { if !l.tproxy {
for { for {
buffer := buf.New() n, addr, err := l.ReadFromUDP(data)
n, addr, err := l.ReadFromUDP(buffer.Extend(buf.UDPBufferSize))
if err != nil { if err != nil {
buffer.Release()
l.handler.HandleError(err) l.handler.HandleError(err)
return return
} }
buffer.Truncate(n) buffer.Resize(buf.ReversedHeader, n)
err = l.handler.NewPacket(buffer, M.Metadata{ err = l.handler.NewPacket(l, buffer, M.Metadata{
Protocol: "udp", Protocol: "udp",
Source: M.AddrPortFromNetAddr(addr), Source: M.AddrPortFromNetAddr(addr),
}) })
if err != nil { if err != nil {
buffer.Release()
l.handler.HandleError(err) l.handler.HandleError(err)
} }
} }
} else { } else {
oob := make([]byte, 1024) _oob := make([]byte, 1024)
oob := common.Dup(_oob)
for { for {
buffer := buf.New() n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(data, oob)
n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(buffer.FreeBytes(), oob)
if err != nil { if err != nil {
buffer.Release()
l.handler.HandleError(err) l.handler.HandleError(err)
return return
} }
@ -103,14 +116,13 @@ func (l *Listener) loop() {
l.handler.HandleError(E.Cause(err, "get original destination")) l.handler.HandleError(E.Cause(err, "get original destination"))
return return
} }
buffer.Truncate(n) buffer.Resize(buf.ReversedHeader, n)
err = l.handler.NewPacket(buffer, M.Metadata{ err = l.handler.NewPacket(l, buffer, M.Metadata{
Protocol: "tproxy", Protocol: "tproxy",
Source: M.AddrPortFromAddrPort(addr), Source: M.AddrPortFromAddrPort(addr),
Destination: destination, Destination: destination,
}) })
if err != nil { if err != nil {
buffer.Release()
l.handler.HandleError(err) l.handler.HandleError(err)
} }
} }