mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 03:47:38 +03:00
Add shadowsocks service
This commit is contained in:
parent
5be6eb2d64
commit
df1e1cfafd
21 changed files with 1028 additions and 265 deletions
|
@ -28,6 +28,12 @@ wget 'https://github.com/Dreamacro/maxmind-geoip/releases/latest/download/Countr
|
|||
go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/ss-local
|
||||
```
|
||||
|
||||
### ss-server
|
||||
|
||||
```shell
|
||||
go install -v -trimpath -ldflags "-s -w -buildid=" ./cli/ss-server
|
||||
```
|
||||
|
||||
### ddns
|
||||
|
||||
```shell
|
||||
|
|
145
cli/socks-chk/main.go
Normal file
145
cli/socks-chk/main.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
)
|
||||
|
||||
func main() {
|
||||
command := &cobra.Command{
|
||||
Use: "socks-chk address:port",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: run,
|
||||
}
|
||||
if err := command.Execute(); err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func run(cmd *cobra.Command, args []string) {
|
||||
server, err := M.ParseAddress(args[0])
|
||||
if err != nil {
|
||||
logrus.Fatal("invalid server address ", args[0])
|
||||
}
|
||||
err = testSocksTCP(server)
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
err = testSocksUDP(server)
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func testSocksTCP(server *M.AddrPort) error {
|
||||
tcpConn, err := net.Dial("tcp", server.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response, err := socks.ClientHandshake(tcpConn, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.ParseAddr("1.0.0.1"), 53), "", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if response.ReplyCode != socks.ReplyCodeSuccess {
|
||||
logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode)
|
||||
}
|
||||
|
||||
message := &dnsmessage.Message{}
|
||||
message.Header.ID = 1
|
||||
message.Header.RecursionDesired = true
|
||||
message.Questions = append(message.Questions, dnsmessage.Question{
|
||||
Name: dnsmessage.MustNewName("google.com."),
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
})
|
||||
packet, err := message.Pack()
|
||||
|
||||
err = binary.Write(tcpConn, binary.BigEndian, uint16(len(packet)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tcpConn.Write(packet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var respLen uint16
|
||||
err = binary.Read(tcpConn, binary.BigEndian, &respLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
respBuf := buf.Make(int(respLen))
|
||||
_, err = io.ReadFull(tcpConn, respBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
common.Must(message.Unpack(respBuf))
|
||||
for _, answer := range message.Answers {
|
||||
logrus.Info("tcp got answer: ", netip.AddrFrom4(answer.Body.(*dnsmessage.AResource).A))
|
||||
}
|
||||
|
||||
tcpConn.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func testSocksUDP(server *M.AddrPort) error {
|
||||
tcpConn, err := net.Dial("tcp", server.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dest := M.AddrPortFrom(M.ParseAddr("1.0.0.1"), 53)
|
||||
response, err := socks.ClientHandshake(tcpConn, socks.Version5, socks.CommandUDPAssociate, dest, "", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if response.ReplyCode != socks.ReplyCodeSuccess {
|
||||
logrus.Fatal("socks tcp handshake failure: ", response.ReplyCode)
|
||||
}
|
||||
var dialer net.Dialer
|
||||
udpConn, err := dialer.DialContext(context.Background(), "udp", response.Bind.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
assConn := socks.NewAssociateConn(tcpConn, udpConn, dest)
|
||||
message := &dnsmessage.Message{}
|
||||
message.Header.ID = 1
|
||||
message.Header.RecursionDesired = true
|
||||
message.Questions = append(message.Questions, dnsmessage.Question{
|
||||
Name: dnsmessage.MustNewName("google.com."),
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
})
|
||||
packet, err := message.Pack()
|
||||
common.Must(err)
|
||||
common.Must1(assConn.WriteTo(packet, &net.UDPAddr{
|
||||
IP: net.IPv4(1, 0, 0, 1),
|
||||
Port: 53,
|
||||
}))
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
common.Must2(buffer.ReadPacketFrom(assConn))
|
||||
common.Must(message.Unpack(buffer.Bytes()))
|
||||
|
||||
for _, answer := range message.Answers {
|
||||
logrus.Info("udp got answer: ", netip.AddrFrom4(answer.Body.(*dnsmessage.AResource).A))
|
||||
}
|
||||
|
||||
udpConn.Close()
|
||||
tcpConn.Close()
|
||||
return nil
|
||||
}
|
|
@ -345,28 +345,15 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
|
|||
return rw.CopyConn(ctx, serverConn, conn)
|
||||
}
|
||||
|
||||
func (c *client) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) error {
|
||||
func (c *client) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
|
||||
logrus.Info("outbound ", metadata.Protocol, " UDP ", metadata.Source, " ==> ", metadata.Destination)
|
||||
ctx := context.Background()
|
||||
udpConn, err := c.dialer.DialContext(ctx, "udp", c.server.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverConn := c.method.DialPacketConn(udpConn)
|
||||
return task.Run(ctx, func() error {
|
||||
var init bool
|
||||
return socks.CopyPacketConn0(serverConn, conn, func(destination *M.AddrPort, n int) {
|
||||
if !init {
|
||||
init = true
|
||||
logrus.Info("UDP ", conn.LocalAddr(), " ==> ", destination)
|
||||
} else {
|
||||
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination)
|
||||
}
|
||||
})
|
||||
}, func() error {
|
||||
return socks.CopyPacketConn0(conn, serverConn, func(destination *M.AddrPort, n int) {
|
||||
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination)
|
||||
})
|
||||
})
|
||||
return socks.CopyPacketConn(ctx, serverConn, conn)
|
||||
}
|
||||
|
||||
func run(cmd *cobra.Command, flags *flags) {
|
||||
|
|
|
@ -3,6 +3,8 @@ package main
|
|||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
@ -12,22 +14,29 @@ import (
|
|||
|
||||
"github.com/sagernet/sing"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/gsync"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/random"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"github.com/sagernet/sing/transport/tcp"
|
||||
"github.com/sagernet/sing/transport/udp"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type flags struct {
|
||||
Bind string `json:"local_address"`
|
||||
LocalPort uint16 `json:"local_port"`
|
||||
// Password string `json:"password"`
|
||||
Server string `json:"server"`
|
||||
ServerPort uint16 `json:"server_port"`
|
||||
Bind string `json:"local_address"`
|
||||
LocalPort uint16 `json:"local_port"`
|
||||
Password string `json:"password"`
|
||||
Key string `json:"key"`
|
||||
Method string `json:"method"`
|
||||
Verbose bool `json:"verbose"`
|
||||
|
@ -35,8 +44,6 @@ type flags struct {
|
|||
}
|
||||
|
||||
func main() {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
|
||||
f := new(flags)
|
||||
|
||||
command := &cobra.Command{
|
||||
|
@ -48,12 +55,16 @@ func main() {
|
|||
},
|
||||
}
|
||||
|
||||
command.Flags().StringVarP(&f.Server, "server", "s", "", "Set the server’s hostname or IP.")
|
||||
command.Flags().Uint16VarP(&f.ServerPort, "server-port", "p", 0, "Set the server’s port number.")
|
||||
command.Flags().StringVarP(&f.Bind, "local-address", "b", "", "Set the local address.")
|
||||
command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.")
|
||||
command.Flags().StringVarP(&f.Key, "key", "k", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
|
||||
command.Flags().StringVar(&f.Key, "key", "", "Set the key directly. The key should be encoded with URL-safe Base64.")
|
||||
command.Flags().StringVarP(&f.Password, "password", "k", "", "Set the password. The server and the client should use the same password.")
|
||||
|
||||
var supportedCiphers []string
|
||||
supportedCiphers = append(supportedCiphers, shadowsocks.MethodNone)
|
||||
supportedCiphers = append(supportedCiphers, shadowaead.List...)
|
||||
supportedCiphers = append(supportedCiphers, shadowaead_2022.List...)
|
||||
|
||||
command.Flags().StringVarP(&f.Method, "encrypt-method", "m", "", "Set the cipher.\n\nSupported ciphers:\n\n"+strings.Join(supportedCiphers, "\n"))
|
||||
|
@ -75,6 +86,10 @@ func run(cmd *cobra.Command, f *flags) {
|
|||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
err = s.udpIn.Start()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
logrus.Info("server started at ", s.tcpIn.TCPListener.Addr())
|
||||
|
||||
osSignals := make(chan os.Signal, 1)
|
||||
|
@ -82,33 +97,101 @@ func run(cmd *cobra.Command, f *flags) {
|
|||
<-osSignals
|
||||
|
||||
s.tcpIn.Close()
|
||||
s.udpIn.Close()
|
||||
}
|
||||
|
||||
type server struct {
|
||||
tcpIn *tcp.Listener
|
||||
udpIn *udp.Listener
|
||||
service shadowsocks.Service
|
||||
udpNat gsync.Map[string, *net.UDPConn]
|
||||
}
|
||||
|
||||
func (s *server) Start() error {
|
||||
err := s.tcpIn.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.udpIn.Start()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *server) Close() error {
|
||||
s.tcpIn.Close()
|
||||
s.udpIn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func newServer(f *flags) (*server, error) {
|
||||
s := new(server)
|
||||
|
||||
if f.ConfigFile != "" {
|
||||
configFile, err := ioutil.ReadFile(f.ConfigFile)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read config file")
|
||||
}
|
||||
flagsNew := new(flags)
|
||||
err = json.Unmarshal(configFile, flagsNew)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode config file")
|
||||
}
|
||||
if flagsNew.Server != "" && f.Server == "" {
|
||||
f.Server = flagsNew.Server
|
||||
}
|
||||
if flagsNew.ServerPort != 0 && f.ServerPort == 0 {
|
||||
f.ServerPort = flagsNew.ServerPort
|
||||
}
|
||||
if flagsNew.Bind != "" && f.Bind == "" {
|
||||
f.Bind = flagsNew.Bind
|
||||
}
|
||||
if flagsNew.LocalPort != 0 && f.LocalPort == 0 {
|
||||
f.LocalPort = flagsNew.LocalPort
|
||||
}
|
||||
if flagsNew.Password != "" && f.Password == "" {
|
||||
f.Password = flagsNew.Password
|
||||
}
|
||||
if flagsNew.Key != "" && f.Key == "" {
|
||||
f.Key = flagsNew.Key
|
||||
}
|
||||
if flagsNew.Method != "" && f.Method == "" {
|
||||
f.Method = flagsNew.Method
|
||||
}
|
||||
if flagsNew.Verbose {
|
||||
f.Verbose = true
|
||||
}
|
||||
}
|
||||
|
||||
if f.Verbose {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
}
|
||||
|
||||
if f.Server == "" {
|
||||
return nil, E.New("missing server address")
|
||||
} else if f.ServerPort == 0 {
|
||||
return nil, E.New("missing server port")
|
||||
} else if f.Method == "" {
|
||||
return nil, E.New("missing method")
|
||||
}
|
||||
|
||||
var key []byte
|
||||
if f.Key != "" {
|
||||
kb, err := base64.StdEncoding.DecodeString(f.Key)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode key")
|
||||
}
|
||||
key = kb
|
||||
}
|
||||
|
||||
if f.Method == shadowsocks.MethodNone {
|
||||
s.service = shadowsocks.NewNoneService(s)
|
||||
} else if common.Contains(shadowaead_2022.List, f.Method) {
|
||||
var pskList [][]byte
|
||||
if f.Key != "" {
|
||||
keyStrList := strings.Split(f.Key, ":")
|
||||
pskList = make([][]byte, len(keyStrList))
|
||||
for i, keyStr := range keyStrList {
|
||||
key, err := base64.StdEncoding.DecodeString(keyStr)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode key")
|
||||
}
|
||||
pskList[i] = key
|
||||
}
|
||||
} else if common.Contains(shadowaead.List, f.Method) {
|
||||
service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Blake3KeyedHash(), false, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rng := random.System
|
||||
service, err := shadowaead_2022.NewService(f.Method, pskList[0], rng, s)
|
||||
s.service = service
|
||||
} else if common.Contains(shadowaead_2022.List, f.Method) {
|
||||
service, err := shadowaead_2022.NewService(f.Method, key, random.Blake3KeyedHash(), s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -118,16 +201,17 @@ func newServer(f *flags) (*server, error) {
|
|||
}
|
||||
|
||||
var bind netip.Addr
|
||||
if f.Bind != "" {
|
||||
addr, err := netip.ParseAddr(f.Bind)
|
||||
if f.Server != "" {
|
||||
addr, err := netip.ParseAddr(f.Server)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "bad local address")
|
||||
return nil, E.Cause(err, "bad server address")
|
||||
}
|
||||
bind = addr
|
||||
} else {
|
||||
bind = netip.IPv6Unspecified()
|
||||
}
|
||||
s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.LocalPort), s)
|
||||
s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.ServerPort), s)
|
||||
s.udpIn = udp.NewUDPListener(netip.AddrPortFrom(bind, f.ServerPort), s)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -143,6 +227,19 @@ func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
|
|||
return rw.CopyConn(ctx, conn, destConn)
|
||||
}
|
||||
|
||||
func (s *server) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
|
||||
logrus.Info("inbound UDP ", metadata.Source, " ==> ", metadata.Destination)
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return socks.CopyNetPacketConn(context.Background(), udpConn, conn)
|
||||
}
|
||||
|
||||
func (s *server) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
return s.service.NewPacket(conn, buffer, metadata)
|
||||
}
|
||||
|
||||
func (s *server) HandleError(err error) {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/redir"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
"github.com/sagernet/sing/common/uot"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"github.com/sagernet/sing/transport/mixed"
|
||||
|
@ -122,15 +121,7 @@ func (c *localClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) e
|
|||
}
|
||||
|
||||
client := uot.NewClientConn(upstream)
|
||||
return task.Run(context.Background(), func() error {
|
||||
return socks.CopyPacketConn0(client, conn, func(destination *M.AddrPort, n int) {
|
||||
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination)
|
||||
})
|
||||
}, func() error {
|
||||
return socks.CopyPacketConn0(conn, client, func(destination *M.AddrPort, n int) {
|
||||
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination)
|
||||
})
|
||||
})
|
||||
return socks.CopyPacketConn(context.Background(), client, conn)
|
||||
}
|
||||
|
||||
func (c *localClient) OnError(err error) {
|
||||
|
|
|
@ -104,11 +104,15 @@ func (b *Buffer) ExtendHeader(size int) []byte {
|
|||
b.start -= size
|
||||
return b.data[b.start-size : b.start]
|
||||
} else {
|
||||
offset := size - b.start
|
||||
/*offset := size - b.start
|
||||
end := b.end + size
|
||||
if end > len(b.data) {
|
||||
panic("buffer overflow")
|
||||
}
|
||||
copy(b.data[offset:end], b.data[b.start:b.end])
|
||||
b.end = end
|
||||
return b.data[:offset]
|
||||
return b.data[:offset]*/
|
||||
panic("no header available")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -119,10 +123,18 @@ func (b *Buffer) WriteBufferAtFirst(buffer *Buffer) *Buffer {
|
|||
b.start -= n
|
||||
buffer.Release()
|
||||
return b
|
||||
} else if buffer.FreeLen() >= b.Len() {
|
||||
common.Must1(buffer.Write(b.Bytes()))
|
||||
b.Release()
|
||||
return buffer
|
||||
} else if b.FreeLen() >= size {
|
||||
copy(b.data[b.start+size:b.end+size], b.data[b.start:b.end])
|
||||
copy(b.data, buffer.data)
|
||||
buffer.Release()
|
||||
return b
|
||||
} else {
|
||||
panic("buffer overflow")
|
||||
}
|
||||
common.Must1(buffer.Write(b.Bytes()))
|
||||
b.Release()
|
||||
return buffer
|
||||
}
|
||||
|
||||
func (b *Buffer) WriteAtFirst(data []byte) (n int, err error) {
|
||||
|
@ -305,6 +317,10 @@ func (b *Buffer) Cut(start int, end int) *Buffer {
|
|||
}
|
||||
}
|
||||
|
||||
func (b Buffer) Start() int {
|
||||
return b.start
|
||||
}
|
||||
|
||||
func (b Buffer) Len() int {
|
||||
return b.end - b.start
|
||||
}
|
||||
|
|
|
@ -3,8 +3,6 @@ package metadata
|
|||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
type Metadata struct {
|
||||
|
@ -16,7 +14,3 @@ type Metadata struct {
|
|||
type TCPConnectionHandler interface {
|
||||
NewConnection(ctx context.Context, conn net.Conn, metadata Metadata) error
|
||||
}
|
||||
|
||||
type UDPHandler interface {
|
||||
NewPacket(packet *buf.Buffer, metadata Metadata) error
|
||||
}
|
||||
|
|
|
@ -55,28 +55,43 @@ func (s *Serializer) WriteAddress(writer io.Writer, addr Addr) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *Serializer) AddressLen(addr Addr) int {
|
||||
switch addr.Family() {
|
||||
case AddressFamilyIPv4:
|
||||
return 5
|
||||
case AddressFamilyIPv6:
|
||||
return 17
|
||||
default:
|
||||
return 1 + len(addr.Fqdn())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Serializer) WritePort(writer io.Writer, port uint16) error {
|
||||
return binary.Write(writer, binary.BigEndian, port)
|
||||
}
|
||||
|
||||
func (s *Serializer) WriteAddrPort(writer io.Writer, addrPort *AddrPort) error {
|
||||
func (s *Serializer) WriteAddrPort(writer io.Writer, destination *AddrPort) error {
|
||||
var err error
|
||||
if !s.portFirst {
|
||||
err = s.WriteAddress(writer, addrPort.Addr)
|
||||
err = s.WriteAddress(writer, destination.Addr)
|
||||
} else {
|
||||
err = s.WritePort(writer, addrPort.Port)
|
||||
err = s.WritePort(writer, destination.Port)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.portFirst {
|
||||
err = s.WriteAddress(writer, addrPort.Addr)
|
||||
err = s.WriteAddress(writer, destination.Addr)
|
||||
} else {
|
||||
err = s.WritePort(writer, addrPort.Port)
|
||||
err = s.WritePort(writer, destination.Port)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Serializer) AddrPortLen(destination *AddrPort) int {
|
||||
return s.AddressLen(destination.Addr) + 2
|
||||
}
|
||||
|
||||
func (s *Serializer) ReadAddress(reader io.Reader) (Addr, error) {
|
||||
af, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
|
@ -120,7 +135,7 @@ func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) {
|
|||
return binary.BigEndian.Uint16(port), nil
|
||||
}
|
||||
|
||||
func (s *Serializer) ReadAddrPort(reader io.Reader) (addrPort *AddrPort, err error) {
|
||||
func (s *Serializer) ReadAddrPort(reader io.Reader) (destination *AddrPort, err error) {
|
||||
var addr Addr
|
||||
var port uint16
|
||||
if !s.portFirst {
|
||||
|
|
|
@ -17,6 +17,17 @@ type DefaultDialer struct {
|
|||
net.Dialer
|
||||
}
|
||||
|
||||
func (d *DefaultDialer) ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
return net.ListenUDP(network, laddr)
|
||||
}
|
||||
|
||||
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address *M.AddrPort) (net.Conn, error) {
|
||||
return d.Dialer.DialContext(ctx, network, address.String())
|
||||
}
|
||||
|
||||
type Listener interface {
|
||||
Listen(ctx context.Context, network, address string) (net.Listener, error)
|
||||
ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error)
|
||||
}
|
||||
|
||||
var SystemListener Listener = &net.ListenConfig{}
|
||||
|
|
|
@ -1,108 +0,0 @@
|
|||
package udpnat
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/gsync"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/redir"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
socks.UDPConnectionHandler
|
||||
E.Handler
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
udpNat gsync.Map[string, *packetConn]
|
||||
handler Handler
|
||||
}
|
||||
|
||||
func NewServer(handler Handler) *Server {
|
||||
return &Server{handler: handler}
|
||||
}
|
||||
|
||||
func (s *Server) HandleUDP(buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
conn, loaded := s.udpNat.LoadOrStore(metadata.Source.String(), func() *packetConn {
|
||||
return &packetConn{source: metadata.Source.UDPAddr(), in: make(chan *udpPacket)}
|
||||
})
|
||||
if !loaded {
|
||||
go func() {
|
||||
err := s.handler.NewPacketConnection(conn, metadata)
|
||||
if err != nil {
|
||||
s.handler.HandleError(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
conn.in <- &udpPacket{
|
||||
buffer: buffer,
|
||||
destination: metadata.Destination,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) OnError(err error) {
|
||||
s.handler.HandleError(err)
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
s.udpNat.Range(func(key string, conn *packetConn) bool {
|
||||
conn.Close()
|
||||
return true
|
||||
})
|
||||
s.udpNat = gsync.Map[string, *packetConn]{}
|
||||
return nil
|
||||
}
|
||||
|
||||
type packetConn struct {
|
||||
socks.PacketConnStub
|
||||
source *net.UDPAddr
|
||||
in chan *udpPacket
|
||||
}
|
||||
|
||||
type udpPacket struct {
|
||||
buffer *buf.Buffer
|
||||
destination *M.AddrPort
|
||||
}
|
||||
|
||||
func (c *packetConn) LocalAddr() net.Addr {
|
||||
return c.source
|
||||
}
|
||||
|
||||
func (c *packetConn) Close() error {
|
||||
select {
|
||||
case <-c.in:
|
||||
return io.ErrClosedPipe
|
||||
default:
|
||||
close(c.in)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *packetConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
select {
|
||||
case packet, ok := <-c.in:
|
||||
if !ok {
|
||||
return nil, io.ErrClosedPipe
|
||||
}
|
||||
defer packet.buffer.Release()
|
||||
if buffer.FreeLen() < packet.buffer.Len() {
|
||||
return nil, io.ErrShortBuffer
|
||||
}
|
||||
return packet.destination, common.Error(buffer.Write(packet.buffer.Bytes()))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *packetConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), c.source)
|
||||
if err != nil {
|
||||
return E.Cause(err, "tproxy udp write back")
|
||||
}
|
||||
defer udpConn.Close()
|
||||
return common.Error(udpConn.Write(buffer.Bytes()))
|
||||
}
|
123
common/udpnat/service.go
Normal file
123
common/udpnat/service.go
Normal file
|
@ -0,0 +1,123 @@
|
|||
package udpnat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/gsync"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
socks.UDPConnectionHandler
|
||||
E.Handler
|
||||
}
|
||||
|
||||
type Service[K comparable] struct {
|
||||
nat gsync.Map[K, *conn]
|
||||
handler Handler
|
||||
}
|
||||
|
||||
func New[T comparable](handler Handler) *Service[T] {
|
||||
return &Service[T]{
|
||||
handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
c, loaded := s.nat.LoadOrStore(key, func() *conn {
|
||||
c := &conn{
|
||||
data: make(chan packet),
|
||||
remoteAddr: metadata.Source.UDPAddr(),
|
||||
source: writer(),
|
||||
}
|
||||
c.ctx, c.cancel = context.WithCancel(context.Background())
|
||||
return c
|
||||
})
|
||||
if !loaded {
|
||||
go func() {
|
||||
err := s.handler.NewPacketConnection(c, metadata)
|
||||
if err != nil {
|
||||
s.handler.HandleError(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx, done := context.WithCancel(c.ctx)
|
||||
p := packet{
|
||||
done: done,
|
||||
data: buffer,
|
||||
destination: metadata.Destination,
|
||||
}
|
||||
c.data <- p
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
type packet struct {
|
||||
data *buf.Buffer
|
||||
destination *M.AddrPort
|
||||
done context.CancelFunc
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
data chan packet
|
||||
remoteAddr *net.UDPAddr
|
||||
source socks.PacketWriter
|
||||
}
|
||||
|
||||
func (c *conn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
select {
|
||||
case p, ok := <-c.data:
|
||||
if !ok {
|
||||
return nil, io.ErrClosedPipe
|
||||
}
|
||||
defer p.data.Release()
|
||||
_, err := buffer.ReadFrom(p.data)
|
||||
p.done()
|
||||
return p.destination, err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
return c.source.WritePacket(buffer, destination)
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
c.cancel()
|
||||
select {
|
||||
case <-c.data:
|
||||
return os.ErrClosed
|
||||
default:
|
||||
close(c.data)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
return &common.DummyAddr{}
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
|
@ -30,7 +30,6 @@ func TestServerConn(t *testing.T) {
|
|||
Port: 53,
|
||||
}))
|
||||
_buffer := buf.StackNew()
|
||||
common.Use(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
common.Must2(buffer.ReadPacketFrom(clientConn))
|
||||
common.Must(message.Unpack(buffer.Bytes()))
|
||||
|
|
|
@ -4,32 +4,35 @@ import (
|
|||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
M.TCPConnectionHandler
|
||||
}
|
||||
|
||||
type MultiUserService interface {
|
||||
Service
|
||||
AddUser(key []byte)
|
||||
RemoveUser(key []byte)
|
||||
socks.UDPHandler
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
M.TCPConnectionHandler
|
||||
socks.UDPConnectionHandler
|
||||
E.Handler
|
||||
}
|
||||
|
||||
type NoneService struct {
|
||||
handler Handler
|
||||
udp *udpnat.Service[string]
|
||||
}
|
||||
|
||||
func NewNoneService(handler Handler) Service {
|
||||
return &NoneService{
|
||||
s := &NoneService{
|
||||
handler: handler,
|
||||
}
|
||||
s.udp = udpnat.New[string](s)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
|
@ -41,3 +44,37 @@ func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata
|
|||
metadata.Destination = destination
|
||||
return s.handler.NewConnection(ctx, conn, metadata)
|
||||
}
|
||||
|
||||
func (s *NoneService) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
|
||||
return &serverPacketWriter{conn, metadata.Source}
|
||||
}, buffer, metadata)
|
||||
}
|
||||
|
||||
type serverPacketWriter struct {
|
||||
socks.PacketConn
|
||||
sourceAddr *M.AddrPort
|
||||
}
|
||||
|
||||
func (s *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
header := buf.With(buffer.ExtendHeader(socks.AddressSerializer.AddrPortLen(destination)))
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.PacketConn.WritePacket(buffer, s.sourceAddr)
|
||||
}
|
||||
|
||||
func (s *NoneService) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
|
||||
return s.handler.NewPacketConnection(conn, metadata)
|
||||
}
|
||||
|
||||
func (s *NoneService) HandleError(err error) {
|
||||
s.handler.HandleError(err)
|
||||
}
|
||||
|
|
|
@ -156,13 +156,13 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
|
|||
}
|
||||
|
||||
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
|
||||
return &clientPacketConn{conn, m}
|
||||
return &clientPacketConn{m, conn}
|
||||
}
|
||||
|
||||
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
|
||||
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
|
||||
c := m.constructor(common.Dup(key))
|
||||
c.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
buffer.Extend(c.Overhead())
|
||||
return nil
|
||||
}
|
||||
|
@ -299,20 +299,18 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
*Method
|
||||
net.Conn
|
||||
method *Method
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||
header := buffer.ExtendHeader(c.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
|
||||
common.Must1(io.ReadFull(c.secureRNG, header[:c.keySaltLength]))
|
||||
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
err = c.method.EncodePacket(buffer)
|
||||
err = c.EncodePacket(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -325,7 +323,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
|||
return nil, err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = c.method.DecodePacket(buffer)
|
||||
err = c.DecodePacket(buffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
|
@ -25,6 +26,7 @@ type Service struct {
|
|||
key []byte
|
||||
secureRNG io.Reader
|
||||
replayFilter replay.Filter
|
||||
udp *udpnat.Service[string]
|
||||
handler shadowsocks.Handler
|
||||
}
|
||||
|
||||
|
@ -34,6 +36,7 @@ func NewService(method string, key []byte, password []byte, secureRNG io.Reader,
|
|||
secureRNG: secureRNG,
|
||||
handler: handler,
|
||||
}
|
||||
s.udp = udpnat.New[string](s)
|
||||
if replayFilter {
|
||||
s.replayFilter = replay.NewBloomRing()
|
||||
}
|
||||
|
@ -163,3 +166,59 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
if buffer.Len() < s.keySaltLength {
|
||||
return E.New("bad packet")
|
||||
}
|
||||
key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength)
|
||||
c := s.constructor(common.Dup(key))
|
||||
/*data := buf.New()
|
||||
packet, err := c.Open(data.Index(0), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data.Truncate(len(packet))
|
||||
metadata.Protocol = "shadowsocks"
|
||||
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
|
||||
return &serverPacketWriter{s, conn, metadata.Source}
|
||||
}, data, metadata)*/
|
||||
packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Advance(s.keySaltLength)
|
||||
buffer.Truncate(len(packet))
|
||||
metadata.Protocol = "shadowsocks"
|
||||
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
|
||||
return &serverPacketWriter{s, conn, metadata.Source}
|
||||
}, buffer, metadata)
|
||||
}
|
||||
|
||||
type serverPacketWriter struct {
|
||||
*Service
|
||||
socks.PacketConn
|
||||
source *M.AddrPort
|
||||
}
|
||||
|
||||
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
header := buffer.ExtendHeader(w.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
|
||||
common.Must1(io.ReadFull(w.secureRNG, header[:w.keySaltLength]))
|
||||
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength)
|
||||
c := w.constructor(common.Dup(key))
|
||||
c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil)
|
||||
buffer.Extend(c.Overhead())
|
||||
return w.PacketConn.WritePacket(buffer, w.source)
|
||||
}
|
||||
|
||||
func (s *Service) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
|
||||
return s.handler.NewPacketConnection(conn, metadata)
|
||||
}
|
||||
|
||||
func (s *Service) HandleError(err error) {
|
||||
s.handler.HandleError(err)
|
||||
}
|
||||
|
|
|
@ -109,9 +109,9 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
|
|||
}
|
||||
|
||||
func Blake3DeriveKey(psk, salt []byte, keyLength int) []byte {
|
||||
sessionKey := make([]byte, 2*KeySaltSize)
|
||||
sessionKey := buf.Make(len(psk) + len(salt))
|
||||
copy(sessionKey, psk)
|
||||
copy(sessionKey[KeySaltSize:], salt)
|
||||
copy(sessionKey[len(psk):], salt)
|
||||
outKey := buf.Make(keyLength)
|
||||
blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey)
|
||||
return outKey
|
||||
|
@ -434,6 +434,9 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo
|
|||
return err
|
||||
}
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.method.udpCipher != nil {
|
||||
c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
|
||||
buffer.Extend(c.method.udpCipher.Overhead())
|
||||
|
@ -574,9 +577,9 @@ func (s *udpSession) nextPacketId() uint64 {
|
|||
}
|
||||
|
||||
func (m *Method) newUDPSession() *udpSession {
|
||||
session := &udpSession{
|
||||
sessionId: rand.Uint64(),
|
||||
}
|
||||
session := &udpSession{}
|
||||
common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId))
|
||||
session.packetId--
|
||||
if m.udpCipher == nil {
|
||||
sessionId := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
||||
|
|
|
@ -2,32 +2,43 @@ package shadowaead_2022
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/gsync"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
wgReplay "golang.zx2c4.com/wireguard/replay"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
name string
|
||||
secureRNG io.Reader
|
||||
keyLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
psk []byte
|
||||
replayFilter replay.Filter
|
||||
handler shadowsocks.Handler
|
||||
name string
|
||||
secureRNG io.Reader
|
||||
keyLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
blockConstructor func(key []byte) cipher.Block
|
||||
udpCipher cipher.AEAD
|
||||
udpBlockCipher cipher.Block
|
||||
psk []byte
|
||||
replayFilter replay.Filter
|
||||
handler shadowsocks.Handler
|
||||
udpNat *udpnat.Service[uint64]
|
||||
sessions gsync.Map[uint64, *serverUDPSession]
|
||||
}
|
||||
|
||||
func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||
|
@ -47,18 +58,20 @@ func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowso
|
|||
case "2022-blake3-aes-128-gcm":
|
||||
s.keyLength = 16
|
||||
s.constructor = newAESGCM
|
||||
// m.blockConstructor = newAES
|
||||
// m.udpBlockCipher = newAES(m.psk)
|
||||
s.blockConstructor = newAES
|
||||
s.udpBlockCipher = newAES(s.psk)
|
||||
case "2022-blake3-aes-256-gcm":
|
||||
s.keyLength = 32
|
||||
s.constructor = newAESGCM
|
||||
// m.blockConstructor = newAES
|
||||
// m.udpBlockCipher = newAES(m.psk)
|
||||
s.blockConstructor = newAES
|
||||
s.udpBlockCipher = newAES(s.psk)
|
||||
case "2022-blake3-chacha20-poly1305":
|
||||
s.keyLength = 32
|
||||
s.constructor = newChacha20Poly1305
|
||||
// m.udpCipher = newXChacha20Poly1305(m.psk)
|
||||
s.udpCipher = newXChacha20Poly1305(s.psk)
|
||||
}
|
||||
|
||||
s.udpNat = udpnat.New[uint64](s)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -194,3 +207,169 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
var packetHeader []byte
|
||||
if s.udpCipher != nil {
|
||||
_, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
||||
if err != nil {
|
||||
return E.Cause(err, "decrypt packet header")
|
||||
}
|
||||
buffer.Advance(PacketNonceSize)
|
||||
} else {
|
||||
packetHeader = buffer.To(aes.BlockSize)
|
||||
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
|
||||
}
|
||||
|
||||
var sessionId, packetId uint64
|
||||
err := binary.Read(buffer, binary.BigEndian, &sessionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Read(buffer, binary.BigEndian, &packetId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, loaded := s.sessions.LoadOrStore(sessionId, s.newUDPSession)
|
||||
if !loaded {
|
||||
session.remoteSessionId = sessionId
|
||||
if packetHeader != nil {
|
||||
key := Blake3DeriveKey(s.psk, packetHeader[:8], s.keyLength)
|
||||
session.remoteCipher = s.constructor(common.Dup(key))
|
||||
}
|
||||
}
|
||||
|
||||
if !session.filter.ValidateCounter(packetId, math.MaxUint64) {
|
||||
return ErrPacketIdNotUnique
|
||||
}
|
||||
|
||||
if packetHeader != nil {
|
||||
_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
|
||||
if err != nil {
|
||||
return E.Cause(err, "decrypt packet")
|
||||
}
|
||||
}
|
||||
|
||||
var headerType byte
|
||||
headerType, err = buffer.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if headerType != HeaderTypeClient {
|
||||
return ErrBadHeaderType
|
||||
}
|
||||
|
||||
var epoch uint64
|
||||
err = binary.Read(buffer, binary.BigEndian, &epoch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 {
|
||||
return ErrBadTimestamp
|
||||
}
|
||||
|
||||
var paddingLength uint16
|
||||
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read padding length")
|
||||
}
|
||||
buffer.Advance(int(paddingLength))
|
||||
|
||||
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metadata.Destination = destination
|
||||
|
||||
return s.udpNat.NewPacket(sessionId, func() socks.PacketWriter {
|
||||
return &serverPacketWriter{s, conn, session, metadata.Source}
|
||||
}, buffer, metadata)
|
||||
}
|
||||
|
||||
type serverPacketWriter struct {
|
||||
*Service
|
||||
socks.PacketConn
|
||||
session *serverUDPSession
|
||||
source *M.AddrPort
|
||||
}
|
||||
|
||||
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
|
||||
var dataIndex int
|
||||
if w.udpCipher != nil {
|
||||
common.Must1(header.ReadFullFrom(w.secureRNG, PacketNonceSize))
|
||||
dataIndex = buffer.Len()
|
||||
} else {
|
||||
dataIndex = aes.BlockSize
|
||||
}
|
||||
|
||||
common.Must(
|
||||
binary.Write(header, binary.BigEndian, w.session.sessionId),
|
||||
binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
|
||||
header.WriteByte(HeaderTypeServer),
|
||||
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
|
||||
binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
|
||||
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
|
||||
)
|
||||
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = header.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if w.udpCipher != nil {
|
||||
w.udpCipher.Seal(header.Index(dataIndex), header.To(dataIndex), header.From(dataIndex), nil)
|
||||
header.Extend(w.udpCipher.Overhead())
|
||||
} else {
|
||||
packetHeader := header.To(aes.BlockSize)
|
||||
w.session.cipher.Seal(header.Index(dataIndex), packetHeader[4:16], header.From(dataIndex), nil)
|
||||
header.Extend(w.session.cipher.Overhead())
|
||||
w.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
||||
}
|
||||
return w.PacketConn.WritePacket(header, w.source)
|
||||
}
|
||||
|
||||
type serverUDPSession struct {
|
||||
sessionId uint64
|
||||
remoteSessionId uint64
|
||||
packetId uint64
|
||||
cipher cipher.AEAD
|
||||
remoteCipher cipher.AEAD
|
||||
filter wgReplay.Filter
|
||||
}
|
||||
|
||||
func (s *serverUDPSession) nextPacketId() uint64 {
|
||||
return atomic.AddUint64(&s.packetId, 1)
|
||||
}
|
||||
|
||||
func (m *Service) newUDPSession() *serverUDPSession {
|
||||
session := &serverUDPSession{}
|
||||
common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId))
|
||||
session.packetId--
|
||||
if m.udpCipher == nil {
|
||||
sessionId := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
||||
key := Blake3DeriveKey(m.psk, sessionId, m.keyLength)
|
||||
session.cipher = m.constructor(common.Dup(key))
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func (s *Service) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
|
||||
return s.handler.NewPacketConnection(conn, metadata)
|
||||
}
|
||||
|
||||
func (s *Service) HandleError(err error) {
|
||||
s.handler.HandleError(err)
|
||||
}
|
||||
|
|
|
@ -8,12 +8,21 @@ import (
|
|||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
)
|
||||
|
||||
type PacketConn interface {
|
||||
type PacketReader interface {
|
||||
ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error)
|
||||
WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error
|
||||
}
|
||||
|
||||
type PacketWriter interface {
|
||||
WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error
|
||||
}
|
||||
|
||||
type PacketConn interface {
|
||||
PacketReader
|
||||
PacketWriter
|
||||
|
||||
Close() error
|
||||
LocalAddr() net.Addr
|
||||
|
@ -23,6 +32,10 @@ type PacketConn interface {
|
|||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
type UDPHandler interface {
|
||||
NewPacket(conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
|
||||
}
|
||||
|
||||
type UDPConnectionHandler interface {
|
||||
NewPacketConnection(conn PacketConn, metadata M.Metadata) error
|
||||
}
|
||||
|
@ -47,6 +60,8 @@ func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
|
|||
|
||||
func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error {
|
||||
return task.Run(ctx, func() error {
|
||||
defer rw.CloseRead(conn)
|
||||
defer rw.CloseWrite(dest)
|
||||
_buffer := buf.StackNewMax()
|
||||
buffer := common.Dup(_buffer)
|
||||
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
|
||||
|
@ -56,13 +71,15 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Truncate(data.Len())
|
||||
buffer.Resize(buf.ReversedHeader+data.Start(), data.Len())
|
||||
err = dest.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}, func() error {
|
||||
defer rw.CloseRead(dest)
|
||||
defer rw.CloseWrite(conn)
|
||||
_buffer := buf.StackNewMax()
|
||||
buffer := common.Dup(_buffer)
|
||||
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
|
||||
|
@ -72,7 +89,7 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Truncate(data.Len())
|
||||
buffer.Resize(buf.ReversedHeader+data.Start(), data.Len())
|
||||
err = conn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -81,44 +98,211 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
|
|||
})
|
||||
}
|
||||
|
||||
func CopyPacketConn0(dest PacketConn, conn PacketConn, onAction func(destination *M.AddrPort, n int)) error {
|
||||
for {
|
||||
buffer := buf.New()
|
||||
destination, err := conn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return err
|
||||
func CopyNetPacketConn(ctx context.Context, dest net.PacketConn, conn PacketConn) error {
|
||||
return task.Run(ctx, func() error {
|
||||
defer rw.CloseRead(conn)
|
||||
defer rw.CloseWrite(dest)
|
||||
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
for {
|
||||
buffer.FullReset()
|
||||
destination, err := conn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = dest.WriteTo(buffer.Bytes(), destination.UDPAddr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
size := buffer.Len()
|
||||
err = dest.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return err
|
||||
}, func() error {
|
||||
defer rw.CloseRead(dest)
|
||||
defer rw.CloseWrite(conn)
|
||||
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
|
||||
for {
|
||||
data.FullReset()
|
||||
n, addr, err := dest.ReadFrom(data.FreeBytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Resize(buf.ReversedHeader, n)
|
||||
err = conn.WritePacket(buffer, M.AddrPortFromNetAddr(addr))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if onAction != nil {
|
||||
onAction(destination, size)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type associatePacketConn struct {
|
||||
net.PacketConn
|
||||
type AssociateConn struct {
|
||||
net.Conn
|
||||
conn net.Conn
|
||||
addr net.Addr
|
||||
dest *M.AddrPort
|
||||
}
|
||||
|
||||
func NewPacketConn(conn net.Conn, packetConn net.PacketConn) PacketConn {
|
||||
return &associatePacketConn{
|
||||
PacketConn: packetConn,
|
||||
conn: conn,
|
||||
func NewAssociateConn(conn net.Conn, packetConn net.Conn, destination *M.AddrPort) net.PacketConn {
|
||||
return &AssociateConn{
|
||||
Conn: packetConn,
|
||||
conn: conn,
|
||||
dest: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) RemoteAddr() net.Addr {
|
||||
func (c *AssociateConn) RemoteAddr() net.Addr {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
func (c *AssociateConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, err = c.Conn.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
reader := buf.As(p[3:n])
|
||||
destination, err := AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
addr = destination.UDPAddr()
|
||||
n = copy(p, reader.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
common.Must(buffer.WriteZeroN(3))
|
||||
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = c.Conn.Write(buffer.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociateConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociateConn) Write(b []byte) (n int, err error) {
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
common.Must(buffer.WriteZeroN(3))
|
||||
err = AddressSerializer.WriteAddrPort(buffer, c.dest)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.Write(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = c.Conn.Write(buffer.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
n, err := buffer.ReadFrom(c.conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.Truncate(int(n))
|
||||
buffer.Advance(3)
|
||||
return AddressSerializer.ReadAddrPort(buffer)
|
||||
}
|
||||
|
||||
func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
|
||||
common.Must(header.WriteZeroN(3))
|
||||
common.Must(AddressSerializer.WriteAddrPort(header, destination))
|
||||
return common.Error(c.Conn.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
type AssociatePacketConn struct {
|
||||
net.PacketConn
|
||||
conn net.Conn
|
||||
addr net.Addr
|
||||
dest *M.AddrPort
|
||||
}
|
||||
|
||||
func NewAssociatePacketConn(conn net.Conn, packetConn net.PacketConn, destination *M.AddrPort) *AssociatePacketConn {
|
||||
return &AssociatePacketConn{
|
||||
PacketConn: packetConn,
|
||||
conn: conn,
|
||||
dest: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) RemoteAddr() net.Addr {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = c.PacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
reader := buf.As(p[3:n])
|
||||
destination, err := AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
addr = destination.UDPAddr()
|
||||
n = copy(p, reader.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
common.Must(buffer.WriteZeroN(3))
|
||||
|
||||
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = c.PacketConn.WriteTo(buffer.Bytes(), c.addr)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
common.Must(buffer.WriteZeroN(3))
|
||||
|
||||
err = AddressSerializer.WriteAddrPort(buffer, c.dest)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.Write(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = c.PacketConn.WriteTo(buffer.Bytes(), c.addr)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -126,15 +310,14 @@ func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error
|
|||
c.addr = addr
|
||||
buffer.Truncate(n)
|
||||
buffer.Advance(3)
|
||||
return AddressSerializer.ReadAddrPort(buffer)
|
||||
dest, err := AddressSerializer.ReadAddrPort(buffer)
|
||||
return dest, err
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||
func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
|
||||
common.Must(header.WriteZeroN(3))
|
||||
common.Must(AddressSerializer.WriteAddrPort(header, addrPort))
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
common.Must(AddressSerializer.WriteAddrPort(header, destination))
|
||||
return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr))
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler
|
|||
}
|
||||
|
||||
func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
return HandleConnection(ctx, conn, l.authenticator, l.bindAddr, l.handler, metadata)
|
||||
return HandleConnection(ctx, conn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata)
|
||||
}
|
||||
|
||||
func (l *Listener) Start() error {
|
||||
|
@ -131,9 +131,10 @@ func HandleConnection(ctx context.Context, conn net.Conn, authenticator auth.Aut
|
|||
if err != nil {
|
||||
return E.Cause(err, "write socks response")
|
||||
}
|
||||
metadata.Protocol = "socks"
|
||||
metadata.Destination = request.Destination
|
||||
go func() {
|
||||
err := handler.NewPacketConnection(NewPacketConn(conn, udpConn), metadata)
|
||||
err := handler.NewPacketConnection(NewAssociatePacketConn(conn, udpConn, request.Destination), metadata)
|
||||
if err != nil {
|
||||
handler.HandleError(err)
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ type Listener struct {
|
|||
bindAddr netip.Addr
|
||||
handler Handler
|
||||
authenticator auth.Authenticator
|
||||
udpNat *udpnat.Server
|
||||
udpNat *udpnat.Service[string]
|
||||
}
|
||||
|
||||
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, handler Handler) *Listener {
|
||||
|
@ -45,7 +45,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transpro
|
|||
listener.TCPListener = tcp.NewTCPListener(bind, listener, tcp.WithTransproxyMode(transproxy))
|
||||
if transproxy == redir.ModeTProxy {
|
||||
listener.UDPListener = udp.NewUDPListener(bind, listener, udp.WithTransproxyMode(transproxy))
|
||||
listener.udpNat = udpnat.NewServer(handler)
|
||||
listener.udpNat = udpnat.New[string](handler)
|
||||
}
|
||||
return listener
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.
|
|||
case socks.Version4:
|
||||
return E.New("socks4 request dropped (TODO)")
|
||||
case socks.Version5:
|
||||
return socks.HandleConnection(ctx, bufConn, l.authenticator, l.bindAddr, l.handler, metadata)
|
||||
return socks.HandleConnection(ctx, bufConn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata)
|
||||
}
|
||||
|
||||
request, err := http.ReadRequest(bufConn.Reader())
|
||||
|
@ -96,8 +96,23 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.
|
|||
return http.HandleRequest(ctx, request, bufConn, l.authenticator, l.handler, metadata)
|
||||
}
|
||||
|
||||
func (l *Listener) NewPacket(packet *buf.Buffer, metadata M.Metadata) error {
|
||||
return l.udpNat.HandleUDP(packet, metadata)
|
||||
func (l *Listener) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
return l.udpNat.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
|
||||
return &tproxyPacketWriter{metadata.Source.UDPAddr()}
|
||||
}, buffer, metadata)
|
||||
}
|
||||
|
||||
type tproxyPacketWriter struct {
|
||||
source *net.UDPAddr
|
||||
}
|
||||
|
||||
func (w *tproxyPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
udpConn, err := redir.DialUDP("udp", destination.UDPAddr(), w.source)
|
||||
if err != nil {
|
||||
return E.Cause(err, "tproxy udp write back")
|
||||
}
|
||||
defer udpConn.Close()
|
||||
return common.Error(udpConn.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (l *Listener) HandleError(err error) {
|
||||
|
|
|
@ -9,10 +9,11 @@ import (
|
|||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/redir"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
M.UDPHandler
|
||||
socks.UDPHandler
|
||||
E.Handler
|
||||
}
|
||||
|
||||
|
@ -24,6 +25,19 @@ type Listener struct {
|
|||
tproxy bool
|
||||
}
|
||||
|
||||
func (l *Listener) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
n, addr, err := l.ReadFromUDP(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
return M.AddrPortFromNetAddr(addr), nil
|
||||
}
|
||||
|
||||
func (l *Listener) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
return common.Error(l.UDPConn.WriteTo(buffer.Bytes(), destination.UDPAddr()))
|
||||
}
|
||||
|
||||
func NewUDPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener {
|
||||
listener := &Listener{
|
||||
handler: handler,
|
||||
|
@ -69,32 +83,31 @@ func (l *Listener) Close() error {
|
|||
}
|
||||
|
||||
func (l *Listener) loop() {
|
||||
_buffer := buf.StackNewMax()
|
||||
buffer := common.Dup(_buffer)
|
||||
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader).Slice()
|
||||
if !l.tproxy {
|
||||
for {
|
||||
buffer := buf.New()
|
||||
n, addr, err := l.ReadFromUDP(buffer.Extend(buf.UDPBufferSize))
|
||||
n, addr, err := l.ReadFromUDP(data)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
l.handler.HandleError(err)
|
||||
return
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = l.handler.NewPacket(buffer, M.Metadata{
|
||||
buffer.Resize(buf.ReversedHeader, n)
|
||||
err = l.handler.NewPacket(l, buffer, M.Metadata{
|
||||
Protocol: "udp",
|
||||
Source: M.AddrPortFromNetAddr(addr),
|
||||
})
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
l.handler.HandleError(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
oob := make([]byte, 1024)
|
||||
_oob := make([]byte, 1024)
|
||||
oob := common.Dup(_oob)
|
||||
for {
|
||||
buffer := buf.New()
|
||||
n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(buffer.FreeBytes(), oob)
|
||||
n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(data, oob)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
l.handler.HandleError(err)
|
||||
return
|
||||
}
|
||||
|
@ -103,14 +116,13 @@ func (l *Listener) loop() {
|
|||
l.handler.HandleError(E.Cause(err, "get original destination"))
|
||||
return
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = l.handler.NewPacket(buffer, M.Metadata{
|
||||
buffer.Resize(buf.ReversedHeader, n)
|
||||
err = l.handler.NewPacket(l, buffer, M.Metadata{
|
||||
Protocol: "tproxy",
|
||||
Source: M.AddrPortFromAddrPort(addr),
|
||||
Destination: destination,
|
||||
})
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
l.handler.HandleError(err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue