Add shadowsocks service

This commit is contained in:
世界 2022-04-29 12:06:10 +08:00
parent 5be6eb2d64
commit df1e1cfafd
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
21 changed files with 1028 additions and 265 deletions

View file

@ -28,6 +28,12 @@ wget 'https://github.com/Dreamacro/maxmind-geoip/releases/latest/download/Countr
go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/ss-local
```
### ss-server
```shell
go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/ss-server
```
### ddns
```shell

145
cli/socks-chk/main.go Normal file
View file

@ -0,0 +1,145 @@
package main
import (
"context"
"encoding/binary"
"io"
"net"
"net/netip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/net/dns/dnsmessage"
)
func main() {
command := &cobra.Command{
Use: "socks-chk address:port",
Args: cobra.ExactArgs(1),
Run: run,
}
if err := command.Execute(); err != nil {
logrus.Fatal(err)
}
}
func run(cmd *cobra.Command, args []string) {
server, err := M.ParseAddress(args[0])
if err != nil {
logrus.Fatal("invalid server address ", args[0])
}
err = testSocksTCP(server)
if err != nil {
logrus.Fatal(err)
}
err = testSocksUDP(server)
if err != nil {
logrus.Fatal(err)
}
}
func testSocksTCP(server *M.AddrPort) error {
tcpConn, err := net.Dial("tcp", server.String())
if err != nil {
return err
}
response, err := socks.ClientHandshake(tcpConn, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.ParseAddr("1.0.0.1"), 53), "", "")
if err != nil {
return err
}
if response.ReplyCode != socks.ReplyCodeSuccess {
logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode)
}
message := &dnsmessage.Message{}
message.Header.ID = 1
message.Header.RecursionDesired = true
message.Questions = append(message.Questions, dnsmessage.Question{
Name: dnsmessage.MustNewName("google.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
})
packet, err := message.Pack()
err = binary.Write(tcpConn, binary.BigEndian, uint16(len(packet)))
if err != nil {
return err
}
_, err = tcpConn.Write(packet)
if err != nil {
return err
}
var respLen uint16
err = binary.Read(tcpConn, binary.BigEndian, &respLen)
if err != nil {
return err
}
respBuf := buf.Make(int(respLen))
_, err = io.ReadFull(tcpConn, respBuf)
if err != nil {
return err
}
common.Must(message.Unpack(respBuf))
for _, answer := range message.Answers {
logrus.Info("tcp got answer: ", netip.AddrFrom4(answer.Body.(*dnsmessage.AResource).A))
}
tcpConn.Close()
return nil
}
func testSocksUDP(server *M.AddrPort) error {
tcpConn, err := net.Dial("tcp", server.String())
if err != nil {
return err
}
dest := M.AddrPortFrom(M.ParseAddr("1.0.0.1"), 53)
response, err := socks.ClientHandshake(tcpConn, socks.Version5, socks.CommandUDPAssociate, dest, "", "")
if err != nil {
return err
}
if response.ReplyCode != socks.ReplyCodeSuccess {
logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode)
}
var dialer net.Dialer
udpConn, err := dialer.DialContext(context.Background(), "udp", response.Bind.String())
if err != nil {
return err
}
assConn := socks.NewAssociateConn(tcpConn, udpConn, dest)
message := &dnsmessage.Message{}
message.Header.ID = 1
message.Header.RecursionDesired = true
message.Questions = append(message.Questions, dnsmessage.Question{
Name: dnsmessage.MustNewName("google.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
})
packet, err := message.Pack()
common.Must(err)
common.Must1(assConn.WriteTo(packet, &net.UDPAddr{
IP: net.IPv4(1, 0, 0, 1),
Port: 53,
}))
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must2(buffer.ReadPacketFrom(assConn))
common.Must(message.Unpack(buffer.Bytes()))
for _, answer := range message.Answers {
logrus.Info("udp got answer: ", netip.AddrFrom4(answer.Body.(*dnsmessage.AResource).A))
}
udpConn.Close()
tcpConn.Close()
return nil
}

View file

@ -345,28 +345,15 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
return rw.CopyConn(ctx, serverConn, conn)
}
func (c *client) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error {
func (c *client) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
logrus.Info("outbound ", metadata.Protocol, " UDP ", metadata.Source, " ==> ", metadata.Destination)
ctx := context.Background()
udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String())
if err != nil {
return err
}
serverConn := c.method.DialPacketConn(udpConn)
return task.Run(ctx, func() error {
var init bool
return socks.CopyPacketConn0(serverConn, conn, func(destination *M.AddrPort, n int) {
if !init {
init = true
logrus.Info("UDP ", conn.LocalAddr(), " ==> ", destination)
} else {
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination)
}
})
}, func() error {
return socks.CopyPacketConn0(conn, serverConn, func(destination *M.AddrPort, n int) {
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination)
})
})
return socks.CopyPacketConn(ctx, serverConn, conn)
}
func run(cmd *cobra.Command, flags *flags) {

View file

@ -3,6 +3,8 @@ package main
import (
"context"
"encoding/base64"
"encoding/json"
"io/ioutil"
"net"
"net/netip"
"os"
@ -12,22 +14,29 @@ import (
"github.com/sagernet/sing"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/random"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/transport/tcp"
"github.com/sagernet/sing/transport/udp"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
type flags struct {
Bind string `json:"local_address"`
LocalPort uint16 `json:"local_port"`
// Password string `json:"password"`
Server string `json:"server"`
ServerPort uint16 `json:"server_port"`
Bind string `json:"local_address"`
LocalPort uint16 `json:"local_port"`
Password string `json:"password"`
Key string `json:"key"`
Method string `json:"method"`
Verbose bool `json:"verbose"`
@ -35,8 +44,6 @@ type flags struct {
}
func main() {
logrus.SetLevel(logrus.TraceLevel)
f := new(flags)
command := &cobra.Command{
@ -48,12 +55,16 @@ func main() {
},
}
command.Flags().StringVarP(&f.Server, "server", "s", "", "Set the server’s hostname or IP.")
command.Flags().Uint16VarP(&f.ServerPort, "server-port", "p", 0, "Set the server’s port number.")
command.Flags().StringVarP(&f.Bind, "local-address", "b", "", "Set the local address.")
command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.")
command.Flags().StringVarP(&f.Key, "key", "k", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
command.Flags().StringVar(&f.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
command.Flags().StringVarP(&f.Password, "password", "k", "", "Set the password. The server and the client should use the same password.")
var supportedCiphers []string
supportedCiphers = append(supportedCiphers, shadowsocks.MethodNone)
supportedCiphers = append(supportedCiphers, shadowaead.List...)
supportedCiphers = append(supportedCiphers, shadowaead_2022.List...)
command.Flags().StringVarP(&f.Method, "encrypt-method", "m", "", "Set the cipher.\n\nSupported ciphers:\n\n"+strings.Join(supportedCiphers, "\n"))
@ -75,6 +86,10 @@ func run(cmd *cobra.Command, f *flags) {
if err != nil {
logrus.Fatal(err)
}
err = s.udpIn.Start()
if err != nil {
logrus.Fatal(err)
}
logrus.Info("server started at ", s.tcpIn.TCPListener.Addr())
osSignals := make(chan os.Signal, 1)
@ -82,33 +97,101 @@ func run(cmd *cobra.Command, f *flags) {
<-osSignals
s.tcpIn.Close()
s.udpIn.Close()
}
type server struct {
tcpIn *tcp.Listener
udpIn *udp.Listener
service shadowsocks.Service
udpNat gsync.Map[string, *net.UDPConn]
}
func (s *server) Start() error {
err := s.tcpIn.Start()
if err != nil {
return err
}
err = s.udpIn.Start()
return err
}
func (s *server) Close() error {
s.tcpIn.Close()
s.udpIn.Close()
return nil
}
func newServer(f *flags) (*server, error) {
s := new(server)
if f.ConfigFile != "" {
configFile, err := ioutil.ReadFile(f.ConfigFile)
if err != nil {
return nil, E.Cause(err, "read config file")
}
flagsNew := new(flags)
err = json.Unmarshal(configFile, flagsNew)
if err != nil {
return nil, E.Cause(err, "decode config file")
}
if flagsNew.Server != "" && f.Server == "" {
f.Server = flagsNew.Server
}
if flagsNew.ServerPort != 0 && f.ServerPort == 0 {
f.ServerPort = flagsNew.ServerPort
}
if flagsNew.Bind != "" && f.Bind == "" {
f.Bind = flagsNew.Bind
}
if flagsNew.LocalPort != 0 && f.LocalPort == 0 {
f.LocalPort = flagsNew.LocalPort
}
if flagsNew.Password != "" && f.Password == "" {
f.Password = flagsNew.Password
}
if flagsNew.Key != "" && f.Key == "" {
f.Key = flagsNew.Key
}
if flagsNew.Method != "" && f.Method == "" {
f.Method = flagsNew.Method
}
if flagsNew.Verbose {
f.Verbose = true
}
}
if f.Verbose {
logrus.SetLevel(logrus.TraceLevel)
}
if f.Server == "" {
return nil, E.New("missing server address")
} else if f.ServerPort == 0 {
return nil, E.New("missing server port")
} else if f.Method == "" {
return nil, E.New("missing method")
}
var key []byte
if f.Key != "" {
kb, err := base64.StdEncoding.DecodeString(f.Key)
if err != nil {
return nil, E.Cause(err, "decode key")
}
key = kb
}
if f.Method == shadowsocks.MethodNone {
s.service = shadowsocks.NewNoneService(s)
} else if common.Contains(shadowaead_2022.List, f.Method) {
var pskList [][]byte
if f.Key != "" {
keyStrList := strings.Split(f.Key, ":")
pskList = make([][]byte, len(keyStrList))
for i, keyStr := range keyStrList {
key, err := base64.StdEncoding.DecodeString(keyStr)
if err != nil {
return nil, E.Cause(err, "decode key")
}
pskList[i] = key
}
} else if common.Contains(shadowaead.List, f.Method) {
service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Blake3KeyedHash(), false, s)
if err != nil {
return nil, err
}
rng := random.System
service, err := shadowaead_2022.NewService(f.Method, pskList[0], rng, s)
s.service = service
} else if common.Contains(shadowaead_2022.List, f.Method) {
service, err := shadowaead_2022.NewService(f.Method, key, random.Blake3KeyedHash(), s)
if err != nil {
return nil, err
}
@ -118,16 +201,17 @@ func newServer(f *flags) (*server, error) {
}
var bind netip.Addr
if f.Bind != "" {
addr, err := netip.ParseAddr(f.Bind)
if f.Server != "" {
addr, err := netip.ParseAddr(f.Server)
if err != nil {
return nil, E.Cause(err, "bad local address")
return nil, E.Cause(err, "bad server address")
}
bind = addr
} else {
bind = netip.IPv6Unspecified()
}
s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.LocalPort), s)
s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.ServerPort), s)
s.udpIn = udp.NewUDPListener(netip.AddrPortFrom(bind, f.ServerPort), s)
return s, nil
}
@ -143,6 +227,19 @@ func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
return rw.CopyConn(ctx, conn, destConn)
}
func (s *server) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
logrus.Info("inbound UDP ", metadata.Source, " ==> ", metadata.Destination)
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return err
}
return socks.CopyNetPacketConn(context.Background(), udpConn, conn)
}
func (s *server) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
return s.service.NewPacket(conn, buffer, metadata)
}
func (s *server) HandleError(err error) {
if E.IsClosed(err) {
return

View file

@ -14,7 +14,6 @@ import (
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/transport/mixed"
@ -122,15 +121,7 @@ func (c *localClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) e
}
client := uot.NewClientConn(upstream)
return task.Run(context.Background(), func() error {
return socks.CopyPacketConn0(client, conn, func(destination *M.AddrPort, n int) {
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination)
})
}, func() error {
return socks.CopyPacketConn0(conn, client, func(destination *M.AddrPort, n int) {
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination)
})
})
return socks.CopyPacketConn(context.Background(), client, conn)
}
func (c *localClient) OnError(err error) {

View file

@ -104,11 +104,15 @@ func (b *Buffer) ExtendHeader(size int) []byte {
b.start -= size
return b.data[b.start-size : b.start]
} else {
offset := size - b.start
/*offset := size - b.start
end := b.end + size
if end > len(b.data) {
panic("buffer overflow")
}
copy(b.data[offset:end], b.data[b.start:b.end])
b.end = end
return b.data[:offset]
return b.data[:offset]*/
panic("no header available")
}
}
@ -119,10 +123,18 @@ func (b *Buffer) WriteBufferAtFirst(buffer *Buffer) *Buffer {
b.start -= n
buffer.Release()
return b
} else if buffer.FreeLen() >= b.Len() {
common.Must1(buffer.Write(b.Bytes()))
b.Release()
return buffer
} else if b.FreeLen() >= size {
copy(b.data[b.start+size:b.end+size], b.data[b.start:b.end])
copy(b.data, buffer.data)
buffer.Release()
return b
} else {
panic("buffer overflow")
}
common.Must1(buffer.Write(b.Bytes()))
b.Release()
return buffer
}
func (b *Buffer) WriteAtFirst(data []byte) (n int, err error) {
@ -305,6 +317,10 @@ func (b *Buffer) Cut(start int, end int) *Buffer {
}
}
func (b Buffer) Start() int {
return b.start
}
func (b Buffer) Len() int {
return b.end - b.start
}

View file

@ -3,8 +3,6 @@ package metadata
import (
"context"
"net"
"github.com/sagernet/sing/common/buf"
)
type Metadata struct {
@ -16,7 +14,3 @@ type Metadata struct {
type TCPConnectionHandler interface {
NewConnection(ctx context.Context, conn net.Conn, metadata Metadata) error
}
type UDPHandler interface {
NewPacket(packet *buf.Buffer, metadata Metadata) error
}

View file

@ -55,28 +55,43 @@ func (s *Serializer) WriteAddress(writer io.Writer, addr Addr) error {
return err
}
func (s *Serializer) AddressLen(addr Addr) int {
switch addr.Family() {
case AddressFamilyIPv4:
return 5
case AddressFamilyIPv6:
return 17
default:
return 1 + len(addr.Fqdn())
}
}
func (s *Serializer) WritePort(writer io.Writer, port uint16) error {
return binary.Write(writer, binary.BigEndian, port)
}
func (s *Serializer) WriteAddrPort(writer io.Writer, addrPort *AddrPort) error {
func (s *Serializer) WriteAddrPort(writer io.Writer, destination *AddrPort) error {
var err error
if !s.portFirst {
err = s.WriteAddress(writer, addrPort.Addr)
err = s.WriteAddress(writer, destination.Addr)
} else {
err = s.WritePort(writer, addrPort.Port)
err = s.WritePort(writer, destination.Port)
}
if err != nil {
return err
}
if s.portFirst {
err = s.WriteAddress(writer, addrPort.Addr)
err = s.WriteAddress(writer, destination.Addr)
} else {
err = s.WritePort(writer, addrPort.Port)
err = s.WritePort(writer, destination.Port)
}
return err
}
func (s *Serializer) AddrPortLen(destination *AddrPort) int {
return s.AddressLen(destination.Addr) + 2
}
func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) {
af, err := rw.ReadByte(reader)
if err != nil {
@ -120,7 +135,7 @@ func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) {
return binary.BigEndian.Uint16(port), nil
}
func (s *Serializer) ReadAddrPort(reader io.Reader) (addrPort *AddrPort, err error) {
func (s *Serializer) ReadAddrPort(reader io.Reader) (destination *AddrPort, err error) {
var addr Addr
var port uint16
if !s.portFirst {

View file

@ -17,6 +17,17 @@ type DefaultDialer struct {
net.Dialer
}
func (d *DefaultDialer) ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
return net.ListenUDP(network, laddr)
}
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error) {
return d.Dialer.DialContext(ctx, network, address.String())
}
type Listener interface {
Listen(ctx context.Context, network, address string) (net.Listener, error)
ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error)
}
var SystemListener Listener = &net.ListenConfig{}

View file

@ -1,108 +0,0 @@
package udpnat
import (
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/protocol/socks"
)
type Handler interface {
socks.UDPConnectionHandler
E.Handler
}
type Server struct {
udpNat gsync.Map[string, *packetConn]
handler Handler
}
func NewServer(handler Handler) *Server {
return &Server{handler: handler}
}
func (s *Server) HandleUDP(buffer *buf.Buffer, metadata M.Metadata) error {
conn, loaded := s.udpNat.LoadOrStore(metadata.Source.String(), func() *packetConn {
return &packetConn{source: metadata.Source.UDPAddr(), in: make(chan *udpPacket)}
})
if !loaded {
go func() {
err := s.handler.NewPacketConnection(conn, metadata)
if err != nil {
s.handler.HandleError(err)
}
}()
}
conn.in <- &udpPacket{
buffer: buffer,
destination: metadata.Destination,
}
return nil
}
func (s *Server) OnError(err error) {
s.handler.HandleError(err)
}
func (s *Server) Close() error {
s.udpNat.Range(func(key string, conn *packetConn) bool {
conn.Close()
return true
})
s.udpNat = gsync.Map[string, *packetConn]{}
return nil
}
type packetConn struct {
socks.PacketConnStub
source *net.UDPAddr
in chan *udpPacket
}
type udpPacket struct {
buffer *buf.Buffer
destination *M.AddrPort
}
func (c *packetConn) LocalAddr() net.Addr {
return c.source
}
func (c *packetConn) Close() error {
select {
case <-c.in:
return io.ErrClosedPipe
default:
close(c.in)
}
return nil
}
func (c *packetConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
select {
case packet, ok := <-c.in:
if !ok {
return nil, io.ErrClosedPipe
}
defer packet.buffer.Release()
if buffer.FreeLen() < packet.buffer.Len() {
return nil, io.ErrShortBuffer
}
return packet.destination, common.Error(buffer.Write(packet.buffer.Bytes()))
}
}
func (c *packetConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), c.source)
if err != nil {
return E.Cause(err, "tproxy udp write back")
}
defer udpConn.Close()
return common.Error(udpConn.Write(buffer.Bytes()))
}

123
common/udpnat/service.go Normal file
View file

@ -0,0 +1,123 @@
package udpnat
import (
"context"
"io"
"net"
"os"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
)
type Handler interface {
socks.UDPConnectionHandler
E.Handler
}
type Service[K comparable] struct {
nat gsync.Map[K, *conn]
handler Handler
}
func New[T comparable](handler Handler) *Service[T] {
return &Service[T]{
handler: handler,
}
}
func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) error {
c, loaded := s.nat.LoadOrStore(key, func() *conn {
c := &conn{
data: make(chan packet),
remoteAddr: metadata.Source.UDPAddr(),
source: writer(),
}
c.ctx, c.cancel = context.WithCancel(context.Background())
return c
})
if !loaded {
go func() {
err := s.handler.NewPacketConnection(c, metadata)
if err != nil {
s.handler.HandleError(err)
}
}()
}
ctx, done := context.WithCancel(c.ctx)
p := packet{
done: done,
data: buffer,
destination: metadata.Destination,
}
c.data <- p
<-ctx.Done()
return nil
}
type packet struct {
data *buf.Buffer
destination *M.AddrPort
done context.CancelFunc
}
type conn struct {
ctx context.Context
cancel context.CancelFunc
data chan packet
remoteAddr *net.UDPAddr
source socks.PacketWriter
}
func (c *conn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
select {
case p, ok := <-c.data:
if !ok {
return nil, io.ErrClosedPipe
}
defer p.data.Release()
_, err := buffer.ReadFrom(p.data)
p.done()
return p.destination, err
}
}
func (c *conn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
return c.source.WritePacket(buffer, destination)
}
func (c *conn) Close() error {
c.cancel()
select {
case <-c.data:
return os.ErrClosed
default:
close(c.data)
return nil
}
}
func (c *conn) LocalAddr() net.Addr {
return &common.DummyAddr{}
}
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *conn) SetDeadline(t time.Time) error {
return nil
}
func (c *conn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return nil
}

