Compare commits

...

8 commits
v0.6.3 ... dev

Author SHA1 Message Date
世界
159e489fc3
Add E.Cause1 2025-04-03 18:26:45 +08:00
世界
d39c2c2fdd
socks: Add custom udp listener 2025-03-26 13:18:24 +08:00
世界
ea82ac275f
Add freelru.GetWithLifetimeNoExpire 2025-03-26 13:18:18 +08:00
世界
ea0ac932ae
Add winiphlpapi 2025-03-26 13:18:17 +08:00
世界
2b41455f5a
Fix udpnat2 handler again 2025-03-26 12:46:15 +08:00
世界
23b0180a1b
Fix crash on udpnat2 handler 2025-03-24 18:11:10 +08:00
世界
ce1b4851a4
Fix socks5 UDP 2025-03-16 10:23:29 +08:00
世界
2238a05966
Fix merge objects 2025-03-16 10:23:29 +08:00
19 changed files with 1089 additions and 21 deletions

View file

@ -12,3 +12,16 @@ func (e *causeError) Error() string {
func (e *causeError) Unwrap() error {
return e.cause
}
type causeError1 struct {
error
cause error
}
func (e *causeError1) Error() string {
return e.error.Error() + ": " + e.cause.Error()
}
func (e *causeError1) Unwrap() []error {
return []error{e.error, e.cause}
}

View file

@ -32,6 +32,13 @@ func Cause(cause error, message ...any) error {
return &causeError{F.ToString(message...), cause}
}
func Cause1(err error, cause error) error {
if cause == nil {
panic("cause on an nil error")
}
return &causeError1{err, cause}
}
func Extend(cause error, message ...any) error {
if cause == nil {
panic("extend on an nil error")

View file

@ -2,9 +2,11 @@ package badjson
import (
"context"
"reflect"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
cJSON "github.com/sagernet/sing/common/json/internal/contextjson"
)
func MarshallObjects(objects ...any) ([]byte, error) {
@ -31,16 +33,12 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error
}
func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
parentContent, err := newJSONObject(ctx, parentObject)
if err != nil {
return err
}
var content JSONObject
err = content.UnmarshalJSONContext(ctx, inputContent)
err := content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return err
}
for _, key := range parentContent.Keys() {
for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) {
content.Remove(key)
}
if object == nil {

View file

@ -0,0 +1,20 @@
package json
import (
"reflect"
"github.com/sagernet/sing/common"
)
func ObjectKeys(object reflect.Type) []string {
switch object.Kind() {
case reflect.Pointer:
return ObjectKeys(object.Elem())
case reflect.Struct:
default:
panic("invalid non-struct input")
}
return common.Map(cachedTypeFields(object).list, func(field field) string {
return field.name
})
}

View file

@ -0,0 +1,26 @@
package json_test
import (
"reflect"
"testing"
json "github.com/sagernet/sing/common/json/internal/contextjson"
"github.com/stretchr/testify/require"
)
type MyObject struct {
Hello string `json:"hello,omitempty"`
MyWorld
MyWorld2 string `json:"-"`
}
type MyWorld struct {
World string `json:"world,omitempty"`
}
func TestObjectKeys(t *testing.T) {
t.Parallel()
keys := json.ObjectKeys(reflect.TypeOf(&MyObject{}))
require.Equal(t, []string{"hello", "world"}, keys)
}

View file

@ -28,6 +28,7 @@ type natConn struct {
cache freelru.Cache[netip.AddrPort, *natConn]
writer N.PacketWriter
localAddr M.Socksaddr
handlerAccess sync.RWMutex
handler N.UDPHandlerEx
packetChan chan *N.PacketBuffer
closeOnce sync.Once
@ -75,12 +76,10 @@ func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr,
}
func (c *natConn) SetHandler(handler N.UDPHandlerEx) {
select {
case <-c.doneChan:
default:
}
c.handlerAccess.Lock()
c.handler = handler
c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler)
c.handlerAccess.Unlock()
fetch:
for {
select {

View file

@ -74,8 +74,11 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati
for _, bufferSlice := range bufferSlices {
buffer.Write(bufferSlice)
}
if conn.handler != nil {
conn.handler.NewPacketEx(buffer, destination)
conn.handlerAccess.RLock()
handler := conn.handler
conn.handlerAccess.RUnlock()
if handler != nil {
handler.NewPacketEx(buffer, destination)
return
}
packet := N.NewPacketBuffer()

View file

@ -1,16 +1,14 @@
//go:build windows
package windnsapi
import (
"runtime"
"testing"
"github.com/stretchr/testify/require"
)
func TestDNSAPI(t *testing.T) {
if runtime.GOOS != "windows" {
t.SkipNow()
}
t.Parallel()
require.NoError(t, FlushResolverCache())
}

View file

@ -0,0 +1,217 @@
//go:build windows
package winiphlpapi
import (
"context"
"encoding/binary"
"net"
"net/netip"
"os"
"time"
"unsafe"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func LoadEStats() error {
err := modiphlpapi.Load()
if err != nil {
return err
}
err = procGetTcpTable.Find()
if err != nil {
return err
}
err = procGetTcp6Table.Find()
if err != nil {
return err
}
err = procGetPerTcp6ConnectionEStats.Find()
if err != nil {
return err
}
err = procGetPerTcp6ConnectionEStats.Find()
if err != nil {
return err
}
err = procSetPerTcpConnectionEStats.Find()
if err != nil {
return err
}
err = procSetPerTcp6ConnectionEStats.Find()
if err != nil {
return err
}
return nil
}
func LoadExtendedTable() error {
err := modiphlpapi.Load()
if err != nil {
return err
}
err = procGetExtendedTcpTable.Find()
if err != nil {
return err
}
err = procGetExtendedUdpTable.Find()
if err != nil {
return err
}
return nil
}
func FindPid(network string, source netip.AddrPort) (uint32, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
if source.Addr().Is4() {
tcpTable, err := GetExtendedTcpTable()
if err != nil {
return 0, err
}
for _, row := range tcpTable {
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
} else {
tcpTable, err := GetExtendedTcp6Table()
if err != nil {
return 0, err
}
for _, row := range tcpTable {
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
}
case N.NetworkUDP:
if source.Addr().Is4() {
udpTable, err := GetExtendedUdpTable()
if err != nil {
return 0, err
}
for _, row := range udpTable {
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
} else {
udpTable, err := GetExtendedUdp6Table()
if err != nil {
return 0, err
}
for _, row := range udpTable {
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
}
}
return 0, E.New("process not found for ", source)
}
func WriteAndWaitAck(ctx context.Context, conn net.Conn, payload []byte) error {
source := M.AddrPortFromNet(conn.LocalAddr())
destination := M.AddrPortFromNet(conn.RemoteAddr())
if source.Addr().Is4() {
tcpTable, err := GetTcpTable()
if err != nil {
return err
}
var tcpRow *MibTcpRow
for _, row := range tcpTable {
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) ||
destination == netip.AddrPortFrom(DwordToAddr(row.DwRemoteAddr), DwordToPort(row.DwRemotePort)) {
tcpRow = &row
break
}
}
if tcpRow == nil {
return E.New("row not found for: ", source)
}
err = SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: true,
})
if err != nil {
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
}
defer SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: false,
})
_, err = conn.Write(payload)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
eStstsSendBuffer, err := GetPerTcpConnectionEStatsSendBuffer(tcpRow)
if err != nil {
return err
}
if eStstsSendBuffer.CurRetxQueue == 0 {
return nil
}
time.Sleep(10 * time.Millisecond)
}
} else {
tcpTable, err := GetTcp6Table()
if err != nil {
return err
}
var tcpRow *MibTcp6Row
for _, row := range tcpTable {
if source == netip.AddrPortFrom(netip.AddrFrom16(row.LocalAddr), DwordToPort(row.LocalPort)) ||
destination == netip.AddrPortFrom(netip.AddrFrom16(row.RemoteAddr), DwordToPort(row.RemotePort)) {
tcpRow = &row
break
}
}
if tcpRow == nil {
return E.New("row not found for: ", source)
}
err = SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: true,
})
if err != nil {
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
}
defer SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: false,
})
_, err = conn.Write(payload)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
eStstsSendBuffer, err := GetPerTcp6ConnectionEStatsSendBuffer(tcpRow)
if err != nil {
return err
}
if eStstsSendBuffer.CurRetxQueue == 0 {
return nil
}
time.Sleep(10 * time.Millisecond)
}
}
}
func DwordToAddr(addr uint32) netip.Addr {
return netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(&addr)))
}
func DwordToPort(dword uint32) uint16 {
return binary.BigEndian.Uint16((*[4]byte)(unsafe.Pointer(&dword))[:])
}

