diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 084afef..c4a1961 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -4,9 +4,9 @@ import ( "crypto/rand" "fmt" "io" - "sing/common/list" "sing/common" + "sing/common/list" ) type Buffer struct { @@ -25,6 +25,13 @@ func New() *Buffer { } } +func FullNew() *Buffer { + return &Buffer{ + data: GetBytes(), + managed: true, + } +} + func StackNew() Buffer { return Buffer{ data: GetBytes(), @@ -226,6 +233,11 @@ func (b *Buffer) Reset() { b.end = ReversedHeader } +func (b *Buffer) FullReset() { + b.start = 0 + b.end = 0 +} + func (b *Buffer) Release() { if b == nil || b.data == nil || !b.managed { return diff --git a/common/buf/pool.go b/common/buf/pool.go index 913bda5..27bdfda 100644 --- a/common/buf/pool.go +++ b/common/buf/pool.go @@ -5,6 +5,7 @@ import "sync" const ( ReversedHeader = 1024 BufferSize = 20 * 1024 + UDPBufferSize = 16 * 1024 ) var pool = sync.Pool{ diff --git a/common/conn.go b/common/conn.go index 33466d3..9152786 100644 --- a/common/conn.go +++ b/common/conn.go @@ -6,15 +6,13 @@ import ( "time" ) -type ReadOnlyException struct { -} +type ReadOnlyException struct{} func (e *ReadOnlyException) Error() string { return "read only connection" } -type WriteOnlyException struct { -} +type WriteOnlyException struct{} func (e *WriteOnlyException) Error() string { return "write only connection" diff --git a/common/const.go b/common/const.go index 3dae261..7f0d777 100644 --- a/common/const.go +++ b/common/const.go @@ -2,8 +2,7 @@ package common const EmptyString = "" -type DummyAddr struct { -} +type DummyAddr struct{} func (d *DummyAddr) Network() string { return "dummy" diff --git a/common/genericsync/map.go b/common/gsync/map.go similarity index 99% rename from common/genericsync/map.go rename to common/gsync/map.go index 281d723..61af105 100644 --- a/common/genericsync/map.go +++ b/common/gsync/map.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package genericsync +package gsync import ( "sync" diff --git a/common/rw/copy.go b/common/rw/copy.go new file mode 100644 index 0000000..74c9e2f --- /dev/null +++ b/common/rw/copy.go @@ -0,0 +1,53 @@ +package rw + +import ( + "context" + "io" + "net" + + "sing/common" + "sing/common/buf" + "sing/common/task" +) + +func CopyConn(ctx context.Context, conn net.Conn, outConn net.Conn) error { + return task.Run(ctx, func() error { + return common.Error(io.Copy(conn, outConn)) + }, func() error { + return common.Error(io.Copy(outConn, conn)) + }) +} + +func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error { + return task.Run(ctx, func() error { + buffer := buf.FullNew() + defer buffer.Release() + for { + n, addr, err := conn.ReadFrom(buffer.FreeBytes()) + if err != nil { + return err + } + buffer.Truncate(n) + _, err = outPacketConn.WriteTo(buffer.Bytes(), addr) + if err != nil { + return err + } + buffer.FullReset() + } + }, func() error { + buffer := buf.FullNew() + defer buffer.Release() + for { + n, addr, err := outPacketConn.ReadFrom(buffer.FreeBytes()) + if err != nil { + return err + } + buffer.Truncate(n) + _, err = conn.WriteTo(buffer.Bytes(), addr) + if err != nil { + return err + } + buffer.FullReset() + } + }) +} diff --git a/common/rw/output.go b/common/rw/output.go index ebf99b5..2841bdb 100644 --- a/common/rw/output.go +++ b/common/rw/output.go @@ -2,6 +2,7 @@ package rw import ( "io" + "sing/common" "sing/common/buf" "sing/common/list" diff --git a/common/session/context.go b/common/session/context.go new file mode 100644 index 0000000..961a100 --- /dev/null +++ b/common/session/context.go @@ -0,0 +1,66 @@ +package session + +import ( + "net" + "strconv" + + "sing/common/buf" + "sing/common/socksaddr" +) + +type Network int + +const ( + NetworkTCP Network = iota + NetworkUDP +) + +type InstanceContext struct{} + +type Context struct { + InstanceContext + Network Network + Source socksaddr.Addr + Destination socksaddr.Addr + SourcePort uint16 + DestinationPort uint16 +} + +func (c Context) DestinationNetAddr() string { + return net.JoinHostPort(c.Destination.String(), strconv.Itoa(int(c.DestinationPort))) +} + +func AddressFromNetAddr(netAddr net.Addr) (addr socksaddr.Addr, port uint16) { + var ip net.IP + switch addr := netAddr.(type) { + case *net.TCPAddr: + ip = addr.IP + port = uint16(addr.Port) + case *net.UDPAddr: + ip = addr.IP + port = uint16(addr.Port) + } + return socksaddr.AddrFromIP(ip), port +} + +type Conn struct { + Conn net.Conn + Context *Context +} + +type PacketConn struct { + Conn net.PacketConn + Context *Context +} + +type Packet struct { + Context *Context + Data *buf.Buffer + WriteBack func(buffer *buf.Buffer, addr *net.UDPAddr) error + Release func() +} + +type Handler interface { + HandleConnection(conn *Conn) + HandlePacket(packet *Packet) +} diff --git a/common/session/pool.go b/common/session/pool.go new file mode 100644 index 0000000..211b5d5 --- /dev/null +++ b/common/session/pool.go @@ -0,0 +1,63 @@ +package session + +import ( + "container/list" + "sync" + + "sing/common" +) + +var ( + connectionPool list.List + connectionPoolEnabled bool + connectionAccess sync.Mutex +) + +func EnableConnectionPool() { + connectionPoolEnabled = true +} + +func DisableConnectionPool() { + connectionAccess.Lock() + defer connectionAccess.Unlock() + connectionPoolEnabled = false + clearConnections() +} + +func AddConnection(connection any) any { + if !connectionPoolEnabled { + return connection + } + connectionAccess.Lock() + defer connectionAccess.Unlock() + return connectionPool.PushBack(connection) +} + +func RemoveConnection(anyElement any) { + element, ok := anyElement.(*list.Element) + if !ok { + common.Close(anyElement) + return + } + if element.Value == nil { + return + } + common.Close(element.Value) + element.Value = nil + connectionAccess.Lock() + defer connectionAccess.Unlock() + connectionPool.Remove(element) +} + +func ResetConnections() { + connectionAccess.Lock() + defer connectionAccess.Unlock() + clearConnections() +} + +func clearConnections() { + for element := connectionPool.Front(); element != nil; element = element.Next() { + common.Close(element) + } + connectionPool.Init() +} diff --git a/common/socksaddr/addr.go b/common/socksaddr/addr.go index 0b4398c..a38879a 100644 --- a/common/socksaddr/addr.go +++ b/common/socksaddr/addr.go @@ -10,6 +10,7 @@ type Addr interface { Family() Family Addr() netip.Addr Fqdn() string + String() string } func AddrFromIP(ip net.IP) Addr { @@ -21,6 +22,19 @@ func AddrFromIP(ip net.IP) Addr { } } +func AddressFromNetAddr(netAddr net.Addr) (addr Addr, port uint16) { + var ip net.IP + switch addr := netAddr.(type) { + case *net.TCPAddr: + ip = addr.IP + port = uint16(addr.Port) + case *net.UDPAddr: + ip = addr.IP + port = uint16(addr.Port) + } + return AddrFromIP(ip), port +} + func AddrFromFqdn(fqdn string) Addr { return AddrFqdn(fqdn) } @@ -39,6 +53,10 @@ func (a Addr4) Fqdn() string { return "" } +func (a Addr4) String() string { + return net.IP(a[:]).String() +} + type Addr16 [16]byte func (a Addr16) Family() Family { @@ -53,6 +71,10 @@ func (a Addr16) Fqdn() string { return "" } +func (a Addr16) String() string { + return net.IP(a[:]).String() +} + type AddrFqdn string func (f AddrFqdn) Family() Family { @@ -66,3 +88,7 @@ func (f AddrFqdn) Addr() netip.Addr { func (f AddrFqdn) Fqdn() string { return string(f) } + +func (f AddrFqdn) String() string { + return string(f) +} diff --git a/common/socksaddr/serializer.go b/common/socksaddr/serializer.go index 32c2c35..48061d6 100644 --- a/common/socksaddr/serializer.go +++ b/common/socksaddr/serializer.go @@ -3,6 +3,7 @@ package socksaddr import ( "encoding/binary" "io" + "sing/common" "sing/common/exceptions" "sing/common/rw" diff --git a/common/task/task.go b/common/task/task.go new file mode 100644 index 0000000..96fe12f --- /dev/null +++ b/common/task/task.go @@ -0,0 +1,36 @@ +package task + +import ( + "context" + "sync" + + "sing/common" +) + +func Run(ctx context.Context, tasks ...func() error) error { + ctx, cancel := context.WithCancel(ctx) + wg := new(sync.WaitGroup) + wg.Add(len(tasks)) + var retErr error + for _, task := range tasks { + task := task + go func() { + if err := task(); err != nil { + if !common.Done(ctx) { + retErr = err + } + cancel() + } + wg.Done() + }() + } + go func() { + wg.Wait() + cancel() + }() + <-ctx.Done() + if retErr != nil { + return retErr + } + return ctx.Err() +} diff --git a/conf/config.go b/conf/config.go new file mode 100644 index 0000000..1100af7 --- /dev/null +++ b/conf/config.go @@ -0,0 +1,54 @@ +package conf + +import ( + "encoding/json" + + "sing/common/exceptions" + "sing/core" + "sing/transport" + "sing/transport/block" + "sing/transport/socks" + "sing/transport/system" +) + +type Config struct { + Inbounds []*InboundConfig `json:"inbounds,omitempty"` + Outbounds []*OutboundConfig `json:"outbounds,omitempty"` +} + +type InboundConfig struct { + Type string `json:"type"` + Tag string `json:"tag,omitempty"` + Settings json.RawMessage `json:"settings,omitempty"` +} + +func (c InboundConfig) Build(instance core.Instance) (transport.Inbound, error) { + switch c.Type { + case "socks": + config := new(socks.InboundConfig) + err := json.Unmarshal(c.Settings, config) + if err != nil { + return nil, err + } + return socks.NewListener(instance, config) + } + return nil, exceptions.New("unknown inbound type ", c.Type) +} + +type OutboundConfig struct { + Type string `json:"type"` + Settings json.RawMessage `json:"settings,omitempty"` +} + +func (c OutboundConfig) Build(instance core.Instance) (transport.Outbound, error) { + var outbound transport.Outbound + switch c.Type { + case "system": + outbound = new(system.Outbound) + case "block": + outbound = new(block.Outbound) + default: + return nil, exceptions.New("unknown outbound type: ", c.Type) + } + return outbound, nil +} diff --git a/core.go b/core.go new file mode 100644 index 0000000..f4c86b7 --- /dev/null +++ b/core.go @@ -0,0 +1,94 @@ +package sing + +import ( + "context" + "sing/common/session" + "sync" + + "sing/common/gsync" + "sing/common/list" + "sing/core" + "sing/transport" +) + +var _ core.Instance = (*Instance)(nil) + +type Instance struct { + access sync.Mutex + ctx context.Context + cancel context.CancelFunc + inbounds list.List[*transport.InboundContext] + inboundByName gsync.Map[string, *transport.InboundContext] + outbounds list.List[*transport.OutboundContext] + outboundByName gsync.Map[string, *transport.OutboundContext] + defaultOutbound *transport.OutboundContext +} + +func (i *Instance) AddInbound(inbound transport.Inbound, tag string) { + i.access.Lock() + defer i.access.Unlock() + + ic := new(transport.InboundContext) + ic.Context = i.ctx + ic.Tag = tag + ic.Inbound = inbound + + i.inbounds.InsertAfter(ic) + i.inboundByName.Store(tag, ic) +} + +func (i *Instance) Inbounds() *list.List[*transport.InboundContext] { + i.inboundByName.Range(func(tag string, inbound *transport.InboundContext) bool { + return true + }) + return &i.inbounds +} + +func (i *Instance) Inbound(tag string) *transport.InboundContext { + inbound, _ := i.inboundByName.Load(tag) + return inbound +} + +func (i *Instance) Outbounds() *list.List[*transport.OutboundContext] { + return &i.outbounds +} + +func (i *Instance) DefaultOutbound() *transport.OutboundContext { + i.access.Lock() + defer i.access.Unlock() + return i.defaultOutbound +} + +func (i *Instance) Outbound(tag string) *transport.OutboundContext { + outbound, _ := i.outboundByName.Load(tag) + return outbound +} + +func (i *Instance) HandleConnection(conn *session.Conn) { + i.defaultOutbound.Outbound.NewConnection(i.ctx, conn) +} + +func (i *Instance) HandlePacket(packet *session.Packet) { +} + +type InstanceContext interface { + context.Context + Instance() *Instance + Load(key string) (any, bool) + Store(key string, value any) +} + +type instanceContext struct { + context.Context + instance Instance + values gsync.Map[any, string] +} + +func (i *instanceContext) Load(key string) (any, bool) { + return i.values.Load(key) +} + +func (i *instanceContext) Store(key string, value any) { + // TODO implement me + panic("implement me") +} diff --git a/core/core.go b/core/core.go new file mode 100644 index 0000000..7629578 --- /dev/null +++ b/core/core.go @@ -0,0 +1,12 @@ +package core + +import ( + "sing/common/session" + "sing/transport" +) + +type Instance interface { + session.Handler + transport.InboundManager + transport.OutboundManager +} diff --git a/example/shadowboom/main.go b/example/shadowboom/main.go index 32714b6..23ac3e3 100644 --- a/example/shadowboom/main.go +++ b/example/shadowboom/main.go @@ -8,14 +8,14 @@ import ( "log" "net" "os" + + cObfs "github.com/Dreamacro/clash/transport/ssr/obfs" + cProtocol "github.com/Dreamacro/clash/transport/ssr/protocol" "sing/common" "sing/common/buf" "sing/common/socksaddr" "sing/protocol/shadowsocks" _ "sing/protocol/shadowsocks/shadowstream" - - cObfs "github.com/Dreamacro/clash/transport/ssr/obfs" - cProtocol "github.com/Dreamacro/clash/transport/ssr/protocol" ) var ( @@ -67,9 +67,9 @@ func main() { key := shadowsocks.Key([]byte(password), cipher.KeySize()) - if _, isAEAD := cipher.(*shadowsocks.AEADCipher); isAEAD { + /*if _, isAEAD := cipher.(*shadowsocks.AEADCipher); isAEAD { log.Fatal("not a stream cipher: ", method) - } + }*/ ipAddr, err := net.ResolveIPAddr("ip", address) if err != nil { diff --git a/format.go b/format.go new file mode 100644 index 0000000..bc48c25 --- /dev/null +++ b/format.go @@ -0,0 +1,6 @@ +package sing + +//go:generate go install -v mvdan.cc/gofumpt@latest +//go:generate go install -v github.com/daixiang0/gci@latest +//go:generate gofumpt -l -w . +//go:generate gci -w . diff --git a/go.mod b/go.mod index 7eeeb3b..99066ff 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/dgryski/go-rc2 v0.0.0-20150621095337-8a9021637152 github.com/geeksbaek/seed v0.0.0-20180909040025-2a7f5fb92e22 github.com/kierdavis/cfb8 v0.0.0-20180105024805-3a17c36ee2f8 - golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8 + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 ) // for testing and example only @@ -36,7 +36,7 @@ require ( github.com/sirupsen/logrus v1.8.1 // indirect github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e // indirect golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 // indirect - golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect - golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect + golang.org/x/net v0.0.0-20220225172249-27dd8689420f // indirect + golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect google.golang.org/protobuf v1.27.1 // indirect ) diff --git a/go.sum b/go.sum index 9210cdb..07946a7 100644 --- a/go.sum +++ b/go.sum @@ -76,16 +76,22 @@ go4.org/unsafe/assume-no-moving-gc v0.0.0-20211027215541-db492cf91b37 h1:Tx9kY6y golang.org/x/crypto v0.0.0-20210317152858-513c2a44f670/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8 h1:kACShD3qhmr/3rLmg1yXyt+N4HcwutKyPRB93s54TIU= golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 h1:LQmS1nU0twXLA96Kt7U9qtHJEbBk3z6Q0V4UXjZkpr4= golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= diff --git a/protocol/shadowsocks/cipher.go b/protocol/shadowsocks/cipher.go index be94813..5f5c382 100644 --- a/protocol/shadowsocks/cipher.go +++ b/protocol/shadowsocks/cipher.go @@ -2,6 +2,7 @@ package shadowsocks import ( "io" + "sing/common/buf" "sing/common/exceptions" "sing/common/list" @@ -18,8 +19,10 @@ type Cipher interface { type CipherCreator func() Cipher -var cipherList *list.List[string] -var cipherMap map[string]CipherCreator +var ( + cipherList *list.List[string] + cipherMap map[string]CipherCreator +) func init() { cipherList = new(list.List[string]) diff --git a/protocol/shadowsocks/cipher_test.go b/protocol/shadowsocks/cipher_test.go index 55e5cff..dff7bc3 100644 --- a/protocol/shadowsocks/cipher_test.go +++ b/protocol/shadowsocks/cipher_test.go @@ -5,7 +5,6 @@ import ( "bytes" "io" "net" - "sing/common/rw" "strings" "sync" "testing" @@ -17,6 +16,7 @@ import ( "sing/common" "sing/common/buf" "sing/common/crypto" + "sing/common/rw" "sing/common/socksaddr" "sing/protocol/shadowsocks" _ "sing/protocol/shadowsocks/shadowstream" @@ -195,7 +195,6 @@ func benchmarkShadowsocksCipher(b *testing.B, method string, data int) { } else { writer.Write(buffer.Bytes()) } - } func testShadowsocksClientTCPWithCipher(t *testing.T, cipherType vs.CipherType, cipherName string) { diff --git a/protocol/socks/protocol.go b/protocol/socks/protocol.go index 2d6e33a..71e4184 100644 --- a/protocol/socks/protocol.go +++ b/protocol/socks/protocol.go @@ -5,6 +5,7 @@ import ( "io" "sing/common" + "sing/common/buf" "sing/common/exceptions" "sing/common/rw" "sing/common/socksaddr" @@ -316,28 +317,12 @@ type AssociatePacket struct { Data []byte } -func WriteAssociatePacket(writer io.Writer, packet *AssociatePacket) error { - err := rw.WriteZeroN(writer, 2) - if err != nil { - return err - } - err = rw.WriteByte(writer, packet.Fragment) - if err != nil { - return err - } - err = AddressSerializer.WriteAddressAndPort(writer, packet.Addr, packet.Port) - if err != nil { - return err - } - return rw.WriteBytes(writer, packet.Data) -} - -func DecodeAssociatePacket(buffer []byte) (*AssociatePacket, error) { - if len(buffer) < 5 { +func DecodeAssociatePacket(buffer *buf.Buffer) (*AssociatePacket, error) { + if buffer.Len() < 5 { return nil, exceptions.New("insufficient length") } - fragment := buffer[2] - reader := bytes.NewReader(buffer) + fragment := buffer.Byte(2) + reader := bytes.NewReader(buffer.Bytes()) err := common.Error(reader.Seek(3, io.SeekStart)) if err != nil { return nil, err @@ -346,39 +331,29 @@ func DecodeAssociatePacket(buffer []byte) (*AssociatePacket, error) { if err != nil { return nil, err } - index := len(buffer) - reader.Len() - data := buffer[index:] + buffer.Advance(reader.Len()) packet := &AssociatePacket{ Fragment: fragment, Addr: addr, Port: port, - Data: data, + Data: buffer.Bytes(), } return packet, nil } -func ReadAssociatePacket(reader io.Reader) (*AssociatePacket, error) { - err := rw.SkipN(reader, 2) +func EncodeAssociatePacket(packet *AssociatePacket, buffer *buf.Buffer) error { + err := rw.WriteZeroN(buffer, 2) if err != nil { - return nil, err + return err } - fragment, err := rw.ReadByte(reader) + err = rw.WriteByte(buffer, packet.Fragment) if err != nil { - return nil, err + return err } - addr, port, err := AddressSerializer.ReadAddressAndPort(reader) + err = AddressSerializer.WriteAddressAndPort(buffer, packet.Addr, packet.Port) if err != nil { - return nil, err + return err } - data, err := io.ReadAll(reader) - if err != nil { - return nil, err - } - packet := &AssociatePacket{ - Fragment: fragment, - Addr: addr, - Port: port, - Data: data, - } - return packet, nil + _, err = buffer.Write(packet.Data) + return err } diff --git a/transport/block/outbound.go b/transport/block/outbound.go new file mode 100644 index 0000000..5488df4 --- /dev/null +++ b/transport/block/outbound.go @@ -0,0 +1,30 @@ +package block + +import ( + "context" + + "sing/common/session" + "sing/transport" +) + +var _ transport.Outbound = (*Outbound)(nil) + +type Outbound struct { +} + +func (h *Outbound) Init(*transport.OutboundContext) { +} + +func (h *Outbound) Close() error { + return nil +} + +func (o *Outbound) NewConnection(ctx context.Context, conn *session.Conn) error { + conn.Conn.Close() + return nil +} + +func (o *Outbound) NewPacketConnection(ctx context.Context, packetConn *session.PacketConn) error { + packetConn.Conn.Close() + return nil +} diff --git a/transport/inbound.go b/transport/inbound.go new file mode 100644 index 0000000..be2a934 --- /dev/null +++ b/transport/inbound.go @@ -0,0 +1,25 @@ +package transport + +import ( + "context" + + "sing/common/list" +) + +type Inbound interface { + Init(ctx *InboundContext) + Start() error + Close() error +} + +type InboundContext struct { + Context context.Context + Tag string + Inbound Inbound +} + +type InboundManager interface { + AddInbound(inbound Inbound, tag string) + Inbounds() *list.List[*InboundContext] + Inbound(tag string) *InboundContext +} diff --git a/transport/outbound.go b/transport/outbound.go new file mode 100644 index 0000000..b41d316 --- /dev/null +++ b/transport/outbound.go @@ -0,0 +1,28 @@ +package transport + +import ( + "context" + + "sing/common/list" + "sing/common/session" +) + +type Outbound interface { + Init(ctx *OutboundContext) + Close() error + NewConnection(ctx context.Context, conn *session.Conn) error + NewPacketConnection(ctx context.Context, packetConn *session.PacketConn) error +} + +type OutboundContext struct { + Context context.Context + Tag string + Outbound Outbound +} + +type OutboundManager interface { + AddOutbound(outbound Outbound, tag string) + Outbounds() *list.List[*OutboundContext] + Outbound(tag string) *OutboundContext + DefaultOutbound() *OutboundContext +} diff --git a/transport/socks/inbound.go b/transport/socks/inbound.go new file mode 100644 index 0000000..7541904 --- /dev/null +++ b/transport/socks/inbound.go @@ -0,0 +1,187 @@ +package socks + +import ( + "bytes" + "io" + "net" + + "github.com/sirupsen/logrus" + "net/netip" + "sing/common" + "sing/common/buf" + "sing/common/exceptions" + "sing/common/session" + "sing/common/socksaddr" + "sing/protocol/socks" + "sing/transport" + "sing/transport/system" +) + +var _ transport.Inbound = (*Inbound)(nil) + +type Inbound struct { + lAddr netip.AddrPort + username, password string + tcpListener *system.TCPListener + udpListener *system.UDPListener + handler session.Handler +} + +func (h *Inbound) Init(ctx *transport.InboundContext) { +} + +type InboundConfig struct { + Listen string `json:"listen"` + Port uint16 `json:"port"` + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` +} + +func NewListener(handler session.Handler, config *InboundConfig) (*Inbound, error) { + addr, err := netip.ParseAddr(config.Listen) + if err != nil { + return nil, exceptions.Cause(err, "invalid listen address: ", config.Listen) + } + lAddr := netip.AddrPortFrom(addr, config.Port) + inbound := new(Inbound) + inbound.username, inbound.password = config.Username, config.Password + inbound.handler = handler + inbound.tcpListener = system.NewTCPListener(lAddr, inbound) + inbound.udpListener = system.NewUDPListener(lAddr, inbound) + return inbound, nil +} + +func (h *Inbound) Start() error { + err := h.tcpListener.Start() + if err != nil { + return err + } + return h.udpListener.Start() +} + +func (h *Inbound) HandleTCP(conn net.Conn) error { + authRequest, err := socks.ReadAuthRequest(conn) + if err != nil { + return exceptions.Cause(err, "read socks auth request") + } + if h.username != "" { + if bytes.IndexByte(authRequest.Methods, socks.AuthTypeNotRequired) > 0 { + err = socks.WriteAuthResponse(conn, &socks.AuthResponse{ + Version: authRequest.Version, + Method: socks.AuthTypeNotRequired, + }) + if err != nil { + return exceptions.Cause(err, "write socks auth response") + } + } else { + socks.WriteAuthResponse(conn, &socks.AuthResponse{ + Version: authRequest.Version, + Method: socks.AuthTypeNoAcceptedMethods, + }) + return exceptions.New("no accepted methods, requested = ", authRequest.Methods, ", except no auth") + } + } else if bytes.IndexByte(authRequest.Methods, socks.AuthTypeNotRequired) == -1 { + socks.WriteAuthResponse(conn, &socks.AuthResponse{ + Version: authRequest.Version, + Method: socks.AuthTypeNoAcceptedMethods, + }) + return exceptions.New("no accepted methods, requested = ", authRequest.Methods, ", except password") + } else { + err = socks.WriteAuthResponse(conn, &socks.AuthResponse{ + Version: authRequest.Version, + Method: socks.AuthTypeUsernamePassword, + }) + if err != nil { + return exceptions.Cause(err, "write socks auth response: ", err) + } + usernamePasswordRequest, err := socks.ReadUsernamePasswordAuthRequest(conn) + if err != nil { + return exceptions.Cause(err, "read username-password request") + } + if usernamePasswordRequest.Username != h.username { + socks.WriteUsernamePasswordAuthResponse(conn, &socks.UsernamePasswordAuthResponse{Status: socks.UsernamePasswordStatusFailure}) + return exceptions.New("auth failed: excepted username ", h.username, ", got ", usernamePasswordRequest.Username) + } else if usernamePasswordRequest.Password != h.password { + socks.WriteUsernamePasswordAuthResponse(conn, &socks.UsernamePasswordAuthResponse{Status: socks.UsernamePasswordStatusFailure}) + return exceptions.New("auth failed: excepted password ", h.password, ", got ", usernamePasswordRequest.Password) + } + err = socks.WriteUsernamePasswordAuthResponse(conn, &socks.UsernamePasswordAuthResponse{Status: socks.UsernamePasswordStatusSuccess}) + if err != nil { + return exceptions.Cause(err, "write username-password response") + } + } + request, err := socks.ReadRequest(conn) + if err != nil { + return exceptions.Cause(err, "read request") + } + switch request.Command { + case socks.CommandBind: + socks.WriteResponse(conn, &socks.Response{ + Version: request.Version, + ReplyCode: socks.ReplyCodeUnsupported, + }) + return exceptions.New("bind unsupported") + case socks.CommandUDPAssociate: + addr, port := session.AddressFromNetAddr(h.udpListener.LocalAddr()) + err = socks.WriteResponse(conn, &socks.Response{ + Version: request.Version, + ReplyCode: socks.ReplyCodeSuccess, + BindAddr: addr, + BindPort: port, + }) + if err != nil { + return exceptions.Cause(err, "write response") + } + io.Copy(io.Discard, conn) + return nil + } + context := new(session.Context) + context.Network = session.NetworkTCP + context.Source, context.SourcePort = socksaddr.AddressFromNetAddr(conn.RemoteAddr()) + context.Destination, context.DestinationPort = request.Addr, request.Port + h.handler.HandleConnection(&session.Conn{ + Conn: conn, + Context: context, + }) + return nil +} + +func (h *Inbound) HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error { + associatePacket, err := socks.DecodeAssociatePacket(buffer) + if err != nil { + return exceptions.Cause(err, "decode associate packet") + } + context := new(session.Context) + context.Network = session.NetworkUDP + context.Source, context.SourcePort = socksaddr.AddressFromNetAddr(sourceAddr) + context.Destination, context.DestinationPort = associatePacket.Addr, associatePacket.Port + h.handler.HandlePacket(&session.Packet{ + Context: context, + Data: buffer, + Release: nil, + WriteBack: func(buffer *buf.Buffer, addr *net.UDPAddr) error { + header := new(socks.AssociatePacket) + header.Addr, header.Port = socksaddr.AddressFromNetAddr(addr) + header.Data = buffer.Bytes() + packet := buf.FullNew() + defer packet.Release() + err := socks.EncodeAssociatePacket(header, packet) + buffer.Release() + if err != nil { + return err + } + return common.Error(h.udpListener.WriteTo(packet.Bytes(), sourceAddr)) + }, + }) + return nil +} + +func (h *Inbound) OnError(err error) { + logrus.Warn("socks: ", err) +} + +func (h *Inbound) Close() error { + h.tcpListener.Close() + h.udpListener.Close() + return nil +} diff --git a/transport/system/control.go b/transport/system/control.go new file mode 100644 index 0000000..20a222f --- /dev/null +++ b/transport/system/control.go @@ -0,0 +1,30 @@ +package system + +import "syscall" + +var ControlFunc func(fd uintptr) error + +func Control(conn syscall.Conn) error { + if ControlFunc == nil { + return nil + } + rawConn, err := conn.SyscallConn() + if err != nil { + return err + } + return ControlRaw(rawConn) +} + +func ControlRaw(conn syscall.RawConn) error { + if ControlFunc == nil { + return nil + } + var rawFd uintptr + err := conn.Control(func(fd uintptr) { + rawFd = fd + }) + if err != nil { + return err + } + return ControlFunc(rawFd) +} diff --git a/transport/system/outbound.go b/transport/system/outbound.go new file mode 100644 index 0000000..8284bf0 --- /dev/null +++ b/transport/system/outbound.go @@ -0,0 +1,52 @@ +package system + +import ( + "context" + "net" + "sing/transport" + "syscall" + + "sing/common/rw" + "sing/common/session" +) + +var _ transport.Outbound = (*Outbound)(nil) + +type Outbound struct{} + +func (h *Outbound) Init(*transport.OutboundContext) { +} + +func (h *Outbound) Close() error { + return nil +} + +func (h *Outbound) NewConnection(ctx context.Context, conn *session.Conn) error { + dialer := net.Dialer{ + Control: func(network, address string, c syscall.RawConn) error { + return ControlRaw(c) + }, + } + outConn, err := dialer.DialContext(ctx, "tcp", conn.Context.DestinationNetAddr()) + if err != nil { + return err + } + connElement := session.AddConnection(outConn) + defer session.RemoveConnection(connElement) + return rw.CopyConn(ctx, conn.Conn, outConn) +} + +func (h *Outbound) NewPacketConnection(ctx context.Context, packetConn *session.PacketConn) error { + dialer := net.Dialer{ + Control: func(network, address string, c syscall.RawConn) error { + return ControlRaw(c) + }, + } + outConn, err := dialer.DialContext(ctx, "udp", packetConn.Context.DestinationNetAddr()) + if err != nil { + return err + } + connElement := session.AddConnection(outConn) + defer session.RemoveConnection(connElement) + return rw.CopyPacketConn(ctx, packetConn.Conn, outConn.(net.PacketConn)) +} diff --git a/transport/system/tcp.go b/transport/system/tcp.go new file mode 100644 index 0000000..10d036f --- /dev/null +++ b/transport/system/tcp.go @@ -0,0 +1,58 @@ +package system + +import ( + "net" + + "net/netip" +) + +type TCPHandler interface { + HandleTCP(conn net.Conn) error + OnError(err error) +} + +type TCPListener struct { + Listen netip.AddrPort + Handler TCPHandler + *net.TCPListener +} + +func NewTCPListener(listen netip.AddrPort, handler TCPHandler) *TCPListener { + return &TCPListener{ + Listen: listen, + Handler: handler, + } +} + +func (l *TCPListener) Start() error { + tcpListener, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(l.Listen)) + if err != nil { + return err + } + l.TCPListener = tcpListener + go l.loop() + return nil +} + +func (l *TCPListener) Close() error { + if l == nil || l.TCPListener == nil { + return nil + } + return l.TCPListener.Close() +} + +func (l *TCPListener) loop() { + for { + tcpConn, err := l.Accept() + if err != nil { + l.Close() + return + } + go func() { + err := l.Handler.HandleTCP(tcpConn) + if err != nil { + l.Handler.OnError(err) + } + }() + } +} diff --git a/transport/system/udp.go b/transport/system/udp.go new file mode 100644 index 0000000..244c5d1 --- /dev/null +++ b/transport/system/udp.go @@ -0,0 +1,62 @@ +package system + +import ( + "net" + + "net/netip" + "sing/common/buf" +) + +type UDPHandler interface { + HandleUDP(buffer *buf.Buffer, sourceAddr net.Addr) error + OnError(err error) +} + +type UDPListener struct { + Listen netip.AddrPort + Handler UDPHandler + *net.UDPConn +} + +func NewUDPListener(listen netip.AddrPort, handler UDPHandler) *UDPListener { + return &UDPListener{ + Listen: listen, + Handler: handler, + } +} + +func (l *UDPListener) Start() error { + udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(l.Listen)) + if err != nil { + return err + } + l.UDPConn = udpConn + go l.loop() + return nil +} + +func (l *UDPListener) Close() error { + if l == nil || l.UDPConn == nil { + return nil + } + return l.UDPConn.Close() +} + +func (l *UDPListener) loop() { + for { + buffer := buf.New() + n, addr, err := l.ReadFromUDP(buffer.Extend(buf.UDPBufferSize)) + if err != nil { + buffer.Release() + return + } + buffer.Truncate(n) + go func() { + err := l.Handler.HandleUDP(buffer, addr) + if err != nil { + buffer.Release() + l.Handler.OnError(err) + } + }() + } +}