View file

@ -30,7 +30,6 @@ func TestServerConn(t *testing.T) {
Port: 53,
}))
_buffer := buf.StackNew()
common.Use(_buffer)
buffer := common.Dup(_buffer)
common.Must2(buffer.ReadPacketFrom(clientConn))
common.Must(message.Unpack(buffer.Bytes()))

View file

@ -4,32 +4,35 @@ import (
"context"
"net"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/socks"
)
type Service interface {
M.TCPConnectionHandler
}
type MultiUserService interface {
Service
AddUser(key []byte)
RemoveUser(key []byte)
socks.UDPHandler
}
type Handler interface {
M.TCPConnectionHandler
socks.UDPConnectionHandler
E.Handler
}
type NoneService struct {
handler Handler
udp *udpnat.Service[string]
}
func NewNoneService(handler Handler) Service {
return &NoneService{
s := &NoneService{
handler: handler,
}
s.udp = udpnat.New[string](s)
return s
}
func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
@ -41,3 +44,37 @@ func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata
metadata.Destination = destination
return s.handler.NewConnection(ctx, conn, metadata)
}
func (s *NoneService) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
return err
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &serverPacketWriter{conn, metadata.Source}
}, buffer, metadata)
}
type serverPacketWriter struct {
socks.PacketConn
sourceAddr *M.AddrPort
}
func (s *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
header := buf.With(buffer.ExtendHeader(socks.AddressSerializer.AddrPortLen(destination)))
err := socks.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
return s.PacketConn.WritePacket(buffer, s.sourceAddr)
}
func (s *NoneService) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(conn, metadata)
}
func (s *NoneService) HandleError(err error) {
s.handler.HandleError(err)
}

View file

@ -156,13 +156,13 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
}
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
return &clientPacketConn{conn, m}
return &clientPacketConn{m, conn}
}
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key))
c.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
buffer.Extend(c.Overhead())
return nil
}
@ -299,20 +299,18 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
}
type clientPacketConn struct {
*Method
net.Conn
method *Method
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
_header := buf.StackNew()
header := common.Dup(_header)
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
err := socks.AddressSerializer.WriteAddrPort(header, destination)
header := buffer.ExtendHeader(c.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(c.secureRNG, header[:c.keySaltLength]))
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
if err != nil {
return err
}
buffer = buffer.WriteBufferAtFirst(header)
err = c.method.EncodePacket(buffer)
err = c.EncodePacket(buffer)
if err != nil {
return err
}
@ -325,7 +323,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
return nil, err
}
buffer.Truncate(n)
err = c.method.DecodePacket(buffer)
err = c.DecodePacket(buffer)
if err != nil {
return nil, err
}

View file

@ -13,6 +13,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks"
"golang.org/x/crypto/chacha20poly1305"
@ -25,6 +26,7 @@ type Service struct {
key []byte
secureRNG io.Reader
replayFilter replay.Filter
udp *udpnat.Service[string]
handler shadowsocks.Handler
}
@ -34,6 +36,7 @@ func NewService(method string, key []byte, password []byte, secureRNG io.Reader,
secureRNG: secureRNG,
handler: handler,
}
s.udp = udpnat.New[string](s)
if replayFilter {
s.replayFilter = replay.NewBloomRing()
}
@ -163,3 +166,59 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
if buffer.Len() < s.keySaltLength {
return E.New("bad packet")
}
key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength)
c := s.constructor(common.Dup(key))
/*data := buf.New()
packet, err := c.Open(data.Index(0), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil {
return err
}
data.Truncate(len(packet))
metadata.Protocol = "shadowsocks"
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source}
}, data, metadata)*/
packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil {
return err
}
buffer.Advance(s.keySaltLength)
buffer.Truncate(len(packet))
metadata.Protocol = "shadowsocks"
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source}
}, buffer, metadata)
}
type serverPacketWriter struct {
*Service
socks.PacketConn
source *M.AddrPort
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
header := buffer.ExtendHeader(w.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(w.secureRNG, header[:w.keySaltLength]))
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
if err != nil {
return err
}
key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength)
c := w.constructor(common.Dup(key))
c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil)
buffer.Extend(c.Overhead())
return w.PacketConn.WritePacket(buffer, w.source)
}
func (s *Service) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(conn, metadata)
}
func (s *Service) HandleError(err error) {
s.handler.HandleError(err)
}

View file

@ -109,9 +109,9 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
}
func Blake3DeriveKey(psk, salt []byte, keyLength int) []byte {
sessionKey := make([]byte, 2*KeySaltSize)
sessionKey := buf.Make(len(psk) + len(salt))
copy(sessionKey, psk)
copy(sessionKey[KeySaltSize:], salt)
copy(sessionKey[len(psk):], salt)
outKey := buf.Make(keyLength)
blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey)
return outKey
@ -434,6 +434,9 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo
return err
}
buffer = buffer.WriteBufferAtFirst(header)
if err != nil {
return err
}
if c.method.udpCipher != nil {
c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
buffer.Extend(c.method.udpCipher.Overhead())
@ -574,9 +577,9 @@ func (s *udpSession) nextPacketId() uint64 {
}
func (m *Method) newUDPSession() *udpSession {
session := &udpSession{
sessionId: rand.Uint64(),
}
session := &udpSession{}
common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId))
session.packetId--
if m.udpCipher == nil {
sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId)

View file

@ -2,32 +2,43 @@ package shadowaead_2022
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"io"
"math"
"net"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
wgReplay "golang.zx2c4.com/wireguard/replay"
)
type Service struct {
name string
secureRNG io.Reader
keyLength int
constructor func(key []byte) cipher.AEAD
psk []byte
replayFilter replay.Filter
handler shadowsocks.Handler
name string
secureRNG io.Reader
keyLength int
constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block
udpCipher cipher.AEAD
udpBlockCipher cipher.Block
psk []byte
replayFilter replay.Filter
handler shadowsocks.Handler
udpNat *udpnat.Service[uint64]
sessions gsync.Map[uint64, *serverUDPSession]
}
func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) {
@ -47,18 +58,20 @@ func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowso
case "2022-blake3-aes-128-gcm":
s.keyLength = 16
s.constructor = newAESGCM
// m.blockConstructor = newAES
// m.udpBlockCipher = newAES(m.psk)
s.blockConstructor = newAES
s.udpBlockCipher = newAES(s.psk)
case "2022-blake3-aes-256-gcm":
s.keyLength = 32
s.constructor = newAESGCM
// m.blockConstructor = newAES
// m.udpBlockCipher = newAES(m.psk)
s.blockConstructor = newAES
s.udpBlockCipher = newAES(s.psk)
case "2022-blake3-chacha20-poly1305":
s.keyLength = 32
s.constructor = newChacha20Poly1305
// m.udpCipher = newXChacha20Poly1305(m.psk)
s.udpCipher = newXChacha20Poly1305(s.psk)
}
s.udpNat = udpnat.New[uint64](s)
return s, nil
}
@ -194,3 +207,169 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
var packetHeader []byte
if s.udpCipher != nil {
_, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
if err != nil {
return E.Cause(err, "decrypt packet header")
}
buffer.Advance(PacketNonceSize)
} else {
packetHeader = buffer.To(aes.BlockSize)
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
}
var sessionId, packetId uint64
err := binary.Read(buffer, binary.BigEndian, &sessionId)
if err != nil {
return err
}
err = binary.Read(buffer, binary.BigEndian, &packetId)
if err != nil {
return err
}
session, loaded := s.sessions.LoadOrStore(sessionId, s.newUDPSession)
if !loaded {
session.remoteSessionId = sessionId
if packetHeader != nil {
key := Blake3DeriveKey(s.psk, packetHeader[:8], s.keyLength)
session.remoteCipher = s.constructor(common.Dup(key))
}
}
if !session.filter.ValidateCounter(packetId, math.MaxUint64) {
return ErrPacketIdNotUnique
}
if packetHeader != nil {
_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
if err != nil {
return E.Cause(err, "decrypt packet")
}
}
var headerType byte
headerType, err = buffer.ReadByte()
if err != nil {
return err
}
if headerType != HeaderTypeClient {
return ErrBadHeaderType
}
var epoch uint64
err = binary.Read(buffer, binary.BigEndian, &epoch)
if err != nil {
return err
}
if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 {
return ErrBadTimestamp
}
var paddingLength uint16
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
if err != nil {
return E.Cause(err, "read padding length")
}
buffer.Advance(int(paddingLength))
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
return err
}
metadata.Destination = destination
return s.udpNat.NewPacket(sessionId, func() socks.PacketWriter {
return &serverPacketWriter{s, conn, session, metadata.Source}
}, buffer, metadata)
}
type serverPacketWriter struct {
*Service
socks.PacketConn
session *serverUDPSession
source *M.AddrPort
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
_header := buf.StackNew()
header := common.Dup(_header)
var dataIndex int
if w.udpCipher != nil {
common.Must1(header.ReadFullFrom(w.secureRNG, PacketNonceSize))
dataIndex = buffer.Len()
} else {
dataIndex = aes.BlockSize
}
common.Must(
binary.Write(header, binary.BigEndian, w.session.sessionId),
binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
header.WriteByte(HeaderTypeServer),
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
err := socks.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
_, err = header.Write(buffer.Bytes())
if err != nil {
return err
}
if w.udpCipher != nil {
w.udpCipher.Seal(header.Index(dataIndex), header.To(dataIndex), header.From(dataIndex), nil)
header.Extend(w.udpCipher.Overhead())
} else {
packetHeader := header.To(aes.BlockSize)
w.session.cipher.Seal(header.Index(dataIndex), packetHeader[4:16], header.From(dataIndex), nil)
header.Extend(w.session.cipher.Overhead())
w.udpBlockCipher.Encrypt(packetHeader, packetHeader)
}
return w.PacketConn.WritePacket(header, w.source)
}
type serverUDPSession struct {
sessionId uint64
remoteSessionId uint64
packetId uint64
cipher cipher.AEAD
remoteCipher cipher.AEAD
filter wgReplay.Filter
}
func (s *serverUDPSession) nextPacketId() uint64 {
return atomic.AddUint64(&s.packetId, 1)
}
func (m *Service) newUDPSession() *serverUDPSession {
session := &serverUDPSession{}
common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId))
session.packetId--
if m.udpCipher == nil {
sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := Blake3DeriveKey(m.psk, sessionId, m.keyLength)
session.cipher = m.constructor(common.Dup(key))
}
return session
}
func (s *Service) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(conn, metadata)
}
func (s *Service) HandleError(err error) {
s.handler.HandleError(err)
}

View file