View file

@ -0,0 +1,313 @@
//go:build windows
package winiphlpapi
import (
"errors"
"os"
"unsafe"
"golang.org/x/sys/windows"
)
const (
TcpTableBasicListener uint32 = iota
TcpTableBasicConnections
TcpTableBasicAll
TcpTableOwnerPidListener
TcpTableOwnerPidConnections
TcpTableOwnerPidAll
TcpTableOwnerModuleListener
TcpTableOwnerModuleConnections
TcpTableOwnerModuleAll
)
const (
UdpTableBasic uint32 = iota
UdpTableOwnerPid
UdpTableOwnerModule
)
const (
TcpConnectionEstatsSynOpts uint32 = iota
TcpConnectionEstatsData
TcpConnectionEstatsSndCong
TcpConnectionEstatsPath
TcpConnectionEstatsSendBuff
TcpConnectionEstatsRec
TcpConnectionEstatsObsRec
TcpConnectionEstatsBandwidth
TcpConnectionEstatsFineRtt
TcpConnectionEstatsMaximum
)
type MibTcpTable struct {
DwNumEntries uint32
Table [1]MibTcpRow
}
type MibTcpRow struct {
DwState uint32
DwLocalAddr uint32
DwLocalPort uint32
DwRemoteAddr uint32
DwRemotePort uint32
}
type MibTcp6Table struct {
DwNumEntries uint32
Table [1]MibTcp6Row
}
type MibTcp6Row struct {
State uint32
LocalAddr [16]byte
LocalScopeId uint32
LocalPort uint32
RemoteAddr [16]byte
RemoteScopeId uint32
RemotePort uint32
}
type MibTcpTableOwnerPid struct {
DwNumEntries uint32
Table [1]MibTcpRowOwnerPid
}
type MibTcpRowOwnerPid struct {
DwState uint32
DwLocalAddr uint32
DwLocalPort uint32
DwRemoteAddr uint32
DwRemotePort uint32
DwOwningPid uint32
}
type MibTcp6TableOwnerPid struct {
DwNumEntries uint32
Table [1]MibTcp6RowOwnerPid
}
type MibTcp6RowOwnerPid struct {
UcLocalAddr [16]byte
DwLocalScopeId uint32
DwLocalPort uint32
UcRemoteAddr [16]byte
DwRemoteScopeId uint32
DwRemotePort uint32
DwState uint32
DwOwningPid uint32
}
type MibUdpTableOwnerPid struct {
DwNumEntries uint32
Table [1]MibUdpRowOwnerPid
}
type MibUdpRowOwnerPid struct {
DwLocalAddr uint32
DwLocalPort uint32
DwOwningPid uint32
}
type MibUdp6TableOwnerPid struct {
DwNumEntries uint32
Table [1]MibUdp6RowOwnerPid
}
type MibUdp6RowOwnerPid struct {
UcLocalAddr [16]byte
DwLocalScopeId uint32
DwLocalPort uint32
DwOwningPid uint32
}
type TcpEstatsSendBufferRodV0 struct {
CurRetxQueue uint64
MaxRetxQueue uint64
CurAppWQueue uint64
MaxAppWQueue uint64
}
type TcpEstatsSendBuffRwV0 struct {
EnableCollection bool
}
const (
offsetOfMibTcpTable = unsafe.Offsetof(MibTcpTable{}.Table)
offsetOfMibTcp6Table = unsafe.Offsetof(MibTcp6Table{}.Table)
offsetOfMibTcpTableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table)
offsetOfMibTcp6TableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table)
offsetOfMibUdpTableOwnerPid = unsafe.Offsetof(MibUdpTableOwnerPid{}.Table)
offsetOfMibUdp6TableOwnerPid = unsafe.Offsetof(MibUdp6TableOwnerPid{}.Table)
sizeOfTcpEstatsSendBuffRwV0 = unsafe.Sizeof(TcpEstatsSendBuffRwV0{})
sizeOfTcpEstatsSendBufferRodV0 = unsafe.Sizeof(TcpEstatsSendBufferRodV0{})
)
func GetTcpTable() ([]MibTcpRow, error) {
var size uint32
err := getTcpTable(nil, &size, false)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, err
}
for {
table := make([]byte, size)
err = getTcpTable(&table[0], &size, false)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, err
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcpRow)(unsafe.Pointer(&table[offsetOfMibTcpTable])), dwNumEntries), nil
}
}
func GetTcp6Table() ([]MibTcp6Row, error) {
var size uint32
err := getTcp6Table(nil, &size, false)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, err
}
for {
table := make([]byte, size)
err = getTcp6Table(&table[0], &size, false)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, err
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcp6Row)(unsafe.Pointer(&table[offsetOfMibTcp6Table])), dwNumEntries), nil
}
}
func GetExtendedTcpTable() ([]MibTcpRowOwnerPid, error) {
var size uint32
err := getExtendedTcpTable(nil, &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcpTableOwnerPid])), dwNumEntries), nil
}
}
func GetExtendedTcp6Table() ([]MibTcp6RowOwnerPid, error) {
var size uint32
err := getExtendedTcpTable(nil, &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcp6TableOwnerPid])), dwNumEntries), nil
}
}
func GetExtendedUdpTable() ([]MibUdpRowOwnerPid, error) {
var size uint32
err := getExtendedUdpTable(nil, &size, false, windows.AF_INET, UdpTableOwnerPid, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET, UdpTableOwnerPid, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdpTableOwnerPid])), dwNumEntries), nil
}
}
func GetExtendedUdp6Table() ([]MibUdp6RowOwnerPid, error) {
var size uint32
err := getExtendedUdpTable(nil, &size, false, windows.AF_INET6, UdpTableOwnerPid, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET6, UdpTableOwnerPid, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdp6TableOwnerPid])), dwNumEntries), nil
}
}
func GetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow) (*TcpEstatsSendBufferRodV0, error) {
var rod TcpEstatsSendBufferRodV0
err := getPerTcpConnectionEStats(row,
TcpConnectionEstatsSendBuff,
0,
0,
0,
0,
0,
0,
uintptr(unsafe.Pointer(&rod)),
0,
uint64(sizeOfTcpEstatsSendBufferRodV0),
)
if err != nil {
return nil, err
}
return &rod, nil
}
func GetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row) (*TcpEstatsSendBufferRodV0, error) {
var rod TcpEstatsSendBufferRodV0
err := getPerTcp6ConnectionEStats(row,
TcpConnectionEstatsSendBuff,
0,
0,
0,
0,
0,
0,
uintptr(unsafe.Pointer(&rod)),
0,
uint64(sizeOfTcpEstatsSendBufferRodV0),
)
if err != nil {
return nil, err
}
return &rod, nil
}
func SetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow, rw *TcpEstatsSendBuffRwV0) error {
return setPerTcpConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0)
}
func SetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row, rw *TcpEstatsSendBuffRwV0) error {
return setPerTcp6ConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0)
}

