From 6f4c52841d2a60ecf687e3939a631d44f5b74d44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 29 Jan 2025 19:59:59 +0800 Subject: [PATCH 1/4] Add winiphlpapi --- common/windnsapi/dnsapi_test.go | 6 +- common/winiphlpapi/helper.go | 217 +++++++++++++++++ common/winiphlpapi/iphlpapi.go | 313 +++++++++++++++++++++++++ common/winiphlpapi/iphlpapi_test.go | 90 +++++++ common/winiphlpapi/syscall_windows.go | 27 +++ common/winiphlpapi/zsyscall_windows.go | 131 +++++++++++ 6 files changed, 780 insertions(+), 4 deletions(-) create mode 100644 common/winiphlpapi/helper.go create mode 100644 common/winiphlpapi/iphlpapi.go create mode 100644 common/winiphlpapi/iphlpapi_test.go create mode 100644 common/winiphlpapi/syscall_windows.go create mode 100644 common/winiphlpapi/zsyscall_windows.go diff --git a/common/windnsapi/dnsapi_test.go b/common/windnsapi/dnsapi_test.go index adf582d..c5ea831 100644 --- a/common/windnsapi/dnsapi_test.go +++ b/common/windnsapi/dnsapi_test.go @@ -1,16 +1,14 @@ +//go:build windows + package windnsapi import ( - "runtime" "testing" "github.com/stretchr/testify/require" ) func TestDNSAPI(t *testing.T) { - if runtime.GOOS != "windows" { - t.SkipNow() - } t.Parallel() require.NoError(t, FlushResolverCache()) } diff --git a/common/winiphlpapi/helper.go b/common/winiphlpapi/helper.go new file mode 100644 index 0000000..6bd4e8f --- /dev/null +++ b/common/winiphlpapi/helper.go @@ -0,0 +1,217 @@ +//go:build windows + +package winiphlpapi + +import ( + "context" + "encoding/binary" + "net" + "net/netip" + "os" + "time" + "unsafe" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func LoadEStats() error { + err := modiphlpapi.Load() + if err != nil { + return err + } + err = procGetTcpTable.Find() + if err != nil { + return err + } + err = procGetTcp6Table.Find() + if err != nil { + return err + } + err = procGetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + err = procGetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + err = procSetPerTcpConnectionEStats.Find() + if err != nil { + return err + } + err = procSetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + return nil +} + +func LoadExtendedTable() error { + err := modiphlpapi.Load() + if err != nil { + return err + } + err = procGetExtendedTcpTable.Find() + if err != nil { + return err + } + err = procGetExtendedUdpTable.Find() + if err != nil { + return err + } + return nil +} + +func FindPid(network string, source netip.AddrPort) (uint32, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + if source.Addr().Is4() { + tcpTable, err := GetExtendedTcpTable() + if err != nil { + return 0, err + } + for _, row := range tcpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } else { + tcpTable, err := GetExtendedTcp6Table() + if err != nil { + return 0, err + } + for _, row := range tcpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } + case N.NetworkUDP: + if source.Addr().Is4() { + udpTable, err := GetExtendedUdpTable() + if err != nil { + return 0, err + } + for _, row := range udpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } else { + udpTable, err := GetExtendedUdp6Table() + if err != nil { + return 0, err + } + for _, row := range udpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } + } + return 0, E.New("process not found for ", source) +} + +func WriteAndWaitAck(ctx context.Context, conn net.Conn, payload []byte) error { + source := M.AddrPortFromNet(conn.LocalAddr()) + destination := M.AddrPortFromNet(conn.RemoteAddr()) + if source.Addr().Is4() { + tcpTable, err := GetTcpTable() + if err != nil { + return err + } + var tcpRow *MibTcpRow + for _, row := range tcpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) || + destination == netip.AddrPortFrom(DwordToAddr(row.DwRemoteAddr), DwordToPort(row.DwRemotePort)) { + tcpRow = &row + break + } + } + if tcpRow == nil { + return E.New("row not found for: ", source) + } + err = SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: true, + }) + if err != nil { + return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err) + } + defer SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: false, + }) + _, err = conn.Write(payload) + if err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + eStstsSendBuffer, err := GetPerTcpConnectionEStatsSendBuffer(tcpRow) + if err != nil { + return err + } + if eStstsSendBuffer.CurRetxQueue == 0 { + return nil + } + time.Sleep(10 * time.Millisecond) + } + } else { + tcpTable, err := GetTcp6Table() + if err != nil { + return err + } + var tcpRow *MibTcp6Row + for _, row := range tcpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.LocalAddr), DwordToPort(row.LocalPort)) || + destination == netip.AddrPortFrom(netip.AddrFrom16(row.RemoteAddr), DwordToPort(row.RemotePort)) { + tcpRow = &row + break + } + } + if tcpRow == nil { + return E.New("row not found for: ", source) + } + err = SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: true, + }) + if err != nil { + return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err) + } + defer SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: false, + }) + _, err = conn.Write(payload) + if err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + eStstsSendBuffer, err := GetPerTcp6ConnectionEStatsSendBuffer(tcpRow) + if err != nil { + return err + } + if eStstsSendBuffer.CurRetxQueue == 0 { + return nil + } + time.Sleep(10 * time.Millisecond) + } + } +} + +func DwordToAddr(addr uint32) netip.Addr { + return netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(&addr))) +} + +func DwordToPort(dword uint32) uint16 { + return binary.BigEndian.Uint16((*[4]byte)(unsafe.Pointer(&dword))[:]) +} diff --git a/common/winiphlpapi/iphlpapi.go b/common/winiphlpapi/iphlpapi.go new file mode 100644 index 0000000..74e5b90 --- /dev/null +++ b/common/winiphlpapi/iphlpapi.go @@ -0,0 +1,313 @@ +//go:build windows + +package winiphlpapi + +import ( + "errors" + "os" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + TcpTableBasicListener uint32 = iota + TcpTableBasicConnections + TcpTableBasicAll + TcpTableOwnerPidListener + TcpTableOwnerPidConnections + TcpTableOwnerPidAll + TcpTableOwnerModuleListener + TcpTableOwnerModuleConnections + TcpTableOwnerModuleAll +) + +const ( + UdpTableBasic uint32 = iota + UdpTableOwnerPid + UdpTableOwnerModule +) + +const ( + TcpConnectionEstatsSynOpts uint32 = iota + TcpConnectionEstatsData + TcpConnectionEstatsSndCong + TcpConnectionEstatsPath + TcpConnectionEstatsSendBuff + TcpConnectionEstatsRec + TcpConnectionEstatsObsRec + TcpConnectionEstatsBandwidth + TcpConnectionEstatsFineRtt + TcpConnectionEstatsMaximum +) + +type MibTcpTable struct { + DwNumEntries uint32 + Table [1]MibTcpRow +} + +type MibTcpRow struct { + DwState uint32 + DwLocalAddr uint32 + DwLocalPort uint32 + DwRemoteAddr uint32 + DwRemotePort uint32 +} + +type MibTcp6Table struct { + DwNumEntries uint32 + Table [1]MibTcp6Row +} + +type MibTcp6Row struct { + State uint32 + LocalAddr [16]byte + LocalScopeId uint32 + LocalPort uint32 + RemoteAddr [16]byte + RemoteScopeId uint32 + RemotePort uint32 +} + +type MibTcpTableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibTcpRowOwnerPid +} + +type MibTcpRowOwnerPid struct { + DwState uint32 + DwLocalAddr uint32 + DwLocalPort uint32 + DwRemoteAddr uint32 + DwRemotePort uint32 + DwOwningPid uint32 +} + +type MibTcp6TableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibTcp6RowOwnerPid +} + +type MibTcp6RowOwnerPid struct { + UcLocalAddr [16]byte + DwLocalScopeId uint32 + DwLocalPort uint32 + UcRemoteAddr [16]byte + DwRemoteScopeId uint32 + DwRemotePort uint32 + DwState uint32 + DwOwningPid uint32 +} + +type MibUdpTableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibUdpRowOwnerPid +} + +type MibUdpRowOwnerPid struct { + DwLocalAddr uint32 + DwLocalPort uint32 + DwOwningPid uint32 +} + +type MibUdp6TableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibUdp6RowOwnerPid +} + +type MibUdp6RowOwnerPid struct { + UcLocalAddr [16]byte + DwLocalScopeId uint32 + DwLocalPort uint32 + DwOwningPid uint32 +} + +type TcpEstatsSendBufferRodV0 struct { + CurRetxQueue uint64 + MaxRetxQueue uint64 + CurAppWQueue uint64 + MaxAppWQueue uint64 +} + +type TcpEstatsSendBuffRwV0 struct { + EnableCollection bool +} + +const ( + offsetOfMibTcpTable = unsafe.Offsetof(MibTcpTable{}.Table) + offsetOfMibTcp6Table = unsafe.Offsetof(MibTcp6Table{}.Table) + offsetOfMibTcpTableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table) + offsetOfMibTcp6TableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table) + offsetOfMibUdpTableOwnerPid = unsafe.Offsetof(MibUdpTableOwnerPid{}.Table) + offsetOfMibUdp6TableOwnerPid = unsafe.Offsetof(MibUdp6TableOwnerPid{}.Table) + sizeOfTcpEstatsSendBuffRwV0 = unsafe.Sizeof(TcpEstatsSendBuffRwV0{}) + sizeOfTcpEstatsSendBufferRodV0 = unsafe.Sizeof(TcpEstatsSendBufferRodV0{}) +) + +func GetTcpTable() ([]MibTcpRow, error) { + var size uint32 + err := getTcpTable(nil, &size, false) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, err + } + for { + table := make([]byte, size) + err = getTcpTable(&table[0], &size, false) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, err + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcpRow)(unsafe.Pointer(&table[offsetOfMibTcpTable])), dwNumEntries), nil + } +} + +func GetTcp6Table() ([]MibTcp6Row, error) { + var size uint32 + err := getTcp6Table(nil, &size, false) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, err + } + for { + table := make([]byte, size) + err = getTcp6Table(&table[0], &size, false) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, err + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcp6Row)(unsafe.Pointer(&table[offsetOfMibTcp6Table])), dwNumEntries), nil + } +} + +func GetExtendedTcpTable() ([]MibTcpRowOwnerPid, error) { + var size uint32 + err := getExtendedTcpTable(nil, &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcpTableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedTcp6Table() ([]MibTcp6RowOwnerPid, error) { + var size uint32 + err := getExtendedTcpTable(nil, &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcp6TableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedUdpTable() ([]MibUdpRowOwnerPid, error) { + var size uint32 + err := getExtendedUdpTable(nil, &size, false, windows.AF_INET, UdpTableOwnerPid, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET, UdpTableOwnerPid, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdpTableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedUdp6Table() ([]MibUdp6RowOwnerPid, error) { + var size uint32 + err := getExtendedUdpTable(nil, &size, false, windows.AF_INET6, UdpTableOwnerPid, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET6, UdpTableOwnerPid, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdp6TableOwnerPid])), dwNumEntries), nil + } +} + +func GetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow) (*TcpEstatsSendBufferRodV0, error) { + var rod TcpEstatsSendBufferRodV0 + err := getPerTcpConnectionEStats(row, + TcpConnectionEstatsSendBuff, + 0, + 0, + 0, + 0, + 0, + 0, + uintptr(unsafe.Pointer(&rod)), + 0, + uint64(sizeOfTcpEstatsSendBufferRodV0), + ) + if err != nil { + return nil, err + } + return &rod, nil +} + +func GetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row) (*TcpEstatsSendBufferRodV0, error) { + var rod TcpEstatsSendBufferRodV0 + err := getPerTcp6ConnectionEStats(row, + TcpConnectionEstatsSendBuff, + 0, + 0, + 0, + 0, + 0, + 0, + uintptr(unsafe.Pointer(&rod)), + 0, + uint64(sizeOfTcpEstatsSendBufferRodV0), + ) + if err != nil { + return nil, err + } + return &rod, nil +} + +func SetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow, rw *TcpEstatsSendBuffRwV0) error { + return setPerTcpConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0) +} + +func SetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row, rw *TcpEstatsSendBuffRwV0) error { + return setPerTcp6ConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0) +} diff --git a/common/winiphlpapi/iphlpapi_test.go b/common/winiphlpapi/iphlpapi_test.go new file mode 100644 index 0000000..5fc3b74 --- /dev/null +++ b/common/winiphlpapi/iphlpapi_test.go @@ -0,0 +1,90 @@ +//go:build windows + +package winiphlpapi_test + +import ( + "context" + "net" + "syscall" + "testing" + + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/winiphlpapi" + + "github.com/stretchr/testify/require" +) + +func TestFindPidTcp4(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidTcp6(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "[::1]:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidUdp4(t *testing.T) { + t.Parallel() + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidUdp6(t *testing.T) { + t.Parallel() + conn, err := net.ListenPacket("udp", "[::1]:0") + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestWaitAck4(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello")) + require.NoError(t, err) +} + +func TestWaitAck6(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "[::1]:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello")) + require.NoError(t, err) +} diff --git a/common/winiphlpapi/syscall_windows.go b/common/winiphlpapi/syscall_windows.go new file mode 100644 index 0000000..f6aab14 --- /dev/null +++ b/common/winiphlpapi/syscall_windows.go @@ -0,0 +1,27 @@ +package winiphlpapi + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcptable +//sys getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcpTable + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcp6table +//sys getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcp6Table + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcpconnectionestats +//sys getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcpConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcp6connectionestats +//sys getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcp6ConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcpconnectionestats +//sys setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcpConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcp6connectionestats +//sys setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcp6ConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable +//sys getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedTcpTable + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedudptable +//sys getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedUdpTable diff --git a/common/winiphlpapi/zsyscall_windows.go b/common/winiphlpapi/zsyscall_windows.go new file mode 100644 index 0000000..e5e9308 --- /dev/null +++ b/common/winiphlpapi/zsyscall_windows.go @@ -0,0 +1,131 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package winiphlpapi + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") + + procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable") + procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable") + procGetPerTcp6ConnectionEStats = modiphlpapi.NewProc("GetPerTcp6ConnectionEStats") + procGetPerTcpConnectionEStats = modiphlpapi.NewProc("GetPerTcpConnectionEStats") + procGetTcp6Table = modiphlpapi.NewProc("GetTcp6Table") + procGetTcpTable = modiphlpapi.NewProc("GetTcpTable") + procSetPerTcp6ConnectionEStats = modiphlpapi.NewProc("SetPerTcp6ConnectionEStats") + procSetPerTcpConnectionEStats = modiphlpapi.NewProc("SetPerTcpConnectionEStats") +) + +func getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) { + var _p0 uint32 + if bOrder { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procGetExtendedTcpTable.Addr(), 6, uintptr(unsafe.Pointer(pTcpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) { + var _p0 uint32 + if bOrder { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procGetExtendedUdpTable.Addr(), 6, uintptr(unsafe.Pointer(pUdpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) { + r0, _, _ := syscall.Syscall12(procGetPerTcp6ConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) { + r0, _, _ := syscall.Syscall12(procGetPerTcpConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) { + var _p0 uint32 + if order { + _p0 = 1 + } + r0, _, _ := syscall.Syscall(procGetTcp6Table.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) { + var _p0 uint32 + if order { + _p0 = 1 + } + r0, _, _ := syscall.Syscall(procGetTcpTable.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) { + r0, _, _ := syscall.Syscall6(procSetPerTcp6ConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) { + r0, _, _ := syscall.Syscall6(procSetPerTcpConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} From 27203c52d57d3097ecfce50d7793f05d2d81e570 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 10 Feb 2025 18:59:17 +0800 Subject: [PATCH 2/4] Fix merge objects --- common/json/badjson/merge_objects.go | 10 +++---- common/json/internal/contextjson/keys.go | 20 ++++++++++++++ common/json/internal/contextjson/keys_test.go | 26 +++++++++++++++++++ 3 files changed, 50 insertions(+), 6 deletions(-) create mode 100644 common/json/internal/contextjson/keys.go create mode 100644 common/json/internal/contextjson/keys_test.go diff --git a/common/json/badjson/merge_objects.go b/common/json/badjson/merge_objects.go index fa6c2d4..5b23209 100644 --- a/common/json/badjson/merge_objects.go +++ b/common/json/badjson/merge_objects.go @@ -2,9 +2,11 @@ package badjson import ( "context" + "reflect" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" + cJSON "github.com/sagernet/sing/common/json/internal/contextjson" ) func MarshallObjects(objects ...any) ([]byte, error) { @@ -31,16 +33,12 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error } func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error { - parentContent, err := newJSONObject(ctx, parentObject) - if err != nil { - return err - } var content JSONObject - err = content.UnmarshalJSONContext(ctx, inputContent) + err := content.UnmarshalJSONContext(ctx, inputContent) if err != nil { return err } - for _, key := range parentContent.Keys() { + for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) { content.Remove(key) } if object == nil { diff --git a/common/json/internal/contextjson/keys.go b/common/json/internal/contextjson/keys.go new file mode 100644 index 0000000..589007f --- /dev/null +++ b/common/json/internal/contextjson/keys.go @@ -0,0 +1,20 @@ +package json + +import ( + "reflect" + + "github.com/sagernet/sing/common" +) + +func ObjectKeys(object reflect.Type) []string { + switch object.Kind() { + case reflect.Pointer: + return ObjectKeys(object.Elem()) + case reflect.Struct: + default: + panic("invalid non-struct input") + } + return common.Map(cachedTypeFields(object).list, func(field field) string { + return field.name + }) +} diff --git a/common/json/internal/contextjson/keys_test.go b/common/json/internal/contextjson/keys_test.go new file mode 100644 index 0000000..11c72cb --- /dev/null +++ b/common/json/internal/contextjson/keys_test.go @@ -0,0 +1,26 @@ +package json_test + +import ( + "reflect" + "testing" + + json "github.com/sagernet/sing/common/json/internal/contextjson" + + "github.com/stretchr/testify/require" +) + +type MyObject struct { + Hello string `json:"hello,omitempty"` + MyWorld + MyWorld2 string `json:"-"` +} + +type MyWorld struct { + World string `json:"world,omitempty"` +} + +func TestObjectKeys(t *testing.T) { + t.Parallel() + keys := json.ObjectKeys(reflect.TypeOf(&MyObject{})) + require.Equal(t, []string{"hello", "world"}, keys) +} From 1c3b777fe509bde0b8c451c68a054a182f9a5ba2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 24 Feb 2025 18:25:10 +0800 Subject: [PATCH 3/4] Add freelru.GetWithLifetimeNoExpire --- contrab/freelru/cache.go | 2 ++ contrab/freelru/lru.go | 18 ++++++++++++++++++ contrab/freelru/shardedlru.go | 11 +++++++++++ contrab/freelru/syncedlru.go | 10 ++++++++++ 4 files changed, 41 insertions(+) diff --git a/contrab/freelru/cache.go b/contrab/freelru/cache.go index e1877fb..22488e0 100644 --- a/contrab/freelru/cache.go +++ b/contrab/freelru/cache.go @@ -49,6 +49,8 @@ type Cache[K comparable, V comparable] interface { GetWithLifetime(key K) (V, time.Time, bool) + GetWithLifetimeNoExpire(key K) (V, time.Time, bool) + // GetAndRefresh returns the value associated with the key, setting it as the most // recently used item. // The lifetime of the found cache item is refreshed, even if it was already expired. diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index 6c31857..055057c 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -500,6 +500,24 @@ func (lru *LRU[K, V]) getWithLifetime(hash uint32, key K) (value V, lifetime tim return } +func (lru *LRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) { + return lru.getWithLifetimeNoExpire(lru.hash(key), key) +} + +func (lru *LRU[K, V]) getWithLifetimeNoExpire(hash uint32, key K) (value V, lifetime time.Time, ok bool) { + if pos, ok := lru.findKeyNoExpire(hash, key); ok { + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Hits++ + return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok + } + + lru.metrics.Misses++ + return +} + // GetAndRefresh returns the value associated with the key, setting it as the most // recently used item. // The lifetime of the found cache item is refreshed, even if it was already expired. diff --git a/contrab/freelru/shardedlru.go b/contrab/freelru/shardedlru.go index db97efa..fa325d7 100644 --- a/contrab/freelru/shardedlru.go +++ b/contrab/freelru/shardedlru.go @@ -187,6 +187,17 @@ func (lru *ShardedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time return } +func (lru *ShardedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].RLock() + value, lifetime, ok = lru.lrus[shard].getWithLifetimeNoExpire(hash, key) + lru.mus[shard].RUnlock() + + return +} + // GetAndRefresh returns the value associated with the key, setting it as the most // recently used item. // The lifetime of the found cache item is refreshed, even if it was already expired. diff --git a/contrab/freelru/syncedlru.go b/contrab/freelru/syncedlru.go index 38a854a..4364907 100644 --- a/contrab/freelru/syncedlru.go +++ b/contrab/freelru/syncedlru.go @@ -108,6 +108,16 @@ func (lru *SyncedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, return } +func (lru *SyncedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, lifetime, ok = lru.lru.getWithLifetimeNoExpire(hash, key) + lru.mu.Unlock() + + return +} + // GetAndRefresh returns the value associated with the key, setting it as the most // recently used item. // The lifetime of the found cache item is refreshed, even if it was already expired. From 5704b6ce4d670cbadc29cf26c57fe0f01a144fac Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Tue, 11 Mar 2025 17:40:15 +0000 Subject: [PATCH 4/4] [dependencies] Update dependency go to ~1.24.0 --- .github/workflows/test.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3afe968..cf50653 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -41,7 +41,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: ~1.20 + go-version: ~1.24.0 continue-on-error: true - name: Build run: | @@ -57,7 +57,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: ~1.21 + go-version: ~1.24.0 continue-on-error: true - name: Build run: | @@ -73,7 +73,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: ~1.22 + go-version: ~1.24.0 continue-on-error: true - name: Build run: |