@ -8,12 +8,21 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
)
type PacketConn interface {
type PacketReader interface {
ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error)
WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error
}
type PacketWriter interface {
WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error
}
type PacketConn interface {
PacketReader
PacketWriter
Close() error
LocalAddr() net.Addr
@ -23,6 +32,10 @@ type PacketConn interface {
SetWriteDeadline(t time.Time) error
}
type UDPHandler interface {
NewPacket(conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
}
type UDPConnectionHandler interface {
NewPacketConnection(conn PacketConn, metadata M.Metadata) error
}
@ -47,6 +60,8 @@ func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error {
return task.Run(ctx, func() error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(dest)
_buffer := buf.StackNewMax()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
@ -56,13 +71,15 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
if err != nil {
return err
}
buffer.Truncate(data.Len())
buffer.Resize(buf.ReversedHeader+data.Start(), data.Len())
err = dest.WritePacket(buffer, destination)
if err != nil {
return err
}
}
}, func() error {
defer rw.CloseRead(dest)
defer rw.CloseWrite(conn)
_buffer := buf.StackNewMax()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
@ -72,7 +89,7 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
if err != nil {
return err
}
buffer.Truncate(data.Len())
buffer.Resize(buf.ReversedHeader+data.Start(), data.Len())
err = conn.WritePacket(buffer, destination)
if err != nil {
return err
@ -81,44 +98,211 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
})
}
func CopyPacketConn0(dest PacketConn, conn PacketConn, onAction func(destination *M.AddrPort, n int)) error {
for {
buffer := buf.New()
destination, err := conn.ReadPacket(buffer)
if err != nil {
buffer.Release()
return err
func CopyNetPacketConn(ctx context.Context, dest net.PacketConn, conn PacketConn) error {
return task.Run(ctx, func() error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(dest)
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
for {
buffer.FullReset()
destination, err := conn.ReadPacket(buffer)
if err != nil {
return err
}
_, err = dest.WriteTo(buffer.Bytes(), destination.UDPAddr())
if err != nil {
return err
}
}
size := buffer.Len()
err = dest.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return err
}, func() error {
defer rw.CloseRead(dest)
defer rw.CloseWrite(conn)
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
for {
data.FullReset()
n, addr, err := dest.ReadFrom(data.FreeBytes())
if err != nil {
return err
}
buffer.Resize(buf.ReversedHeader, n)
err = conn.WritePacket(buffer, M.AddrPortFromNetAddr(addr))
if err != nil {
return err
}
}
if onAction != nil {
onAction(destination, size)
}
}
})
}
type associatePacketConn struct {
net.PacketConn
type AssociateConn struct {
net.Conn
conn net.Conn
addr net.Addr
dest *M.AddrPort
}
func NewPacketConn(conn net.Conn, packetConn net.PacketConn) PacketConn {
return &associatePacketConn{
PacketConn: packetConn,
conn: conn,
func NewAssociateConn(conn net.Conn, packetConn net.Conn, destination *M.AddrPort) net.PacketConn {
return &AssociateConn{
Conn: packetConn,
conn: conn,
dest: destination,
}
}
func (c *associatePacketConn) RemoteAddr() net.Addr {
func (c *AssociateConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *AssociateConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.Conn.Read(p)
if err != nil {
return
}
reader := buf.As(p[3:n])
destination, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return
}
addr = destination.UDPAddr()
n = copy(p, reader.Bytes())
return
}
func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
if err != nil {
return
}
_, err = buffer.Write(p)
if err != nil {
return
}
_, err = c.Conn.Write(buffer.Bytes())
return
}
func (c *AssociateConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *AssociateConn) Write(b []byte) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, c.dest)
if err != nil {
return
}
_, err = buffer.Write(b)
if err != nil {
return
}
_, err = c.Conn.Write(buffer.Bytes())
return
}
func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, err := buffer.ReadFrom(c.conn)
if err != nil {
return nil, err
}
buffer.Truncate(int(n))
buffer.Advance(3)
return AddressSerializer.ReadAddrPort(buffer)
}
func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
common.Must(header.WriteZeroN(3))
common.Must(AddressSerializer.WriteAddrPort(header, destination))
return common.Error(c.Conn.Write(buffer.Bytes()))
}
type AssociatePacketConn struct {
net.PacketConn
conn net.Conn
addr net.Addr
dest *M.AddrPort
}
func NewAssociatePacketConn(conn net.Conn, packetConn net.PacketConn, destination *M.AddrPort) *AssociatePacketConn {
return &AssociatePacketConn{
PacketConn: packetConn,
conn: conn,
dest: destination,
}
}
func (c *AssociatePacketConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.PacketConn.ReadFrom(p)
if err != nil {
return
}
reader := buf.As(p[3:n])
destination, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return
}
addr = destination.UDPAddr()
n = copy(p, reader.Bytes())
return
}
func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
if err != nil {
return
}
_, err = buffer.Write(p)
if err != nil {
return
}
_, err = c.PacketConn.WriteTo(buffer.Bytes(), c.addr)
return
}
func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, c.dest)
if err != nil {
return
}
_, err = buffer.Write(b)
if err != nil {
return
}
_, err = c.PacketConn.WriteTo(buffer.Bytes(), c.addr)
return
}
func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes())
if err != nil {
return nil, err
@ -126,15 +310,14 @@ func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error
c.addr = addr
buffer.Truncate(n)
buffer.Advance(3)
return AddressSerializer.ReadAddrPort(buffer)
dest, err := AddressSerializer.ReadAddrPort(buffer)
return dest, err
}
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
_header := buf.StackNew()
header := common.Dup(_header)
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
common.Must(header.WriteZeroN(3))
common.Must(AddressSerializer.WriteAddrPort(header, addrPort))
buffer = buffer.WriteBufferAtFirst(header)
common.Must(AddressSerializer.WriteAddrPort(header, destination))
return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr))
}