View file

@ -0,0 +1,90 @@
//go:build windows
package winiphlpapi_test
import (
"context"
"net"
"syscall"
"testing"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/winiphlpapi"
"github.com/stretchr/testify/require"
)
func TestFindPidTcp4(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}
func TestFindPidTcp6(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "[::1]:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}
func TestFindPidUdp4(t *testing.T) {
t.Parallel()
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}
func TestFindPidUdp6(t *testing.T) {
t.Parallel()
conn, err := net.ListenPacket("udp", "[::1]:0")
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}
func TestWaitAck4(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello"))
require.NoError(t, err)
}
func TestWaitAck6(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "[::1]:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello"))
require.NoError(t, err)
}

View file

@ -0,0 +1,27 @@
package winiphlpapi
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcptable
//sys getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcpTable
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcp6table
//sys getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcp6Table
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcpconnectionestats
//sys getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcpConnectionEStats
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcp6connectionestats
//sys getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcp6ConnectionEStats
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcpconnectionestats
//sys setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcpConnectionEStats
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcp6connectionestats
//sys setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcp6ConnectionEStats
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
//sys getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedTcpTable
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedudptable
//sys getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedUdpTable

View file

@ -0,0 +1,131 @@
// Code generated by 'go generate'; DO NOT EDIT.
package winiphlpapi
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable")
procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable")
procGetPerTcp6ConnectionEStats = modiphlpapi.NewProc("GetPerTcp6ConnectionEStats")
procGetPerTcpConnectionEStats = modiphlpapi.NewProc("GetPerTcpConnectionEStats")
procGetTcp6Table = modiphlpapi.NewProc("GetTcp6Table")
procGetTcpTable = modiphlpapi.NewProc("GetTcpTable")
procSetPerTcp6ConnectionEStats = modiphlpapi.NewProc("SetPerTcp6ConnectionEStats")
procSetPerTcpConnectionEStats = modiphlpapi.NewProc("SetPerTcpConnectionEStats")
)
func getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) {
var _p0 uint32
if bOrder {
_p0 = 1
}
r0, _, _ := syscall.Syscall6(procGetExtendedTcpTable.Addr(), 6, uintptr(unsafe.Pointer(pTcpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) {
var _p0 uint32
if bOrder {
_p0 = 1
}
r0, _, _ := syscall.Syscall6(procGetExtendedUdpTable.Addr(), 6, uintptr(unsafe.Pointer(pUdpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) {
r0, _, _ := syscall.Syscall12(procGetPerTcp6ConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0)
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) {
r0, _, _ := syscall.Syscall12(procGetPerTcpConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0)
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) {
var _p0 uint32
if order {
_p0 = 1
}
r0, _, _ := syscall.Syscall(procGetTcp6Table.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) {
var _p0 uint32
if order {
_p0 = 1
}
r0, _, _ := syscall.Syscall(procGetTcpTable.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) {
r0, _, _ := syscall.Syscall6(procSetPerTcp6ConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) {
r0, _, _ := syscall.Syscall6(procSetPerTcpConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}

View file

@ -49,6 +49,8 @@ type Cache[K comparable, V comparable] interface {
GetWithLifetime(key K) (V, time.Time, bool)
GetWithLifetimeNoExpire(key K) (V, time.Time, bool)
// GetAndRefresh returns the value associated with the key, setting it as the most
// recently used item.
// The lifetime of the found cache item is refreshed, even if it was already expired.

View file

@ -500,6 +500,24 @@ func (lru *LRU[K, V]) getWithLifetime(hash uint32, key K) (value V, lifetime tim
return
}
func (lru *LRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) {
return lru.getWithLifetimeNoExpire(lru.hash(key), key)
}
func (lru *LRU[K, V]) getWithLifetimeNoExpire(hash uint32, key K) (value V, lifetime time.Time, ok bool) {
if pos, ok := lru.findKeyNoExpire(hash, key); ok {
if pos != lru.head {
lru.unlinkElement(pos)
lru.setHead(pos)
}
lru.metrics.Hits++
return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok
}
lru.metrics.Misses++
return
}
// GetAndRefresh returns the value associated with the key, setting it as the most
// recently used item.
// The lifetime of the found cache item is refreshed, even if it was already expired.

View file

@ -187,6 +187,17 @@ func (lru *ShardedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time
return
}
func (lru *ShardedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].RLock()
value, lifetime, ok = lru.lrus[shard].getWithLifetimeNoExpire(hash, key)
lru.mus[shard].RUnlock()
return
}
// GetAndRefresh returns the value associated with the key, setting it as the most
// recently used item.
// The lifetime of the found cache item is refreshed, even if it was already expired.

View file

@ -108,6 +108,16 @@ func (lru *SyncedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time,
return
}
func (lru *SyncedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
value, lifetime, ok = lru.lru.getWithLifetimeNoExpire(hash, key)
lru.mu.Unlock()
return
}
// GetAndRefresh returns the value associated with the key, setting it as the most
// recently used item.
// The lifetime of the found cache item is refreshed, even if it was already expired.

View file

@ -10,6 +10,7 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
@ -24,6 +25,10 @@ type HandlerEx interface {
N.UDPConnectionHandlerEx
}
type PacketListener interface {
ListenPacket(listenConfig net.ListenConfig, ctx context.Context, network string, address string) (net.PacketConn, error)
}
func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) {
err := socks4.WriteRequest(conn, socks4.Request{
Command: command,
@ -120,6 +125,8 @@ func HandleConnectionEx(
ctx context.Context, conn net.Conn, reader *std_bufio.Reader,
authenticator *auth.Authenticator,
handler HandlerEx,
packetListener PacketListener,
// resolver TorResolver,
source M.Socksaddr,
onClose N.CloseHandlerFunc,
) error {
@ -147,6 +154,11 @@ func HandleConnectionEx(
}
handler.NewConnectionEx(auth.ContextWithUser(ctx, request.Username), NewLazyConn(conn, version), source, request.Destination, onClose)
return nil
/*case CommandTorResolve, CommandTorResolvePTR:
if resolver == nil {
return E.New("socks4: torsocks: commands not implemented")
}
return handleTorSocks4(ctx, conn, request, resolver)*/
default:
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
@ -213,13 +225,40 @@ func HandleConnectionEx(
handler.NewConnectionEx(ctx, NewLazyConn(conn, version), source, request.Destination, onClose)
return nil
case socks5.CommandUDPAssociate:
var udpConn *net.UDPConn
udpConn, err = net.ListenUDP(M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), net.UDPAddrFromAddrPort(netip.AddrPortFrom(M.AddrFromNet(conn.LocalAddr()), 0)))
if err != nil {
return err
var (
listenConfig net.ListenConfig
udpConn net.PacketConn
)
if packetListener != nil {
udpConn, err = packetListener.ListenPacket(listenConfig, ctx, M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), M.SocksaddrFrom(M.AddrFromNet(conn.LocalAddr()), 0).String())
} else {
udpConn, err = listenConfig.ListenPacket(ctx, M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), M.SocksaddrFrom(M.AddrFromNet(conn.LocalAddr()), 0).String())
}
handler.NewPacketConnectionEx(ctx, NewLazyAssociatePacketConn(bufio.NewServerPacketConn(udpConn), conn), source, M.Socksaddr{}, onClose)
if err != nil {
return E.Cause(err, "socks5: listen udp")
}
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.SocksaddrFromNet(udpConn.LocalAddr()),
})
if err != nil {
return E.Cause(err, "socks5: write response")
}
var socksPacketConn N.PacketConn = NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), M.Socksaddr{}, conn)
firstPacket := buf.NewPacket()
var destination M.Socksaddr
destination, err = socksPacketConn.ReadPacket(firstPacket)
if err != nil {
return E.Cause(err, "socks5: read first packet")
}
socksPacketConn = bufio.NewCachedPacketConn(socksPacketConn, firstPacket, destination)
handler.NewPacketConnectionEx(ctx, socksPacketConn, source, destination, onClose)
return nil
/*case CommandTorResolve, CommandTorResolvePTR:
if resolver == nil {
return E.New("socks4: torsocks: commands not implemented")
}
return handleTorSocks5(ctx, conn, request, resolver)*/
default:
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeUnsupported,

View file

@ -0,0 +1,146 @@
package socks
import (
"context"
"net"
"net/netip"
"os"
"strings"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks/socks4"
"github.com/sagernet/sing/protocol/socks/socks5"
)
const (
CommandTorResolve byte = 0xF0
CommandTorResolvePTR byte = 0xF1
)
type TorResolver interface {
LookupIP(ctx context.Context, host string) (netip.Addr, error)
LookupPTR(ctx context.Context, addr netip.Addr) (string, error)
}
func handleTorSocks4(ctx context.Context, conn net.Conn, request socks4.Request, resolver TorResolver) error {
switch request.Command {
case CommandTorResolve:
if !request.Destination.IsFqdn() {
return E.New("socks4: torsocks: invalid destination")
}
ipAddr, err := resolver.LookupIP(ctx, request.Destination.Fqdn)
if err != nil {
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
})
if err != nil {
return err
}
return E.Cause(err, "socks4: torsocks: lookup failed for domain: ", request.Destination.Fqdn)
}
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeGranted,
Destination: M.SocksaddrFrom(ipAddr, 0),
})
if err != nil {
return E.Cause(err, "socks4: torsocks: write response")
}
return nil
case CommandTorResolvePTR:
var ipAddr netip.Addr
if request.Destination.IsIP() {
ipAddr = request.Destination.Addr
} else if strings.HasSuffix(request.Destination.Fqdn, ".in-addr.arpa") {
ipAddr, _ = netip.ParseAddr(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".in-addr.arpa")])
} else if strings.HasSuffix(request.Destination.Fqdn, ".ip6.arpa") {
ipAddr, _ = netip.ParseAddr(strings.ReplaceAll(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".ip6.arpa")], ".", ":"))
}
if !ipAddr.IsValid() {
return E.New("socks4: torsocks: invalid destination")
}
host, err := resolver.LookupPTR(ctx, ipAddr)
if err != nil {
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
})
if err != nil {
return err
}
return E.Cause(err, "socks4: torsocks: lookup PTR failed for ip: ", ipAddr)
}
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeGranted,
Destination: M.Socksaddr{
Fqdn: host,
},
})
if err != nil {
return E.Cause(err, "socks4: torsocks: write response")
}
return nil
default:
return os.ErrInvalid
}
}
func handleTorSocks5(ctx context.Context, conn net.Conn, request socks5.Request, resolver TorResolver) error {
switch request.Command {
case CommandTorResolve:
if !request.Destination.IsFqdn() {
return E.New("socks5: torsocks: invalid destination")
}
ipAddr, err := resolver.LookupIP(ctx, request.Destination.Fqdn)
if err != nil {
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeFailure,
})
if err != nil {
return err
}
return E.Cause(err, "socks5: torsocks: lookup failed for domain: ", request.Destination.Fqdn)
}
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.SocksaddrFrom(ipAddr, 0),
})
if err != nil {
return E.Cause(err, "socks5: torsocks: write response")
}
return nil
case CommandTorResolvePTR:
var ipAddr netip.Addr
if request.Destination.IsIP() {
ipAddr = request.Destination.Addr
} else if strings.HasSuffix(request.Destination.Fqdn, ".in-addr.arpa") {
ipAddr, _ = netip.ParseAddr(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".in-addr.arpa")])
} else if strings.HasSuffix(request.Destination.Fqdn, ".ip6.arpa") {
ipAddr, _ = netip.ParseAddr(strings.ReplaceAll(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".ip6.arpa")], ".", ":"))
}
if !ipAddr.IsValid() {
return E.New("socks5: torsocks: invalid destination")
}
host, err := resolver.LookupPTR(ctx, ipAddr)
if err != nil {
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeFailure,
})
if err != nil {
return err
}
return E.Cause(err, "socks5: torsocks: lookup PTR failed for ip: ", ipAddr)
}
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.Socksaddr{
Fqdn: host,
},
})
if err != nil {
return E.Cause(err, "socks5: torsocks: write response")
}
return nil
default:
return os.ErrInvalid
}
}