diff --git a/common/baderror/baderror.go b/common/baderror/baderror.go index c5ab530..d90ddeb 100644 --- a/common/baderror/baderror.go +++ b/common/baderror/baderror.go @@ -2,11 +2,10 @@ package baderror import ( "context" + "errors" "io" "net" "strings" - - E "github.com/sagernet/sing/common/exceptions" ) func Contains(err error, msgList ...string) bool { @@ -22,8 +21,7 @@ func WrapH2(err error) error { if err == nil { return nil } - err = E.Unwrap(err) - if err == io.ErrUnexpectedEOF { + if errors.Is(err, io.ErrUnexpectedEOF) { return io.EOF } if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") { diff --git a/common/bufio/copy_direct_windows.go b/common/bufio/copy_direct_windows.go index 482b649..7da9213 100644 --- a/common/bufio/copy_direct_windows.go +++ b/common/bufio/copy_direct_windows.go @@ -120,16 +120,16 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions var readN int var from windows.Sockaddr readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0) + //goland:noinspection GoDirectComparisonOfErrors + if w.readErr != nil { + buffer.Release() + return w.readErr != windows.WSAEWOULDBLOCK + } if readN > 0 { buffer.Truncate(readN) - w.options.PostReturn(buffer) - w.buffer = buffer - } else { - buffer.Release() - } - if w.readErr == windows.WSAEWOULDBLOCK { - return false } + w.options.PostReturn(buffer) + w.buffer = buffer if from != nil { switch fromAddr := from.(type) { case *windows.SockaddrInet4: diff --git a/common/bufio/nat.go b/common/bufio/nat.go index cafeb06..6e5ab64 100644 --- a/common/bufio/nat.go +++ b/common/bufio/nat.go @@ -30,6 +30,14 @@ func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.So } } +func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn { + return &destinationNATPacketConn{ + NetPacketConn: conn, + origin: origin, + destination: destination, + } +} + type unidirectionalNATPacketConn struct { N.NetPacketConn origin M.Socksaddr @@ -144,6 +152,60 @@ func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr { return c.destination.UDPAddr() } +type destinationNATPacketConn struct { + N.NetPacketConn + origin M.Socksaddr + destination M.Socksaddr +} + +func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.NetPacketConn.ReadFrom(p) + if err != nil { + return + } + if M.SocksaddrFromNet(addr) == c.origin { + addr = c.destination.UDPAddr() + } + return +} + +func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if M.SocksaddrFromNet(addr) == c.destination { + addr = c.origin.UDPAddr() + } + return c.NetPacketConn.WriteTo(p, addr) +} + +func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + destination, err = c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return + } + if destination == c.origin { + destination = c.destination + } + return +} + +func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if destination == c.destination { + destination = c.origin + } + return c.NetPacketConn.WritePacket(buffer, destination) +} + +func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { + c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) +} + +func (c *destinationNATPacketConn) Upstream() any { + return c.NetPacketConn +} + +func (c *destinationNATPacketConn) RemoteAddr() net.Addr { + return c.destination.UDPAddr() +} + func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr { destination.Port = 0 return destination diff --git a/common/exceptions/cause.go b/common/exceptions/cause.go index fe7adf3..b48b6e2 100644 --- a/common/exceptions/cause.go +++ b/common/exceptions/cause.go @@ -12,3 +12,16 @@ func (e *causeError) Error() string { func (e *causeError) Unwrap() error { return e.cause } + +type causeError1 struct { + error + cause error +} + +func (e *causeError1) Error() string { + return e.error.Error() + ": " + e.cause.Error() +} + +func (e *causeError1) Unwrap() []error { + return []error{e.error, e.cause} +} diff --git a/common/exceptions/error.go b/common/exceptions/error.go index 3f08ac4..24f0c29 100644 --- a/common/exceptions/error.go +++ b/common/exceptions/error.go @@ -32,6 +32,13 @@ func Cause(cause error, message ...any) error { return &causeError{F.ToString(message...), cause} } +func Cause1(err error, cause error) error { + if cause == nil { + panic("cause on an nil error") + } + return &causeError1{err, cause} +} + func Extend(cause error, message ...any) error { if cause == nil { panic("extend on an nil error") @@ -40,11 +47,11 @@ func Extend(cause error, message ...any) error { } func IsClosedOrCanceled(err error) bool { - return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded) || IsTimeout(err) + return IsClosed(err) || IsCanceled(err) || IsTimeout(err) } func IsClosed(err error) bool { - return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET) + return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, syscall.ENOTCONN) } func IsCanceled(err error) bool { diff --git a/common/exceptions/inner.go b/common/exceptions/inner.go index 3799af9..58bcc5a 100644 --- a/common/exceptions/inner.go +++ b/common/exceptions/inner.go @@ -1,24 +1,14 @@ package exceptions -import "github.com/sagernet/sing/common" +import ( + "errors" -type HasInnerError interface { - Unwrap() error -} + "github.com/sagernet/sing/common" +) +// Deprecated: Use errors.Unwrap instead. func Unwrap(err error) error { - for { - inner, ok := err.(HasInnerError) - if !ok { - break - } - innerErr := inner.Unwrap() - if innerErr == nil { - break - } - err = innerErr - } - return err + return errors.Unwrap(err) } func Cast[T any](err error) (T, bool) { diff --git a/common/exceptions/multi.go b/common/exceptions/multi.go index 2cdec05..78fb3b6 100644 --- a/common/exceptions/multi.go +++ b/common/exceptions/multi.go @@ -63,12 +63,5 @@ func IsMulti(err error, targetList ...error) bool { return true } } - err = Unwrap(err) - multiErr, isMulti := err.(MultiError) - if !isMulti { - return false - } - return common.All(multiErr.Unwrap(), func(it error) bool { - return IsMulti(it, targetList...) - }) + return false } diff --git a/common/json/badjson/merge_objects.go b/common/json/badjson/merge_objects.go index fa6c2d4..5b23209 100644 --- a/common/json/badjson/merge_objects.go +++ b/common/json/badjson/merge_objects.go @@ -2,9 +2,11 @@ package badjson import ( "context" + "reflect" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" + cJSON "github.com/sagernet/sing/common/json/internal/contextjson" ) func MarshallObjects(objects ...any) ([]byte, error) { @@ -31,16 +33,12 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error } func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error { - parentContent, err := newJSONObject(ctx, parentObject) - if err != nil { - return err - } var content JSONObject - err = content.UnmarshalJSONContext(ctx, inputContent) + err := content.UnmarshalJSONContext(ctx, inputContent) if err != nil { return err } - for _, key := range parentContent.Keys() { + for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) { content.Remove(key) } if object == nil { diff --git a/common/json/badoption/listable.go b/common/json/badoption/listable.go index df02217..c48df12 100644 --- a/common/json/badoption/listable.go +++ b/common/json/badoption/listable.go @@ -18,6 +18,9 @@ func (l Listable[T]) MarshalJSONContext(ctx context.Context) ([]byte, error) { } func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error { + if string(content) == "null" { + return nil + } var singleItem T err := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem) if err == nil { diff --git a/common/json/internal/contextjson/keys.go b/common/json/internal/contextjson/keys.go new file mode 100644 index 0000000..589007f --- /dev/null +++ b/common/json/internal/contextjson/keys.go @@ -0,0 +1,20 @@ +package json + +import ( + "reflect" + + "github.com/sagernet/sing/common" +) + +func ObjectKeys(object reflect.Type) []string { + switch object.Kind() { + case reflect.Pointer: + return ObjectKeys(object.Elem()) + case reflect.Struct: + default: + panic("invalid non-struct input") + } + return common.Map(cachedTypeFields(object).list, func(field field) string { + return field.name + }) +} diff --git a/common/json/internal/contextjson/keys_test.go b/common/json/internal/contextjson/keys_test.go new file mode 100644 index 0000000..11c72cb --- /dev/null +++ b/common/json/internal/contextjson/keys_test.go @@ -0,0 +1,26 @@ +package json_test + +import ( + "reflect" + "testing" + + json "github.com/sagernet/sing/common/json/internal/contextjson" + + "github.com/stretchr/testify/require" +) + +type MyObject struct { + Hello string `json:"hello,omitempty"` + MyWorld + MyWorld2 string `json:"-"` +} + +type MyWorld struct { + World string `json:"world,omitempty"` +} + +func TestObjectKeys(t *testing.T) { + t.Parallel() + keys := json.ObjectKeys(reflect.TypeOf(&MyObject{})) + require.Equal(t, []string{"hello", "world"}, keys) +} diff --git a/common/ntp/time_windows.go b/common/ntp/time_windows.go index 5aab23e..82d8798 100644 --- a/common/ntp/time_windows.go +++ b/common/ntp/time_windows.go @@ -7,6 +7,7 @@ import ( ) func SetSystemTime(nowTime time.Time) error { + nowTime = nowTime.UTC() var systemTime windows.Systemtime systemTime.Year = uint16(nowTime.Year()) systemTime.Month = uint16(nowTime.Month()) diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 4ed7b74..e8987f4 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -28,6 +28,7 @@ type natConn struct { cache freelru.Cache[netip.AddrPort, *natConn] writer N.PacketWriter localAddr M.Socksaddr + handlerAccess sync.RWMutex handler N.UDPHandlerEx packetChan chan *N.PacketBuffer closeOnce sync.Once @@ -75,12 +76,10 @@ func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, } func (c *natConn) SetHandler(handler N.UDPHandlerEx) { - select { - case <-c.doneChan: - default: - } + c.handlerAccess.Lock() c.handler = handler c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler) + c.handlerAccess.Unlock() fetch: for { select { diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 07e786d..3367bd7 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -27,7 +27,7 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur } var cache freelru.Cache[netip.AddrPort, *natConn] if !shared { - cache = common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + cache = common.Must1(freelru.NewSynced[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) } else { cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) } @@ -51,7 +51,7 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur } func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) { - conn, loaded := s.cache.GetAndRefreshOrAdd(source.AddrPort(), func() (*natConn, bool) { + conn, _, ok := s.cache.GetAndRefreshOrAdd(source.AddrPort(), func() (*natConn, bool) { ok, ctx, writer, onClose := s.prepare(source, destination, userData) if !ok { return nil, false @@ -67,17 +67,18 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati go s.handler.NewPacketConnectionEx(ctx, newConn, source, destination, onClose) return newConn, true }) - if !loaded { - if conn == nil { - return - } + if !ok { + return } buffer := conn.readWaitOptions.NewPacketBuffer() for _, bufferSlice := range bufferSlices { buffer.Write(bufferSlice) } - if conn.handler != nil { - conn.handler.NewPacketEx(buffer, destination) + conn.handlerAccess.RLock() + handler := conn.handler + conn.handlerAccess.RUnlock() + if handler != nil { + handler.NewPacketEx(buffer, destination) return } packet := N.NewPacketBuffer() diff --git a/common/windnsapi/dnsapi_test.go b/common/windnsapi/dnsapi_test.go index adf582d..c5ea831 100644 --- a/common/windnsapi/dnsapi_test.go +++ b/common/windnsapi/dnsapi_test.go @@ -1,16 +1,14 @@ +//go:build windows + package windnsapi import ( - "runtime" "testing" "github.com/stretchr/testify/require" ) func TestDNSAPI(t *testing.T) { - if runtime.GOOS != "windows" { - t.SkipNow() - } t.Parallel() require.NoError(t, FlushResolverCache()) } diff --git a/common/winiphlpapi/helper.go b/common/winiphlpapi/helper.go new file mode 100644 index 0000000..6bd4e8f --- /dev/null +++ b/common/winiphlpapi/helper.go @@ -0,0 +1,217 @@ +//go:build windows + +package winiphlpapi + +import ( + "context" + "encoding/binary" + "net" + "net/netip" + "os" + "time" + "unsafe" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func LoadEStats() error { + err := modiphlpapi.Load() + if err != nil { + return err + } + err = procGetTcpTable.Find() + if err != nil { + return err + } + err = procGetTcp6Table.Find() + if err != nil { + return err + } + err = procGetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + err = procGetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + err = procSetPerTcpConnectionEStats.Find() + if err != nil { + return err + } + err = procSetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + return nil +} + +func LoadExtendedTable() error { + err := modiphlpapi.Load() + if err != nil { + return err + } + err = procGetExtendedTcpTable.Find() + if err != nil { + return err + } + err = procGetExtendedUdpTable.Find() + if err != nil { + return err + } + return nil +} + +func FindPid(network string, source netip.AddrPort) (uint32, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + if source.Addr().Is4() { + tcpTable, err := GetExtendedTcpTable() + if err != nil { + return 0, err + } + for _, row := range tcpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } else { + tcpTable, err := GetExtendedTcp6Table() + if err != nil { + return 0, err + } + for _, row := range tcpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } + case N.NetworkUDP: + if source.Addr().Is4() { + udpTable, err := GetExtendedUdpTable() + if err != nil { + return 0, err + } + for _, row := range udpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } else { + udpTable, err := GetExtendedUdp6Table() + if err != nil { + return 0, err + } + for _, row := range udpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } + } + return 0, E.New("process not found for ", source) +} + +func WriteAndWaitAck(ctx context.Context, conn net.Conn, payload []byte) error { + source := M.AddrPortFromNet(conn.LocalAddr()) + destination := M.AddrPortFromNet(conn.RemoteAddr()) + if source.Addr().Is4() { + tcpTable, err := GetTcpTable() + if err != nil { + return err + } + var tcpRow *MibTcpRow + for _, row := range tcpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) || + destination == netip.AddrPortFrom(DwordToAddr(row.DwRemoteAddr), DwordToPort(row.DwRemotePort)) { + tcpRow = &row + break + } + } + if tcpRow == nil { + return E.New("row not found for: ", source) + } + err = SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: true, + }) + if err != nil { + return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err) + } + defer SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: false, + }) + _, err = conn.Write(payload) + if err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + eStstsSendBuffer, err := GetPerTcpConnectionEStatsSendBuffer(tcpRow) + if err != nil { + return err + } + if eStstsSendBuffer.CurRetxQueue == 0 { + return nil + } + time.Sleep(10 * time.Millisecond) + } + } else { + tcpTable, err := GetTcp6Table() + if err != nil { + return err + } + var tcpRow *MibTcp6Row + for _, row := range tcpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.LocalAddr), DwordToPort(row.LocalPort)) || + destination == netip.AddrPortFrom(netip.AddrFrom16(row.RemoteAddr), DwordToPort(row.RemotePort)) { + tcpRow = &row + break + } + } + if tcpRow == nil { + return E.New("row not found for: ", source) + } + err = SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: true, + }) + if err != nil { + return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err) + } + defer SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: false, + }) + _, err = conn.Write(payload) + if err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + eStstsSendBuffer, err := GetPerTcp6ConnectionEStatsSendBuffer(tcpRow) + if err != nil { + return err + } + if eStstsSendBuffer.CurRetxQueue == 0 { + return nil + } + time.Sleep(10 * time.Millisecond) + } + } +} + +func DwordToAddr(addr uint32) netip.Addr { + return netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(&addr))) +} + +func DwordToPort(dword uint32) uint16 { + return binary.BigEndian.Uint16((*[4]byte)(unsafe.Pointer(&dword))[:]) +} diff --git a/common/winiphlpapi/iphlpapi.go b/common/winiphlpapi/iphlpapi.go new file mode 100644 index 0000000..74e5b90 --- /dev/null +++ b/common/winiphlpapi/iphlpapi.go @@ -0,0 +1,313 @@ +//go:build windows + +package winiphlpapi + +import ( + "errors" + "os" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + TcpTableBasicListener uint32 = iota + TcpTableBasicConnections + TcpTableBasicAll + TcpTableOwnerPidListener + TcpTableOwnerPidConnections + TcpTableOwnerPidAll + TcpTableOwnerModuleListener + TcpTableOwnerModuleConnections + TcpTableOwnerModuleAll +) + +const ( + UdpTableBasic uint32 = iota + UdpTableOwnerPid + UdpTableOwnerModule +) + +const ( + TcpConnectionEstatsSynOpts uint32 = iota + TcpConnectionEstatsData + TcpConnectionEstatsSndCong + TcpConnectionEstatsPath + TcpConnectionEstatsSendBuff + TcpConnectionEstatsRec + TcpConnectionEstatsObsRec + TcpConnectionEstatsBandwidth + TcpConnectionEstatsFineRtt + TcpConnectionEstatsMaximum +) + +type MibTcpTable struct { + DwNumEntries uint32 + Table [1]MibTcpRow +} + +type MibTcpRow struct { + DwState uint32 + DwLocalAddr uint32 + DwLocalPort uint32 + DwRemoteAddr uint32 + DwRemotePort uint32 +} + +type MibTcp6Table struct { + DwNumEntries uint32 + Table [1]MibTcp6Row +} + +type MibTcp6Row struct { + State uint32 + LocalAddr [16]byte + LocalScopeId uint32 + LocalPort uint32 + RemoteAddr [16]byte + RemoteScopeId uint32 + RemotePort uint32 +} + +type MibTcpTableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibTcpRowOwnerPid +} + +type MibTcpRowOwnerPid struct { + DwState uint32 + DwLocalAddr uint32 + DwLocalPort uint32 + DwRemoteAddr uint32 + DwRemotePort uint32 + DwOwningPid uint32 +} + +type MibTcp6TableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibTcp6RowOwnerPid +} + +type MibTcp6RowOwnerPid struct { + UcLocalAddr [16]byte + DwLocalScopeId uint32 + DwLocalPort uint32 + UcRemoteAddr [16]byte + DwRemoteScopeId uint32 + DwRemotePort uint32 + DwState uint32 + DwOwningPid uint32 +} + +type MibUdpTableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibUdpRowOwnerPid +} + +type MibUdpRowOwnerPid struct { + DwLocalAddr uint32 + DwLocalPort uint32 + DwOwningPid uint32 +} + +type MibUdp6TableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibUdp6RowOwnerPid +} + +type MibUdp6RowOwnerPid struct { + UcLocalAddr [16]byte + DwLocalScopeId uint32 + DwLocalPort uint32 + DwOwningPid uint32 +} + +type TcpEstatsSendBufferRodV0 struct { + CurRetxQueue uint64 + MaxRetxQueue uint64 + CurAppWQueue uint64 + MaxAppWQueue uint64 +} + +type TcpEstatsSendBuffRwV0 struct { + EnableCollection bool +} + +const ( + offsetOfMibTcpTable = unsafe.Offsetof(MibTcpTable{}.Table) + offsetOfMibTcp6Table = unsafe.Offsetof(MibTcp6Table{}.Table) + offsetOfMibTcpTableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table) + offsetOfMibTcp6TableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table) + offsetOfMibUdpTableOwnerPid = unsafe.Offsetof(MibUdpTableOwnerPid{}.Table) + offsetOfMibUdp6TableOwnerPid = unsafe.Offsetof(MibUdp6TableOwnerPid{}.Table) + sizeOfTcpEstatsSendBuffRwV0 = unsafe.Sizeof(TcpEstatsSendBuffRwV0{}) + sizeOfTcpEstatsSendBufferRodV0 = unsafe.Sizeof(TcpEstatsSendBufferRodV0{}) +) + +func GetTcpTable() ([]MibTcpRow, error) { + var size uint32 + err := getTcpTable(nil, &size, false) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, err + } + for { + table := make([]byte, size) + err = getTcpTable(&table[0], &size, false) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, err + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcpRow)(unsafe.Pointer(&table[offsetOfMibTcpTable])), dwNumEntries), nil + } +} + +func GetTcp6Table() ([]MibTcp6Row, error) { + var size uint32 + err := getTcp6Table(nil, &size, false) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, err + } + for { + table := make([]byte, size) + err = getTcp6Table(&table[0], &size, false) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, err + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcp6Row)(unsafe.Pointer(&table[offsetOfMibTcp6Table])), dwNumEntries), nil + } +} + +func GetExtendedTcpTable() ([]MibTcpRowOwnerPid, error) { + var size uint32 + err := getExtendedTcpTable(nil, &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcpTableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedTcp6Table() ([]MibTcp6RowOwnerPid, error) { + var size uint32 + err := getExtendedTcpTable(nil, &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcp6TableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedUdpTable() ([]MibUdpRowOwnerPid, error) { + var size uint32 + err := getExtendedUdpTable(nil, &size, false, windows.AF_INET, UdpTableOwnerPid, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET, UdpTableOwnerPid, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdpTableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedUdp6Table() ([]MibUdp6RowOwnerPid, error) { + var size uint32 + err := getExtendedUdpTable(nil, &size, false, windows.AF_INET6, UdpTableOwnerPid, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET6, UdpTableOwnerPid, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdp6TableOwnerPid])), dwNumEntries), nil + } +} + +func GetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow) (*TcpEstatsSendBufferRodV0, error) { + var rod TcpEstatsSendBufferRodV0 + err := getPerTcpConnectionEStats(row, + TcpConnectionEstatsSendBuff, + 0, + 0, + 0, + 0, + 0, + 0, + uintptr(unsafe.Pointer(&rod)), + 0, + uint64(sizeOfTcpEstatsSendBufferRodV0), + ) + if err != nil { + return nil, err + } + return &rod, nil +} + +func GetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row) (*TcpEstatsSendBufferRodV0, error) { + var rod TcpEstatsSendBufferRodV0 + err := getPerTcp6ConnectionEStats(row, + TcpConnectionEstatsSendBuff, + 0, + 0, + 0, + 0, + 0, + 0, + uintptr(unsafe.Pointer(&rod)), + 0, + uint64(sizeOfTcpEstatsSendBufferRodV0), + ) + if err != nil { + return nil, err + } + return &rod, nil +} + +func SetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow, rw *TcpEstatsSendBuffRwV0) error { + return setPerTcpConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0) +} + +func SetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row, rw *TcpEstatsSendBuffRwV0) error { + return setPerTcp6ConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0) +} diff --git a/common/winiphlpapi/iphlpapi_test.go b/common/winiphlpapi/iphlpapi_test.go new file mode 100644 index 0000000..5fc3b74 --- /dev/null +++ b/common/winiphlpapi/iphlpapi_test.go @@ -0,0 +1,90 @@ +//go:build windows + +package winiphlpapi_test + +import ( + "context" + "net" + "syscall" + "testing" + + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/winiphlpapi" + + "github.com/stretchr/testify/require" +) + +func TestFindPidTcp4(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidTcp6(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "[::1]:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidUdp4(t *testing.T) { + t.Parallel() + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidUdp6(t *testing.T) { + t.Parallel() + conn, err := net.ListenPacket("udp", "[::1]:0") + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestWaitAck4(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello")) + require.NoError(t, err) +} + +func TestWaitAck6(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "[::1]:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello")) + require.NoError(t, err) +} diff --git a/common/winiphlpapi/syscall_windows.go b/common/winiphlpapi/syscall_windows.go new file mode 100644 index 0000000..f6aab14 --- /dev/null +++ b/common/winiphlpapi/syscall_windows.go @@ -0,0 +1,27 @@ +package winiphlpapi + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcptable +//sys getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcpTable + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcp6table +//sys getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcp6Table + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcpconnectionestats +//sys getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcpConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcp6connectionestats +//sys getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcp6ConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcpconnectionestats +//sys setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcpConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcp6connectionestats +//sys setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcp6ConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable +//sys getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedTcpTable + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedudptable +//sys getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedUdpTable diff --git a/common/winiphlpapi/zsyscall_windows.go b/common/winiphlpapi/zsyscall_windows.go new file mode 100644 index 0000000..e5e9308 --- /dev/null +++ b/common/winiphlpapi/zsyscall_windows.go @@ -0,0 +1,131 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package winiphlpapi + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") + + procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable") + procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable") + procGetPerTcp6ConnectionEStats = modiphlpapi.NewProc("GetPerTcp6ConnectionEStats") + procGetPerTcpConnectionEStats = modiphlpapi.NewProc("GetPerTcpConnectionEStats") + procGetTcp6Table = modiphlpapi.NewProc("GetTcp6Table") + procGetTcpTable = modiphlpapi.NewProc("GetTcpTable") + procSetPerTcp6ConnectionEStats = modiphlpapi.NewProc("SetPerTcp6ConnectionEStats") + procSetPerTcpConnectionEStats = modiphlpapi.NewProc("SetPerTcpConnectionEStats") +) + +func getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) { + var _p0 uint32 + if bOrder { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procGetExtendedTcpTable.Addr(), 6, uintptr(unsafe.Pointer(pTcpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) { + var _p0 uint32 + if bOrder { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procGetExtendedUdpTable.Addr(), 6, uintptr(unsafe.Pointer(pUdpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) { + r0, _, _ := syscall.Syscall12(procGetPerTcp6ConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) { + r0, _, _ := syscall.Syscall12(procGetPerTcpConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) { + var _p0 uint32 + if order { + _p0 = 1 + } + r0, _, _ := syscall.Syscall(procGetTcp6Table.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) { + var _p0 uint32 + if order { + _p0 = 1 + } + r0, _, _ := syscall.Syscall(procGetTcpTable.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) { + r0, _, _ := syscall.Syscall6(procSetPerTcp6ConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) { + r0, _, _ := syscall.Syscall6(procSetPerTcpConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} diff --git a/contrab/freelru/cache.go b/contrab/freelru/cache.go index 319e662..22488e0 100644 --- a/contrab/freelru/cache.go +++ b/contrab/freelru/cache.go @@ -49,12 +49,14 @@ type Cache[K comparable, V comparable] interface { GetWithLifetime(key K) (V, time.Time, bool) + GetWithLifetimeNoExpire(key K) (V, time.Time, bool) + // GetAndRefresh returns the value associated with the key, setting it as the most // recently used item. // The lifetime of the found cache item is refreshed, even if it was already expired. GetAndRefresh(key K) (V, bool) - GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool) + GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool, bool) // Peek looks up a key's value from the cache, without changing its recent-ness. // If the found entry is already expired, the evict function is called. diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index 55e8937..055057c 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -500,6 +500,24 @@ func (lru *LRU[K, V]) getWithLifetime(hash uint32, key K) (value V, lifetime tim return } +func (lru *LRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) { + return lru.getWithLifetimeNoExpire(lru.hash(key), key) +} + +func (lru *LRU[K, V]) getWithLifetimeNoExpire(hash uint32, key K) (value V, lifetime time.Time, ok bool) { + if pos, ok := lru.findKeyNoExpire(hash, key); ok { + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Hits++ + return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok + } + + lru.metrics.Misses++ + return +} + // GetAndRefresh returns the value associated with the key, setting it as the most // recently used item. // The lifetime of the found cache item is refreshed, even if it was already expired. @@ -522,11 +540,15 @@ func (lru *LRU[K, V]) getAndRefresh(hash uint32, key K) (value V, ok bool) { return } -func (lru *LRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool) { - return lru.getAndRefreshOrAdd(lru.hash(key), key, constructor) +func (lru *LRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool, bool) { + value, updated, ok := lru.getAndRefreshOrAdd(lru.hash(key), key, constructor) + if !updated && ok { + lru.PurgeExpired() + } + return value, updated, ok } -func (lru *LRU[K, V]) getAndRefreshOrAdd(hash uint32, key K, constructor func() (V, bool)) (value V, ok bool) { +func (lru *LRU[K, V]) getAndRefreshOrAdd(hash uint32, key K, constructor func() (V, bool)) (value V, updated bool, ok bool) { if pos, ok := lru.findKeyNoExpire(hash, key); ok { if pos != lru.head { lru.unlinkElement(pos) @@ -534,17 +556,15 @@ func (lru *LRU[K, V]) getAndRefreshOrAdd(hash uint32, key K, constructor func() } lru.metrics.Hits++ lru.elements[pos].expire = expire(lru.lifetime) - return lru.elements[pos].value, ok + return lru.elements[pos].value, true, true } - lru.metrics.Misses++ value, ok = constructor() if !ok { return } lru.addWithLifetime(hash, key, value, lru.lifetime) - lru.PurgeExpired() - return value, false + return value, false, true } // Peek looks up a key's value from the cache, without changing its recent-ness. diff --git a/contrab/freelru/shardedlru.go b/contrab/freelru/shardedlru.go index e6aca65..fa325d7 100644 --- a/contrab/freelru/shardedlru.go +++ b/contrab/freelru/shardedlru.go @@ -187,6 +187,17 @@ func (lru *ShardedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time return } +func (lru *ShardedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].RLock() + value, lifetime, ok = lru.lrus[shard].getWithLifetimeNoExpire(hash, key) + lru.mus[shard].RUnlock() + + return +} + // GetAndRefresh returns the value associated with the key, setting it as the most // recently used item. // The lifetime of the found cache item is refreshed, even if it was already expired. @@ -201,14 +212,17 @@ func (lru *ShardedLRU[K, V]) GetAndRefresh(key K) (value V, ok bool) { return } -func (lru *ShardedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool) { +func (lru *ShardedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool, ok bool) { hash := lru.hash(key) shard := (hash >> 16) & lru.mask lru.mus[shard].Lock() - value, updated = lru.lrus[shard].getAndRefreshOrAdd(hash, key, constructor) + value, updated, ok = lru.lrus[shard].getAndRefreshOrAdd(hash, key, constructor) lru.mus[shard].Unlock() + if !updated && ok { + lru.PurgeExpired() + } return } diff --git a/contrab/freelru/syncedlru.go b/contrab/freelru/syncedlru.go index d42da70..4364907 100644 --- a/contrab/freelru/syncedlru.go +++ b/contrab/freelru/syncedlru.go @@ -108,6 +108,16 @@ func (lru *SyncedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, return } +func (lru *SyncedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, lifetime, ok = lru.lru.getWithLifetimeNoExpire(hash, key) + lru.mu.Unlock() + + return +} + // GetAndRefresh returns the value associated with the key, setting it as the most // recently used item. // The lifetime of the found cache item is refreshed, even if it was already expired. @@ -121,11 +131,14 @@ func (lru *SyncedLRU[K, V]) GetAndRefresh(key K) (value V, ok bool) { return } -func (lru *SyncedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool) { +func (lru *SyncedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool, ok bool) { hash := lru.lru.hash(key) lru.mu.Lock() - value, updated = lru.lru.getAndRefreshOrAdd(hash, key, constructor) + value, updated, ok = lru.lru.getAndRefreshOrAdd(hash, key, constructor) + if !updated && ok { + lru.lru.PurgeExpired() + } lru.mu.Unlock() return diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 8a0ff86..dc9c057 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -24,6 +25,10 @@ type HandlerEx interface { N.UDPConnectionHandlerEx } +type PacketListener interface { + ListenPacket(listenConfig net.ListenConfig, ctx context.Context, network string, address string) (net.PacketConn, error) +} + func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) { err := socks4.WriteRequest(conn, socks4.Request{ Command: command, @@ -120,6 +125,8 @@ func HandleConnectionEx( ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler HandlerEx, + packetListener PacketListener, + // resolver TorResolver, source M.Socksaddr, onClose N.CloseHandlerFunc, ) error { @@ -147,6 +154,11 @@ func HandleConnectionEx( } handler.NewConnectionEx(auth.ContextWithUser(ctx, request.Username), NewLazyConn(conn, version), source, request.Destination, onClose) return nil + /*case CommandTorResolve, CommandTorResolvePTR: + if resolver == nil { + return E.New("socks4: torsocks: commands not implemented") + } + return handleTorSocks4(ctx, conn, request, resolver)*/ default: err = socks4.WriteResponse(conn, socks4.Response{ ReplyCode: socks4.ReplyCodeRejectedOrFailed, @@ -213,12 +225,40 @@ func HandleConnectionEx( handler.NewConnectionEx(ctx, NewLazyConn(conn, version), source, request.Destination, onClose) return nil case socks5.CommandUDPAssociate: - var udpConn *net.UDPConn - udpConn, err = net.ListenUDP(M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), net.UDPAddrFromAddrPort(netip.AddrPortFrom(M.AddrFromNet(conn.LocalAddr()), 0))) - if err != nil { - return err + var ( + listenConfig net.ListenConfig + udpConn net.PacketConn + ) + if packetListener != nil { + udpConn, err = packetListener.ListenPacket(listenConfig, ctx, M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), M.SocksaddrFrom(M.AddrFromNet(conn.LocalAddr()), 0).String()) + } else { + udpConn, err = listenConfig.ListenPacket(ctx, M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), M.SocksaddrFrom(M.AddrFromNet(conn.LocalAddr()), 0).String()) } - handler.NewPacketConnectionEx(ctx, NewLazyAssociatePacketConn(bufio.NewServerPacketConn(udpConn), conn), source, M.Socksaddr{}, onClose) + if err != nil { + return E.Cause(err, "socks5: listen udp") + } + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFromNet(udpConn.LocalAddr()), + }) + if err != nil { + return E.Cause(err, "socks5: write response") + } + var socksPacketConn N.PacketConn = NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), M.Socksaddr{}, conn) + firstPacket := buf.NewPacket() + var destination M.Socksaddr + destination, err = socksPacketConn.ReadPacket(firstPacket) + if err != nil { + return E.Cause(err, "socks5: read first packet") + } + socksPacketConn = bufio.NewCachedPacketConn(socksPacketConn, firstPacket, destination) + handler.NewPacketConnectionEx(ctx, socksPacketConn, source, destination, onClose) + return nil + /*case CommandTorResolve, CommandTorResolvePTR: + if resolver == nil { + return E.New("socks4: torsocks: commands not implemented") + } + return handleTorSocks5(ctx, conn, request, resolver)*/ default: err = socks5.WriteResponse(conn, socks5.Response{ ReplyCode: socks5.ReplyCodeUnsupported, diff --git a/protocol/socks/handshake_tor.go b/protocol/socks/handshake_tor.go new file mode 100644 index 0000000..5d66322 --- /dev/null +++ b/protocol/socks/handshake_tor.go @@ -0,0 +1,146 @@ +package socks + +import ( + "context" + "net" + "net/netip" + "os" + "strings" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/protocol/socks/socks4" + "github.com/sagernet/sing/protocol/socks/socks5" +) + +const ( + CommandTorResolve byte = 0xF0 + CommandTorResolvePTR byte = 0xF1 +) + +type TorResolver interface { + LookupIP(ctx context.Context, host string) (netip.Addr, error) + LookupPTR(ctx context.Context, addr netip.Addr) (string, error) +} + +func handleTorSocks4(ctx context.Context, conn net.Conn, request socks4.Request, resolver TorResolver) error { + switch request.Command { + case CommandTorResolve: + if !request.Destination.IsFqdn() { + return E.New("socks4: torsocks: invalid destination") + } + ipAddr, err := resolver.LookupIP(ctx, request.Destination.Fqdn) + if err != nil { + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeRejectedOrFailed, + }) + if err != nil { + return err + } + return E.Cause(err, "socks4: torsocks: lookup failed for domain: ", request.Destination.Fqdn) + } + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeGranted, + Destination: M.SocksaddrFrom(ipAddr, 0), + }) + if err != nil { + return E.Cause(err, "socks4: torsocks: write response") + } + return nil + case CommandTorResolvePTR: + var ipAddr netip.Addr + if request.Destination.IsIP() { + ipAddr = request.Destination.Addr + } else if strings.HasSuffix(request.Destination.Fqdn, ".in-addr.arpa") { + ipAddr, _ = netip.ParseAddr(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".in-addr.arpa")]) + } else if strings.HasSuffix(request.Destination.Fqdn, ".ip6.arpa") { + ipAddr, _ = netip.ParseAddr(strings.ReplaceAll(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".ip6.arpa")], ".", ":")) + } + if !ipAddr.IsValid() { + return E.New("socks4: torsocks: invalid destination") + } + host, err := resolver.LookupPTR(ctx, ipAddr) + if err != nil { + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeRejectedOrFailed, + }) + if err != nil { + return err + } + return E.Cause(err, "socks4: torsocks: lookup PTR failed for ip: ", ipAddr) + } + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeGranted, + Destination: M.Socksaddr{ + Fqdn: host, + }, + }) + if err != nil { + return E.Cause(err, "socks4: torsocks: write response") + } + return nil + default: + return os.ErrInvalid + } +} + +func handleTorSocks5(ctx context.Context, conn net.Conn, request socks5.Request, resolver TorResolver) error { + switch request.Command { + case CommandTorResolve: + if !request.Destination.IsFqdn() { + return E.New("socks5: torsocks: invalid destination") + } + ipAddr, err := resolver.LookupIP(ctx, request.Destination.Fqdn) + if err != nil { + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeFailure, + }) + if err != nil { + return err + } + return E.Cause(err, "socks5: torsocks: lookup failed for domain: ", request.Destination.Fqdn) + } + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFrom(ipAddr, 0), + }) + if err != nil { + return E.Cause(err, "socks5: torsocks: write response") + } + return nil + case CommandTorResolvePTR: + var ipAddr netip.Addr + if request.Destination.IsIP() { + ipAddr = request.Destination.Addr + } else if strings.HasSuffix(request.Destination.Fqdn, ".in-addr.arpa") { + ipAddr, _ = netip.ParseAddr(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".in-addr.arpa")]) + } else if strings.HasSuffix(request.Destination.Fqdn, ".ip6.arpa") { + ipAddr, _ = netip.ParseAddr(strings.ReplaceAll(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".ip6.arpa")], ".", ":")) + } + if !ipAddr.IsValid() { + return E.New("socks5: torsocks: invalid destination") + } + host, err := resolver.LookupPTR(ctx, ipAddr) + if err != nil { + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeFailure, + }) + if err != nil { + return err + } + return E.Cause(err, "socks5: torsocks: lookup PTR failed for ip: ", ipAddr) + } + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.Socksaddr{ + Fqdn: host, + }, + }) + if err != nil { + return E.Cause(err, "socks5: torsocks: write response") + } + return nil + default: + return os.ErrInvalid + } +}