View file

@ -36,7 +36,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler
}
func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
return HandleConnection(ctx, conn, l.authenticator, l.bindAddr, l.handler, metadata)
return HandleConnection(ctx, conn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata)
}
func (l *Listener) Start() error {
@ -131,9 +131,10 @@ func HandleConnection(ctx context.Context, conn net.Conn, authenticator auth.Aut
if err != nil {
return E.Cause(err, "write socks response")
}
metadata.Protocol = "socks"
metadata.Destination = request.Destination
go func() {
err := handler.NewPacketConnection(NewPacketConn(conn, udpConn), metadata)
err := handler.NewPacketConnection(NewAssociatePacketConn(conn, udpConn, request.Destination), metadata)
if err != nil {
handler.HandleError(err)
}

View file

@ -32,7 +32,7 @@ type Listener struct {
bindAddr netip.Addr
handler Handler
authenticator auth.Authenticator
udpNat *udpnat.Server
udpNat *udpnat.Service[string]
}
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, handler Handler) *Listener {
@ -45,7 +45,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transpro
listener.TCPListener = tcp.NewTCPListener(bind, listener, tcp.WithTransproxyMode(transproxy))
if transproxy == redir.ModeTProxy {
listener.UDPListener = udp.NewUDPListener(bind, listener, udp.WithTransproxyMode(transproxy))
listener.udpNat = udpnat.NewServer(handler)
listener.udpNat = udpnat.New[string](handler)
}
return listener
}
@ -63,7 +63,7 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.
case socks.Version4:
return E.New("socks4 request dropped (TODO)")
case socks.Version5:
return socks.HandleConnection(ctx, bufConn, l.authenticator, l.bindAddr, l.handler, metadata)
return socks.HandleConnection(ctx, bufConn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata)
}
request, err := http.ReadRequest(bufConn.Reader())
@ -96,8 +96,23 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.
return http.HandleRequest(ctx, request, bufConn, l.authenticator, l.handler, metadata)
}
func (l *Listener) NewPacket(packet *buf.Buffer, metadata M.Metadata) error {
return l.udpNat.HandleUDP(packet, metadata)
func (l *Listener) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
return l.udpNat.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &tproxyPacketWriter{metadata.Source.UDPAddr()}
}, buffer, metadata)
}
type tproxyPacketWriter struct {
source *net.UDPAddr
}
func (w *tproxyPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), w.source)
if err != nil {
return E.Cause(err, "tproxy udp write back")
}
defer udpConn.Close()
return common.Error(udpConn.Write(buffer.Bytes()))
}
func (l *Listener) HandleError(err error) {

View file

@ -9,10 +9,11 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/redir"
"github.com/sagernet/sing/protocol/socks"
)
type Handler interface {
M.UDPHandler
socks.UDPHandler
E.Handler
}
@ -24,6 +25,19 @@ type Listener struct {
tproxy bool
}
func (l *Listener) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, addr, err := l.ReadFromUDP(buffer.FreeBytes())
if err != nil {
return nil, err
}
buffer.Truncate(n)
return M.AddrPortFromNetAddr(addr), nil
}
func (l *Listener) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
return common.Error(l.UDPConn.WriteTo(buffer.Bytes(), destination.UDPAddr()))
}
func NewUDPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener {
listener := &Listener{
handler: handler,
@ -69,32 +83,31 @@ func (l *Listener) Close() error {
}
func (l *Listener) loop() {
_buffer := buf.StackNewMax()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader).Slice()
if !l.tproxy {
for {
buffer := buf.New()
n, addr, err := l.ReadFromUDP(buffer.Extend(buf.UDPBufferSize))
n, addr, err := l.ReadFromUDP(data)
if err != nil {
buffer.Release()
l.handler.HandleError(err)
return
}
buffer.Truncate(n)
err = l.handler.NewPacket(buffer, M.Metadata{
buffer.Resize(buf.ReversedHeader, n)
err = l.handler.NewPacket(l, buffer, M.Metadata{
Protocol: "udp",
Source: M.AddrPortFromNetAddr(addr),
})
if err != nil {
buffer.Release()
l.handler.HandleError(err)
}
}
} else {
oob := make([]byte, 1024)
_oob := make([]byte, 1024)
oob := common.Dup(_oob)
for {
buffer := buf.New()
n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(buffer.FreeBytes(), oob)
n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(data, oob)
if err != nil {
buffer.Release()
l.handler.HandleError(err)
return
}
@ -103,14 +116,13 @@ func (l *Listener) loop() {
l.handler.HandleError(E.Cause(err, "get original destination"))
return
}
buffer.Truncate(n)
err = l.handler.NewPacket(buffer, M.Metadata{
buffer.Resize(buf.ReversedHeader, n)
err = l.handler.NewPacket(l, buffer, M.Metadata{
Protocol: "tproxy",
Source: M.AddrPortFromAddrPort(addr),
Destination: destination,
})
if err != nil {
buffer.Release()
l.handler.HandleError(err)
}
}