Compare commits

...

78 commits

Author SHA1 Message Date
世界
159e489fc3
Add E.Cause1 2025-04-03 18:26:45 +08:00
世界
d39c2c2fdd
socks: Add custom udp listener 2025-03-26 13:18:24 +08:00
世界
ea82ac275f
Add freelru.GetWithLifetimeNoExpire 2025-03-26 13:18:18 +08:00
世界
ea0ac932ae
Add winiphlpapi 2025-03-26 13:18:17 +08:00
世界
2b41455f5a
Fix udpnat2 handler again 2025-03-26 12:46:15 +08:00
世界
23b0180a1b
Fix crash on udpnat2 handler 2025-03-24 18:11:10 +08:00
世界
ce1b4851a4
Fix socks5 UDP 2025-03-16 10:23:29 +08:00
世界
2238a05966
Fix merge objects 2025-03-16 10:23:29 +08:00
世界
b55d1c78b3
bufio: Add destination NAT packet conn 2025-03-09 15:20:32 +08:00
世界
d54716612c
Fix syscall packet read waiter for Windows 2025-02-28 12:07:45 +08:00
世界
9eafc7fc62
udpnat2: Fix crash 2025-02-10 15:08:18 +08:00
世界
d8153df67f
Add ENOTCONN to IsClosed 2025-02-06 08:41:32 +08:00
世界
d9f6eb136d
Fix set windows system time 2025-01-09 23:30:25 +08:00
世界
4dabb9be97
freelru: Fix GetAndRefreshOrAdd 2025-01-09 15:59:26 +08:00
世界
be9840c70f
listable: Fix incorrect unmarshaling of null to []T{null} 2025-01-09 15:57:12 +08:00
世界
aa7d2543a3
Fix errors usage 2024-12-16 09:20:34 +08:00
世界
33beacc053
Fix socks5 UDP handshake 2024-12-14 18:16:15 +08:00
世界
442cceb9fa
Fix disable UDP fragment 2024-12-12 20:43:56 +08:00
世界
3374a45475
Fix socks5 UDP implementation 2024-12-10 19:53:57 +08:00
世界
73776cf797
Fix lru test 2024-12-10 19:42:55 +08:00
世界
957166799e
Fix CloseOnHandshakeFailure 2024-12-04 17:14:58 +08:00
世界
809d8eca13
freelru: fix PurgeExpired 2024-12-04 11:36:20 +08:00
世界
9f69e7f9f7
E: IsClosedOrCanceled check IsTimeout 2024-12-01 20:19:37 +08:00
世界
478265cd45
badoption: Finish netip options 2024-12-01 14:33:23 +08:00
世界
3f30aaf25e
freelru: purge all expired items 2024-11-30 16:06:59 +08:00
世界
39040e06dc
udpnat2: Fix concurrency 2024-11-28 13:51:17 +08:00
世界
6edd2ce0ea
freelru: Update source and add GetAndRefreshOrAdd 2024-11-28 13:51:17 +08:00
世界
0a2e2a3eaf
udpnat2: Fix timeout 2024-11-27 18:02:22 +08:00
世界
4ba1eb123c
Fix set timeout 2024-11-27 17:28:18 +08:00
世界
c44912a861
freelru: Fix purge 2024-11-27 13:51:08 +08:00
世界
a8f5bf4eb0
udpnat2: Add timeout check 2024-11-26 19:08:35 +08:00
世界
30e9d91b57
Fix AppendClose 2024-11-26 12:21:37 +08:00
世界
7fd3517e4d
udpnat2: Add purge expire ticker 2024-11-26 12:21:37 +08:00
世界
a8285e06a5
udpnat2: Implement set timeout for nat conn 2024-11-26 12:21:37 +08:00
世界
3613ead480
freelru: Add PeekWithLifetime and UpdateLifetime 2024-11-26 11:29:14 +08:00
世界
c8f251c668
Fix copy count 2024-11-24 19:02:21 +08:00
世界
fa5355e99e
bufio: more copy funcs 2024-11-20 11:27:20 +08:00
世界
30fbafd954
udpnat2: Add cache funcs 2024-11-18 12:14:35 +08:00
世界
fdca9b3f8e
badjson: Fix Listable 2024-11-16 16:03:00 +08:00
世界
e52e04f721
Fix HandshakeFailure usages 2024-11-15 16:27:03 +08:00
世界
7f621fdd78
Add freelru.SetUpdateLifetimeOnGet/GetWithLifetime 2024-11-14 17:49:49 +08:00
世界
ae139d9ee1
Update N.PayloadDialer 2024-11-14 17:49:49 +08:00
世界
c432befd02
http: Fix proxying websocket 2024-11-13 19:02:07 +08:00
世界
cc7e630923
control: Refactor interface finder 2024-11-12 20:15:50 +08:00
世界
0998999911
udpnat2: Fix missing shared impl 2024-11-09 11:40:27 +08:00
世界
72ff654ee0
shared: Add SetHealthCheck to interface 2024-11-09 11:40:27 +08:00
世界
11ffb962ae
freelru: Fix impl 2024-11-09 11:40:27 +08:00
世界
fcb19641e6
freelru: Copy shared source 2024-11-09 11:40:27 +08:00
世界
524a6bd0d1
udpnat2: Set upstream to writer 2024-11-09 11:40:27 +08:00
世界
b5f9e70ffd
badjson: Fix Listable 2024-11-09 11:40:27 +08:00
世界
c80c8f907c
badjson: Add context marshaler/unmarshaler 2024-11-05 18:43:05 +08:00
世界
a4eb7fa900
udpnat2: Add SetHandler 2024-11-05 18:43:05 +08:00
世界
7ec09d6045
udpnat2: New synced udp nat service 2024-11-05 18:43:04 +08:00
世界
0641c71805
maphash: copy source from v0.1.0 2024-11-05 18:43:04 +08:00
世界
e7ec021b81
freelru: copy source from v0.14.0 2024-11-05 18:43:04 +08:00
世界
0f2447a95b
Crazy sekai overturns the small pond 2024-11-05 18:43:04 +08:00
世界
72db784fc7
Add bind.Interface.Flags 2024-11-04 11:05:38 +08:00
世界
d59ac57aaa
Add go1.21 compat funcs 2024-10-19 09:09:15 +08:00
世界
c63546470b
Add Update() error to control.InterfaceFinder 2024-09-22 22:15:12 +08:00
世界
55908bea36
Update linter configuration 2024-09-14 21:36:41 +08:00
世界
6567829958
Fix cached conn eats up read deadlines 2024-09-14 10:11:50 +08:00
世界
c324d4143d
json: Add badoption templates 2024-09-10 23:57:22 +08:00
世界
0acb36c118
Minor fixes 2024-09-10 23:46:03 +08:00
世界
26511a251f
udpnat: Fix read deadline not initialized 2024-08-19 17:56:31 +08:00
世界
afd8993773
windnsapi: Fix incorrect error checking 2024-08-19 17:47:42 +08:00
世界
96bef0733f
Fix bad group usages 2024-08-18 11:15:20 +08:00
世界
ec1df651e8
Update golangci-lint configuration 2024-08-18 11:15:15 +08:00
世界
e33b1d67d5
bufio: Add ReadBufferSize and ReadPacketSize 2024-08-18 09:14:52 +08:00
世界
ed6cde73f7
udpnat: Implement read deadline 2024-08-18 09:14:44 +08:00
世界
73cc65605e
pipe: Make pipeDeadline public for use 2024-08-18 08:59:04 +08:00
世界
6c19e0736d
windows: Migrate to mkwinsyscall 2024-08-09 12:00:21 +08:00
世界
08e8c02fb1
Fix usage of PowerUnregisterSuspendResumeNotification 2024-08-08 15:27:35 +08:00
世界
7beca62e4f
Improve winpowrprof callback 2024-08-06 13:19:09 +08:00
世界
e422e3d048
Reuse winpowrprof callback 2024-08-06 12:23:56 +08:00
世界
fa81eabc29
ntp: Fix a bad context usage 2024-08-06 12:23:56 +08:00
世界
4498e57839
task: Fix context not continuous 2024-07-31 18:06:42 +08:00
世界
f97054e917
ntp: Ignore setup error 2024-07-31 10:25:01 +08:00
世界
a2f9fef936
domain: Add adguard matcher 2024-07-26 08:00:09 +08:00
121 changed files with 5964 additions and 753 deletions

View file

@ -26,7 +26,7 @@ jobs:
- name: Setup Go - name: Setup Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: ^1.22 go-version: ^1.23
- name: Cache go module - name: Cache go module
uses: actions/cache@v4 uses: actions/cache@v4
with: with:

View file

@ -26,7 +26,7 @@ jobs:
- name: Setup Go - name: Setup Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: ^1.22 go-version: ^1.23
- name: Build - name: Build
run: | run: |
make test make test
@ -62,6 +62,22 @@ jobs:
- name: Build - name: Build
run: | run: |
make test make test
build_go122:
name: Linux (Go 1.22)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.22
continue-on-error: true
- name: Build
run: |
make test
build_windows: build_windows:
name: Windows name: Windows
runs-on: windows-latest runs-on: windows-latest
@ -73,7 +89,7 @@ jobs:
- name: Setup Go - name: Setup Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: ^1.22 go-version: ^1.23
continue-on-error: true continue-on-error: true
- name: Build - name: Build
run: | run: |
@ -89,7 +105,7 @@ jobs:
- name: Setup Go - name: Setup Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: ^1.22 go-version: ^1.23
continue-on-error: true continue-on-error: true
- name: Build - name: Build
run: | run: |

View file

@ -5,6 +5,8 @@ linters:
- govet - govet
- gci - gci
- staticcheck - staticcheck
- paralleltest
- ineffassign
linters-settings: linters-settings:
gci: gci:
@ -13,3 +15,10 @@ linters-settings:
- standard - standard
- prefix(github.com/sagernet/) - prefix(github.com/sagernet/)
- default - default
staticcheck:
checks:
- all
- -SA1003
run:
go: "1.23"

View file

@ -8,14 +8,14 @@ fmt_install:
go install -v github.com/daixiang0/gci@latest go install -v github.com/daixiang0/gci@latest
lint: lint:
GOOS=linux golangci-lint run ./... GOOS=linux golangci-lint run
GOOS=android golangci-lint run ./... GOOS=android golangci-lint run
GOOS=windows golangci-lint run ./... GOOS=windows golangci-lint run
GOOS=darwin golangci-lint run ./... GOOS=darwin golangci-lint run
GOOS=freebsd golangci-lint run ./... GOOS=freebsd golangci-lint run
lint_install: lint_install:
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
test: test:
go test $(shell go list ./... | grep -v /internal/) go test ./...

View file

@ -2,11 +2,10 @@ package baderror
import ( import (
"context" "context"
"errors"
"io" "io"
"net" "net"
"strings" "strings"
E "github.com/sagernet/sing/common/exceptions"
) )
func Contains(err error, msgList ...string) bool { func Contains(err error, msgList ...string) bool {
@ -22,8 +21,7 @@ func WrapH2(err error) error {
if err == nil { if err == nil {
return nil return nil
} }
err = E.Unwrap(err) if errors.Is(err, io.ErrUnexpectedEOF) {
if err == io.ErrUnexpectedEOF {
return io.EOF return io.EOF
} }
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") { if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {

View file

@ -9,19 +9,20 @@ import (
type AddrConn struct { type AddrConn struct {
net.Conn net.Conn
M.Metadata Source M.Socksaddr
Destination M.Socksaddr
} }
func (c *AddrConn) LocalAddr() net.Addr { func (c *AddrConn) LocalAddr() net.Addr {
if c.Metadata.Destination.IsValid() { if c.Destination.IsValid() {
return c.Metadata.Destination.TCPAddr() return c.Destination.TCPAddr()
} }
return c.Conn.LocalAddr() return c.Conn.LocalAddr()
} }
func (c *AddrConn) RemoteAddr() net.Addr { func (c *AddrConn) RemoteAddr() net.Addr {
if c.Metadata.Source.IsValid() { if c.Source.IsValid() {
return c.Metadata.Source.TCPAddr() return c.Source.TCPAddr()
} }
return c.Conn.RemoteAddr() return c.Conn.RemoteAddr()
} }

View file

@ -3,7 +3,6 @@ package bufio
import ( import (
"io" "io"
"net" "net"
"time"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -60,13 +59,6 @@ func (c *CachedConn) WriteTo(w io.Writer) (n int64, err error) {
return return
} }
func (c *CachedConn) SetReadDeadline(t time.Time) error {
if c.buffer != nil && !c.buffer.IsEmpty() {
return nil
}
return c.Conn.SetReadDeadline(t)
}
func (c *CachedConn) ReadFrom(r io.Reader) (n int64, err error) { func (c *CachedConn) ReadFrom(r io.Reader) (n int64, err error) {
return Copy(c.Conn, r) return Copy(c.Conn, r)
} }
@ -192,10 +184,12 @@ func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
if buffer != nil { if buffer != nil {
buffer.DecRef() buffer.DecRef()
} }
return &N.PacketBuffer{ packet := N.NewPacketBuffer()
*packet = N.PacketBuffer{
Buffer: buffer, Buffer: buffer,
Destination: c.destination, Destination: c.destination,
} }
return packet
} }
func (c *CachedPacketConn) Upstream() any { func (c *CachedPacketConn) Upstream() any {

View file

@ -35,14 +35,7 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release() defer buffer.Release()
if destination.IsFqdn() { return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
if err != nil {
return err
}
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
}
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
} }
func (w *ExtendedUDPConn) Upstream() any { func (w *ExtendedUDPConn) Upstream() any {

View file

@ -29,28 +29,36 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
if cachedSrc, isCached := source.(N.CachedReader); isCached { if cachedSrc, isCached := source.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached() cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil { if cachedBuffer != nil {
if !cachedBuffer.IsEmpty() { dataLen := cachedBuffer.Len()
_, err = destination.Write(cachedBuffer.Bytes()) _, err = destination.Write(cachedBuffer.Bytes())
if err != nil {
cachedBuffer.Release()
return
}
}
cachedBuffer.Release() cachedBuffer.Release()
if err != nil {
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
continue continue
} }
} }
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
break break
} }
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
}
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
} }
@ -75,6 +83,7 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters) return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
} }
// Deprecated: not used
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
buffer.IncRef() buffer.IncRef()
defer buffer.DecRef() defer buffer.DecRef()
@ -113,19 +122,10 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
} }
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destination) options := N.NewReadWaitOptions(source, destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
bufferSize := N.CalculateMTU(source, destination)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
var notFirstTime bool var notFirstTime bool
for { for {
buffer := buf.NewSize(bufferSize) buffer := options.NewBuffer()
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
err = source.ReadBuffer(buffer) err = source.ReadBuffer(buffer)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
@ -136,7 +136,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
return return
} }
dataLen := buffer.Len() dataLen := buffer.Len()
buffer.OverCap(rearHeadroom) options.PostReturn(buffer)
err = destination.WriteBuffer(buffer) err = destination.WriteBuffer(buffer)
if err != nil { if err != nil {
buffer.Leak() buffer.Leak()
@ -157,10 +157,6 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
} }
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error { func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
return CopyConnContextList([]context.Context{ctx}, source, destination)
}
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
var group task.Group var group task.Group
if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex { if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex {
group.Append("upload", func(ctx context.Context) error { group.Append("upload", func(ctx context.Context) error {
@ -197,7 +193,7 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
group.Cleanup(func() { group.Cleanup(func() {
common.Close(source, destination) common.Close(source, destination)
}) })
return group.RunContextList(contextList) return group.Run(ctx)
} }
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
@ -217,24 +213,24 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
break break
} }
if cachedPackets != nil { if cachedPackets != nil {
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets) n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
if err != nil { if err != nil {
return return
} }
} }
frontHeadroom := N.CalculateFrontHeadroom(destinationConn) copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
rearHeadroom := N.CalculateRearHeadroom(destinationConn) n += copeN
return
}
func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var ( var (
handled bool handled bool
copeN int64 copeN int64
) )
readWaiter, isReadWaiter := CreatePacketReadWaiter(source) readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter { if isReadWaiter {
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{ needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destinationConn),
})
if !needCopy || common.LowMemory { if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled { if handled {
@ -248,28 +244,19 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
return return
} }
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn) options := N.NewReadWaitOptions(source, destination)
rearHeadroom := N.CalculateRearHeadroom(destinationConn) var destinationAddress M.Socksaddr
bufferSize := N.CalculateMTU(source, destinationConn)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var destination M.Socksaddr
for { for {
buffer := buf.NewSize(bufferSize) buffer := options.NewPacketBuffer()
buffer.Resize(frontHeadroom, 0) destinationAddress, err = source.ReadPacket(buffer)
buffer.Reserve(rearHeadroom)
destination, err = source.ReadPacket(buffer)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return return
} }
dataLen := buffer.Len() dataLen := buffer.Len()
buffer.OverCap(rearHeadroom) options.PostReturn(buffer)
err = destinationConn.WritePacket(buffer, destination) err = destination.WritePacket(buffer, destinationAddress)
if err != nil { if err != nil {
buffer.Leak() buffer.Leak()
if !notFirstTime { if !notFirstTime {
@ -277,34 +264,25 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
} }
return return
} }
n += int64(dataLen)
for _, counter := range readCounters { for _, counter := range readCounters {
counter(int64(dataLen)) counter(int64(dataLen))
} }
for _, counter := range writeCounters { for _, counter := range writeCounters {
counter(int64(dataLen)) counter(int64(dataLen))
} }
n += int64(dataLen)
notFirstTime = true notFirstTime = true
} }
} }
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn) options := N.NewReadWaitOptions(nil, destination)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
var notFirstTime bool var notFirstTime bool
for _, packetBuffer := range packetBuffers { for _, packetBuffer := range packetBuffers {
buffer := buf.NewPacket() buffer := options.Copy(packetBuffer.Buffer)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
_, err = buffer.Write(packetBuffer.Buffer.Bytes())
packetBuffer.Buffer.Release()
if err != nil {
buffer.Release()
continue
}
dataLen := buffer.Len() dataLen := buffer.Len()
buffer.OverCap(rearHeadroom) err = destination.WritePacket(buffer, packetBuffer.Destination)
err = destinationConn.WritePacket(buffer, packetBuffer.Destination) N.PutPacketBuffer(packetBuffer)
if err != nil { if err != nil {
buffer.Leak() buffer.Leak()
if !notFirstTime { if !notFirstTime {
@ -312,16 +290,19 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr
} }
return return
} }
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
n += int64(dataLen) n += int64(dataLen)
notFirstTime = true
} }
return return
} }
func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error { func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
return CopyPacketConnContextList([]context.Context{ctx}, source, destination)
}
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
var group task.Group var group task.Group
group.Append("upload", func(ctx context.Context) error { group.Append("upload", func(ctx context.Context) error {
return common.Error(CopyPacket(destination, source)) return common.Error(CopyPacket(destination, source))
@ -333,5 +314,5 @@ func CopyPacketConnContextList(contextList []context.Context, source N.PacketCon
common.Close(source, destination) common.Close(source, destination)
}) })
group.FastFail() group.FastFail()
return group.RunContextList(contextList) return group.Run(ctx)
} }

View file

@ -5,7 +5,6 @@ import (
"net/netip" "net/netip"
"os" "os"
"syscall" "syscall"
"unsafe"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@ -15,49 +14,6 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
var modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
var procrecv = modws2_32.NewProc("recv")
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
r0, _, e1 := syscall.SyscallN(procrecv.Addr(), uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags))
n = int32(r0)
if n == -1 {
err = errnoErr(e1)
}
return
}
var _ N.ReadWaiter = (*syscallReadWaiter)(nil) var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
type syscallReadWaiter struct { type syscallReadWaiter struct {
@ -164,16 +120,16 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
var readN int var readN int
var from windows.Sockaddr var from windows.Sockaddr
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0) readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr != nil {
buffer.Release()
return w.readErr != windows.WSAEWOULDBLOCK
}
if readN > 0 { if readN > 0 {
buffer.Truncate(readN) buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
}
if w.readErr == windows.WSAEWOULDBLOCK {
return false
} }
w.options.PostReturn(buffer)
w.buffer = buffer
if from != nil { if from != nil {
switch fromAddr := from.(type) { switch fromAddr := from.(type) {
case *windows.SockaddrInet4: case *windows.SockaddrInet4:

View file

@ -25,6 +25,45 @@ func ReadPacket(reader N.PacketReader, buffer *buf.Buffer) (n int, addr net.Addr
return return
} }
func ReadBufferSize(reader io.Reader, bufferSize int) (buffer *buf.Buffer, err error) {
readWaiter, isReadWaiter := CreateReadWaiter(reader)
if isReadWaiter {
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
MTU: bufferSize,
})
return readWaiter.WaitReadBuffer()
}
buffer = buf.NewSize(bufferSize)
if extendedReader, isExtendedReader := reader.(N.ExtendedReader); isExtendedReader {
err = extendedReader.ReadBuffer(buffer)
} else {
_, err = buffer.ReadOnceFrom(reader)
}
if err != nil {
buffer.Release()
buffer = nil
}
return
}
func ReadPacketSize(reader N.PacketReader, packetSize int) (buffer *buf.Buffer, destination M.Socksaddr, err error) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(reader)
if isReadWaiter {
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
MTU: packetSize,
})
buffer, destination, err = readWaiter.WaitReadPacket()
return
}
buffer = buf.NewSize(packetSize)
destination, err = reader.ReadPacket(buffer)
if err != nil {
buffer.Release()
buffer = nil
}
return
}
func Write(writer io.Writer, data []byte) (n int, err error) { func Write(writer io.Writer, data []byte) (n int, err error) {
if extendedWriter, isExtended := writer.(N.ExtendedWriter); isExtended { if extendedWriter, isExtended := writer.(N.ExtendedWriter); isExtended {
return WriteBuffer(extendedWriter, buf.As(data)) return WriteBuffer(extendedWriter, buf.As(data))

View file

@ -30,6 +30,14 @@ func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.So
} }
} }
func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &destinationNATPacketConn{
NetPacketConn: conn,
origin: origin,
destination: destination,
}
}
type unidirectionalNATPacketConn struct { type unidirectionalNATPacketConn struct {
N.NetPacketConn N.NetPacketConn
origin M.Socksaddr origin M.Socksaddr
@ -144,6 +152,60 @@ func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr() return c.destination.UDPAddr()
} }
type destinationNATPacketConn struct {
N.NetPacketConn
origin M.Socksaddr
destination M.Socksaddr
}
func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
if M.SocksaddrFromNet(addr) == c.origin {
addr = c.destination.UDPAddr()
}
return
}
func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if M.SocksaddrFromNet(addr) == c.destination {
addr = c.origin.UDPAddr()
}
return c.NetPacketConn.WriteTo(p, addr)
}
func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return
}
if destination == c.origin {
destination = c.destination
}
return
}
func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination == c.destination {
destination = c.origin
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
}
func (c *destinationNATPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *destinationNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr { func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
destination.Port = 0 destination.Port = 0
return destination return destination

View file

@ -36,7 +36,7 @@ func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
clientConn, clientErr = net.Dial("tcp", listener.Addr().String()) clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
return clientErr return clientErr
}) })
err = group.Run() err = group.Run(context.Background())
require.NoError(t, err) require.NoError(t, err)
listener.Close() listener.Close()
t.Cleanup(func() { t.Cleanup(func() {

View file

@ -0,0 +1,5 @@
package bufio
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
//sys recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) [failretval == -1] = ws2_32.recv

View file

@ -38,7 +38,6 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
var innerErr unix.Errno var innerErr unix.Errno
err := w.rawConn.Write(func(fd uintptr) (done bool) { err := w.rawConn.Write(func(fd uintptr) (done bool) {
//nolint:staticcheck //nolint:staticcheck
//goland:noinspection GoDeprecation
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList))) _, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
}) })

View file

@ -0,0 +1,57 @@
// Code generated by 'go generate'; DO NOT EDIT.
package bufio
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procrecv = modws2_32.NewProc("recv")
)
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
r0, _, e1 := syscall.Syscall6(procrecv.Addr(), 4, uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags), 0, 0)
n = int32(r0)
if n == -1 {
err = errnoErr(e1)
}
return
}

View file

@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration {
return i.timeout return i.timeout
} }
func (i *Instance) SetTimeout(timeout time.Duration) { func (i *Instance) SetTimeout(timeout time.Duration) bool {
i.timeout = timeout i.timeout = timeout
i.Update() return i.Update()
} }
func (i *Instance) wait() { func (i *Instance) wait() {

View file

@ -13,7 +13,7 @@ import (
type PacketConn interface { type PacketConn interface {
N.PacketConn N.PacketConn
Timeout() time.Duration Timeout() time.Duration
SetTimeout(timeout time.Duration) SetTimeout(timeout time.Duration) bool
} }
type TimerPacketConn struct { type TimerPacketConn struct {
@ -24,10 +24,12 @@ type TimerPacketConn struct {
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) { func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn { if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
oldTimeout := timeoutConn.Timeout() oldTimeout := timeoutConn.Timeout()
if timeout < oldTimeout { if oldTimeout > 0 && timeout >= oldTimeout {
timeoutConn.SetTimeout(timeout) return ctx, conn
}
if timeoutConn.SetTimeout(timeout) {
return ctx, conn
} }
return ctx, conn
} }
err := conn.SetReadDeadline(time.Time{}) err := conn.SetReadDeadline(time.Time{})
if err == nil { if err == nil {
@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration {
return c.instance.Timeout() return c.instance.Timeout()
} }
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) { func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
c.instance.SetTimeout(timeout) return c.instance.SetTimeout(timeout)
} }
func (c *TimerPacketConn) Close() error { func (c *TimerPacketConn) Close() error {

View file

@ -61,9 +61,9 @@ func (c *TimeoutPacketConn) Timeout() time.Duration {
return c.timeout return c.timeout
} }
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) { func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
c.timeout = timeout c.timeout = timeout
c.PacketConn.SetReadDeadline(time.Now()) return c.PacketConn.SetReadDeadline(time.Now()) == nil
} }
func (c *TimeoutPacketConn) Close() error { func (c *TimeoutPacketConn) Close() error {

View file

@ -157,6 +157,18 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
return -1 return -1
} }
func Equal[S ~[]E, E comparable](s1, s2 S) bool {
if len(s1) != len(s2) {
return false
}
for i := range s1 {
if s1[i] != s2[i] {
return false
}
}
return true
}
//go:norace //go:norace
func Dup[T any](obj T) T { func Dup[T any](obj T) T {
pointer := uintptr(unsafe.Pointer(&obj)) pointer := uintptr(unsafe.Pointer(&obj))
@ -268,6 +280,14 @@ func Reverse[T any](arr []T) []T {
return arr return arr
} }
func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K {
ret := make(map[V]K, len(m))
for k, v := range m {
ret[v] = k
}
return ret
}
func Done(ctx context.Context) bool { func Done(ctx context.Context) bool {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -362,24 +382,3 @@ func Close(closers ...any) error {
} }
return retErr return retErr
} }
// Deprecated: wtf is this?
type Starter interface {
Start() error
}
// Deprecated: wtf is this?
func Start(starters ...any) error {
for _, rawStarter := range starters {
if rawStarter == nil {
continue
}
if starter, isStarter := rawStarter.(Starter); isStarter {
err := starter.Start()
if err != nil {
return err
}
}
}
return nil
}

View file

@ -5,6 +5,7 @@ import (
"reflect" "reflect"
) )
// Deprecated: not used
func SelectContext(contextList []context.Context) (int, error) { func SelectContext(contextList []context.Context) (int, error) {
if len(contextList) == 1 { if len(contextList) == 1 {
<-contextList[0].Done() <-contextList[0].Done()

View file

@ -9,15 +9,15 @@ import (
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error { func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error { return Raw(conn, func(fd uintptr) error {
var err error
if interfaceIndex == -1 { if interfaceIndex == -1 {
if finder == nil { if finder == nil {
return os.ErrInvalid return os.ErrInvalid
} }
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) iif, err := finder.ByName(interfaceName)
if err != nil { if err != nil {
return err return err
} }
interfaceIndex = iif.Index
} }
switch network { switch network {
case "tcp6", "udp6": case "tcp6", "udp6":

View file

@ -3,19 +3,57 @@ package control
import ( import (
"net" "net"
"net/netip" "net/netip"
"unsafe"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
) )
type InterfaceFinder interface { type InterfaceFinder interface {
Update() error
Interfaces() []Interface Interfaces() []Interface
InterfaceIndexByName(name string) (int, error) ByName(name string) (*Interface, error)
InterfaceNameByIndex(index int) (string, error) ByIndex(index int) (*Interface, error)
InterfaceByAddr(addr netip.Addr) (*Interface, error) ByAddr(addr netip.Addr) (*Interface, error)
} }
type Interface struct { type Interface struct {
Index int Index int
MTU int MTU int
Name string Name string
Addresses []netip.Prefix
HardwareAddr net.HardwareAddr HardwareAddr net.HardwareAddr
Flags net.Flags
Addresses []netip.Prefix
}
func (i Interface) Equals(other Interface) bool {
return i.Index == other.Index &&
i.MTU == other.MTU &&
i.Name == other.Name &&
common.Equal(i.HardwareAddr, other.HardwareAddr) &&
i.Flags == other.Flags &&
common.Equal(i.Addresses, other.Addresses)
}
func (i Interface) NetInterface() net.Interface {
return *(*net.Interface)(unsafe.Pointer(&i))
}
func InterfaceFromNet(iif net.Interface) (Interface, error) {
ifAddrs, err := iif.Addrs()
if err != nil {
return Interface{}, err
}
return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil
}
func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface {
return Interface{
Index: iif.Index,
MTU: iif.MTU,
Name: iif.Name,
HardwareAddr: iif.HardwareAddr,
Flags: iif.Flags,
Addresses: addresses,
}
} }

View file

@ -3,11 +3,8 @@ package control
import ( import (
"net" "net"
"net/netip" "net/netip"
_ "unsafe"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
) )
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil) var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
@ -27,17 +24,12 @@ func (f *DefaultInterfaceFinder) Update() error {
} }
interfaces := make([]Interface, 0, len(netIfs)) interfaces := make([]Interface, 0, len(netIfs))
for _, netIf := range netIfs { for _, netIf := range netIfs {
ifAddrs, err := netIf.Addrs() var iif Interface
iif, err = InterfaceFromNet(netIf)
if err != nil { if err != nil {
return err return err
} }
interfaces = append(interfaces, Interface{ interfaces = append(interfaces, iif)
Index: netIf.Index,
MTU: netIf.MTU,
Name: netIf.Name,
Addresses: common.Map(ifAddrs, M.PrefixFromNet),
HardwareAddr: netIf.HardwareAddr,
})
} }
f.interfaces = interfaces f.interfaces = interfaces
return nil return nil
@ -51,46 +43,41 @@ func (f *DefaultInterfaceFinder) Interfaces() []Interface {
return f.interfaces return f.interfaces
} }
func (f *DefaultInterfaceFinder) InterfaceIndexByName(name string) (int, error) { func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) {
for _, netInterface := range f.interfaces { for _, netInterface := range f.interfaces {
if netInterface.Name == name { if netInterface.Name == name {
return netInterface.Index, nil return &netInterface, nil
} }
} }
netInterface, err := net.InterfaceByName(name) _, err := net.InterfaceByName(name)
if err != nil { if err == nil {
return 0, err err = f.Update()
if err != nil {
return nil, err
}
return f.ByName(name)
} }
f.Update() return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
return netInterface.Index, nil
} }
func (f *DefaultInterfaceFinder) InterfaceNameByIndex(index int) (string, error) { func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) {
for _, netInterface := range f.interfaces { for _, netInterface := range f.interfaces {
if netInterface.Index == index { if netInterface.Index == index {
return netInterface.Name, nil return &netInterface, nil
} }
} }
netInterface, err := net.InterfaceByIndex(index) _, err := net.InterfaceByIndex(index)
if err != nil { if err == nil {
return "", err err = f.Update()
if err != nil {
return nil, err
}
return f.ByIndex(index)
} }
f.Update() return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
return netInterface.Name, nil
} }
func (f *DefaultInterfaceFinder) InterfaceByAddr(addr netip.Addr) (*Interface, error) { func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) {
for _, netInterface := range f.interfaces {
for _, prefix := range netInterface.Addresses {
if prefix.Contains(addr) {
return &netInterface, nil
}
}
}
err := f.Update()
if err != nil {
return nil, err
}
for _, netInterface := range f.interfaces { for _, netInterface := range f.interfaces {
for _, prefix := range netInterface.Addresses { for _, prefix := range netInterface.Addresses {
if prefix.Contains(addr) { if prefix.Contains(addr) {

View file

@ -19,11 +19,11 @@ func bindToInterface(conn syscall.RawConn, network string, address string, finde
if interfaceName == "" { if interfaceName == "" {
return os.ErrInvalid return os.ErrInvalid
} }
var err error iif, err := finder.ByName(interfaceName)
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
if err != nil { if err != nil {
return err return err
} }
interfaceIndex = iif.Index
} }
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex) err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
if err == nil { if err == nil {

View file

@ -11,19 +11,19 @@ import (
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error { func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error { return Raw(conn, func(fd uintptr) error {
var err error
if interfaceIndex == -1 { if interfaceIndex == -1 {
if finder == nil { if finder == nil {
return os.ErrInvalid return os.ErrInvalid
} }
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) iif, err := finder.ByName(interfaceName)
if err != nil { if err != nil {
return err return err
} }
interfaceIndex = iif.Index
} }
handle := syscall.Handle(fd) handle := syscall.Handle(fd)
if M.ParseSocksaddr(address).AddrString() == "" { if M.ParseSocksaddr(address).AddrString() == "" {
err = bind4(handle, interfaceIndex) err := bind4(handle, interfaceIndex)
if err != nil { if err != nil {
return err return err
} }

View file

@ -4,19 +4,26 @@ import (
"os" "os"
"syscall" "syscall"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func DisableUDPFragment() Func { func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error { return func(network, address string, conn syscall.RawConn) error {
if N.NetworkName(network) != N.NetworkUDP {
return nil
}
return Raw(conn, func(fd uintptr) error { return Raw(conn, func(fd uintptr) error {
switch network { if network == "udp" || network == "udp4" {
case "udp4": err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1); err != nil { if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err) return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
} }
case "udp6": }
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1); err != nil { if network == "udp" || network == "udp6" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err) return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
} }
} }

View file

@ -11,17 +11,19 @@ import (
func DisableUDPFragment() Func { func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error { return func(network, address string, conn syscall.RawConn) error {
switch N.NetworkName(network) { if N.NetworkName(network) != N.NetworkUDP {
case N.NetworkUDP:
default:
return nil return nil
} }
return Raw(conn, func(fd uintptr) error { return Raw(conn, func(fd uintptr) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil { if network == "udp" || network == "udp4" {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err) err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
} }
if network == "udp6" { if network == "udp" || network == "udp6" {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil { err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err) return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
} }
} }

View file

@ -25,17 +25,19 @@ const (
func DisableUDPFragment() Func { func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error { return func(network, address string, conn syscall.RawConn) error {
switch N.NetworkName(network) { if N.NetworkName(network) != N.NetworkUDP {
case N.NetworkUDP:
default:
return nil return nil
} }
return Raw(conn, func(fd uintptr) error { return Raw(conn, func(fd uintptr) error {
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil { if network == "udp" || network == "udp4" {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err) err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
} }
if network == "udp6" { if network == "udp" || network == "udp6" {
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil { err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err) return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
} }
} }

View file

@ -0,0 +1,67 @@
package domain_test
import (
"sort"
"testing"
"github.com/sagernet/sing/common/domain"
"github.com/stretchr/testify/require"
)
func TestAdGuardMatcher(t *testing.T) {
t.Parallel()
ruleLines := []string{
"||example.org^",
"|example.com^",
"example.net^",
"||example.edu",
"||example.edu.tw^",
"|example.gov",
"example.arpa",
}
matcher := domain.NewAdGuardMatcher(ruleLines)
require.NotNil(t, matcher)
matchDomain := []string{
"example.org",
"www.example.org",
"example.com",
"example.net",
"isexample.net",
"www.example.net",
"example.edu",
"example.edu.cn",
"example.edu.tw",
"www.example.edu",
"www.example.edu.cn",
"example.gov",
"example.gov.cn",
"example.arpa",
"www.example.arpa",
"isexample.arpa",
"example.arpa.cn",
"www.example.arpa.cn",
"isexample.arpa.cn",
}
notMatchDomain := []string{
"example.org.cn",
"notexample.org",
"example.com.cn",
"www.example.com.cn",
"example.net.cn",
"notexample.edu",
"notexample.edu.cn",
"www.example.gov",
"notexample.gov",
}
for _, domain := range matchDomain {
require.True(t, matcher.Match(domain), domain)
}
for _, domain := range notMatchDomain {
require.False(t, matcher.Match(domain), domain)
}
dLines := matcher.Dump()
sort.Strings(ruleLines)
sort.Strings(dLines)
require.Equal(t, ruleLines, dLines)
}

View file

@ -0,0 +1,172 @@
package domain
import (
"bytes"
"sort"
"strings"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/varbin"
)
const (
anyLabel = '*'
suffixLabel = '\b'
)
type AdGuardMatcher struct {
set *succinctSet
}
func NewAdGuardMatcher(ruleLines []string) *AdGuardMatcher {
ruleList := make([]string, 0, len(ruleLines))
for _, ruleLine := range ruleLines {
var (
isSuffix bool // ||
hasStart bool // |
hasEnd bool // ^
)
if strings.HasPrefix(ruleLine, "||") {
ruleLine = ruleLine[2:]
isSuffix = true
} else if strings.HasPrefix(ruleLine, "|") {
ruleLine = ruleLine[1:]
hasStart = true
}
if strings.HasSuffix(ruleLine, "^") {
ruleLine = ruleLine[:len(ruleLine)-1]
hasEnd = true
}
if isSuffix {
ruleLine = string(rootLabel) + ruleLine
} else if !hasStart {
ruleLine = string(prefixLabel) + ruleLine
}
if !hasEnd {
if strings.HasSuffix(ruleLine, ".") {
ruleLine = ruleLine[:len(ruleLine)-1]
}
ruleLine += string(suffixLabel)
}
ruleList = append(ruleList, reverseDomain(ruleLine))
}
ruleList = common.Uniq(ruleList)
sort.Strings(ruleList)
return &AdGuardMatcher{newSuccinctSet(ruleList)}
}
func ReadAdGuardMatcher(reader varbin.Reader) (*AdGuardMatcher, error) {
set, err := readSuccinctSet(reader)
if err != nil {
return nil, err
}
return &AdGuardMatcher{set}, nil
}
func (m *AdGuardMatcher) Write(writer varbin.Writer) error {
return m.set.Write(writer)
}
func (m *AdGuardMatcher) Match(domain string) bool {
key := reverseDomain(domain)
if m.has([]byte(key), 0, 0) {
return true
}
for {
if m.has([]byte(string(suffixLabel)+key), 0, 0) {
return true
}
idx := strings.IndexByte(key, '.')
if idx == -1 {
return false
}
key = key[idx+1:]
}
}
func (m *AdGuardMatcher) has(key []byte, nodeId, bmIdx int) bool {
for i := 0; i < len(key); i++ {
currentChar := key[i]
for ; ; bmIdx++ {
if getBit(m.set.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := m.set.labels[bmIdx-nodeId]
if nextLabel == prefixLabel {
return true
}
if nextLabel == rootLabel {
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
hasNext := getBit(m.set.leaves, nextNodeId) != 0
if currentChar == '.' && hasNext {
return true
}
}
if nextLabel == currentChar {
break
}
if nextLabel == anyLabel {
idx := bytes.IndexRune(key[i:], '.')
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
if idx == -1 {
if getBit(m.set.leaves, nextNodeId) != 0 {
return true
}
idx = 0
}
nextBmIdx := selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nextNodeId-1) + 1
if m.has(key[i+idx:], nextNodeId, nextBmIdx) {
return true
}
}
}
nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1
}
if getBit(m.set.leaves, nodeId) != 0 {
return true
}
for ; ; bmIdx++ {
if getBit(m.set.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := m.set.labels[bmIdx-nodeId]
if nextLabel == prefixLabel || nextLabel == rootLabel {
return true
}
}
}
func (m *AdGuardMatcher) Dump() (ruleLines []string) {
for _, key := range m.set.keys() {
key = reverseDomain(key)
var (
isSuffix bool
hasStart bool
hasEnd bool
)
if key[0] == prefixLabel {
key = key[1:]
} else if key[0] == rootLabel {
key = key[1:]
isSuffix = true
} else {
hasStart = true
}
if key[len(key)-1] == suffixLabel {
key = key[:len(key)-1]
} else {
hasEnd = true
}
if isSuffix {
key = "||" + key
} else if hasStart {
key = "|" + key
}
if hasEnd {
key += "^"
}
ruleLines = append(ruleLines, key)
}
return
}

View file

@ -1,13 +1,17 @@
package domain package domain
import ( import (
"encoding/binary"
"sort" "sort"
"unicode/utf8" "unicode/utf8"
"github.com/sagernet/sing/common/varbin" "github.com/sagernet/sing/common/varbin"
) )
const (
prefixLabel = '\r'
rootLabel = '\n'
)
type Matcher struct { type Matcher struct {
set *succinctSet set *succinctSet
} }
@ -21,16 +25,16 @@ func NewMatcher(domains []string, domainSuffix []string, generateLegacy bool) *M
} }
seen[domain] = true seen[domain] = true
if domain[0] == '.' { if domain[0] == '.' {
domainList = append(domainList, reverseDomainSuffix(domain)) domainList = append(domainList, reverseDomain(string(prefixLabel)+domain))
} else if generateLegacy { } else if generateLegacy {
domainList = append(domainList, reverseDomain(domain)) domainList = append(domainList, reverseDomain(domain))
suffixDomain := "." + domain suffixDomain := "." + domain
if !seen[suffixDomain] { if !seen[suffixDomain] {
seen[suffixDomain] = true seen[suffixDomain] = true
domainList = append(domainList, reverseDomainSuffix(suffixDomain)) domainList = append(domainList, reverseDomain(string(prefixLabel)+suffixDomain))
} }
} else { } else {
domainList = append(domainList, reverseDomainRoot(domain)) domainList = append(domainList, reverseDomain(string(rootLabel)+domain))
} }
} }
for _, domain := range domains { for _, domain := range domains {
@ -44,38 +48,60 @@ func NewMatcher(domains []string, domainSuffix []string, generateLegacy bool) *M
return &Matcher{newSuccinctSet(domainList)} return &Matcher{newSuccinctSet(domainList)}
} }
type matcherData struct {
Version uint8
Leaves []uint64
LabelBitmap []uint64
Labels []byte
}
func ReadMatcher(reader varbin.Reader) (*Matcher, error) { func ReadMatcher(reader varbin.Reader) (*Matcher, error) {
matcher, err := varbin.ReadValue[matcherData](reader, binary.BigEndian) set, err := readSuccinctSet(reader)
if err != nil { if err != nil {
return nil, err return nil, err
} }
set := &succinctSet{
leaves: matcher.Leaves,
labelBitmap: matcher.LabelBitmap,
labels: matcher.Labels,
}
set.init()
return &Matcher{set}, nil return &Matcher{set}, nil
} }
func (m *Matcher) Match(domain string) bool { func (m *Matcher) Write(writer varbin.Writer) error {
return m.set.Has(reverseDomain(domain)) return m.set.Write(writer)
} }
func (m *Matcher) Write(writer varbin.Writer) error { func (m *Matcher) Match(domain string) bool {
return varbin.Write(writer, binary.BigEndian, matcherData{ return m.has(reverseDomain(domain))
Version: 1, }
Leaves: m.set.leaves,
LabelBitmap: m.set.labelBitmap, func (m *Matcher) has(key string) bool {
Labels: m.set.labels, var nodeId, bmIdx int
}) for i := 0; i < len(key); i++ {
currentChar := key[i]
for ; ; bmIdx++ {
if getBit(m.set.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := m.set.labels[bmIdx-nodeId]
if nextLabel == prefixLabel {
return true
}
if nextLabel == rootLabel {
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
hasNext := getBit(m.set.leaves, nextNodeId) != 0
if currentChar == '.' && hasNext {
return true
}
}
if nextLabel == currentChar {
break
}
}
nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1
}
if getBit(m.set.leaves, nodeId) != 0 {
return true
}
for ; ; bmIdx++ {
if getBit(m.set.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := m.set.labels[bmIdx-nodeId]
if nextLabel == prefixLabel || nextLabel == rootLabel {
return true
}
}
} }
func (m *Matcher) Dump() (domainList []string, prefixList []string) { func (m *Matcher) Dump() (domainList []string, prefixList []string) {
@ -119,27 +145,3 @@ func reverseDomain(domain string) string {
} }
return string(b) return string(b)
} }
func reverseDomainSuffix(domain string) string {
l := len(domain)
b := make([]byte, l+1)
for i := 0; i < l; {
r, n := utf8.DecodeRuneInString(domain[i:])
i += n
utf8.EncodeRune(b[l-i:], r)
}
b[l] = prefixLabel
return string(b)
}
func reverseDomainRoot(domain string) string {
l := len(domain)
b := make([]byte, l+1)
for i := 0; i < l; {
r, n := utf8.DecodeRuneInString(domain[i:])
i += n
utf8.EncodeRune(b[l-i:], r)
}
b[l] = rootLabel
return string(b)
}

View file

@ -12,6 +12,7 @@ import (
) )
func TestMatcher(t *testing.T) { func TestMatcher(t *testing.T) {
t.Parallel()
testDomain := []string{"example.com", "example.org"} testDomain := []string{"example.com", "example.org"}
testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"} testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"}
matcher := domain.NewMatcher(testDomain, testDomainSuffix, false) matcher := domain.NewMatcher(testDomain, testDomainSuffix, false)
@ -31,6 +32,7 @@ func TestMatcher(t *testing.T) {
} }
func TestMatcherLegacy(t *testing.T) { func TestMatcherLegacy(t *testing.T) {
t.Parallel()
testDomain := []string{"example.com", "example.org"} testDomain := []string{"example.com", "example.org"}
testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"} testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"}
matcher := domain.NewMatcher(testDomain, testDomainSuffix, true) matcher := domain.NewMatcher(testDomain, testDomainSuffix, true)
@ -57,6 +59,7 @@ type simpleRuleSet struct {
} }
func TestDumpLarge(t *testing.T) { func TestDumpLarge(t *testing.T) {
t.Parallel()
response, err := http.Get("https://raw.githubusercontent.com/MetaCubeX/meta-rules-dat/sing/geo/geosite/cn.json") response, err := http.Get("https://raw.githubusercontent.com/MetaCubeX/meta-rules-dat/sing/geo/geosite/cn.json")
require.NoError(t, err) require.NoError(t, err)
defer response.Body.Close() defer response.Body.Close()

View file

@ -1,12 +1,10 @@
package domain package domain
import ( import (
"encoding/binary"
"math/bits" "math/bits"
)
const ( "github.com/sagernet/sing/common/varbin"
prefixLabel = '\r'
rootLabel = '\n'
) )
// mod from https://github.com/openacid/succinct // mod from https://github.com/openacid/succinct
@ -45,46 +43,6 @@ func newSuccinctSet(keys []string) *succinctSet {
return ss return ss
} }
func (ss *succinctSet) Has(key string) bool {
var nodeId, bmIdx int
for i := 0; i < len(key); i++ {
currentChar := key[i]
for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := ss.labels[bmIdx-nodeId]
if nextLabel == prefixLabel {
return true
}
if nextLabel == rootLabel {
nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
hasNext := getBit(ss.leaves, nextNodeId) != 0
if currentChar == '.' && hasNext {
return true
}
}
if nextLabel == currentChar {
break
}
}
nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1
}
if getBit(ss.leaves, nodeId) != 0 {
return true
}
for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := ss.labels[bmIdx-nodeId]
if nextLabel == prefixLabel || nextLabel == rootLabel {
return true
}
}
}
func (ss *succinctSet) keys() []string { func (ss *succinctSet) keys() []string {
var result []string var result []string
var currentKey []byte var currentKey []byte
@ -113,6 +71,35 @@ func (ss *succinctSet) keys() []string {
return result return result
} }
type succinctSetData struct {
Reserved uint8
Leaves []uint64
LabelBitmap []uint64
Labels []byte
}
func readSuccinctSet(reader varbin.Reader) (*succinctSet, error) {
matcher, err := varbin.ReadValue[succinctSetData](reader, binary.BigEndian)
if err != nil {
return nil, err
}
set := &succinctSet{
leaves: matcher.Leaves,
labelBitmap: matcher.LabelBitmap,
labels: matcher.Labels,
}
set.init()
return set, nil
}
func (ss *succinctSet) Write(writer varbin.Writer) error {
return varbin.Write(writer, binary.BigEndian, succinctSetData{
Leaves: ss.leaves,
LabelBitmap: ss.labelBitmap,
Labels: ss.labels,
})
}
func setBit(bm *[]uint64, i int, v int) { func setBit(bm *[]uint64, i int, v int) {
for i>>6 >= len(*bm) { for i>>6 >= len(*bm) {
*bm = append(*bm, 0) *bm = append(*bm, 0)

View file

@ -12,3 +12,16 @@ func (e *causeError) Error() string {
func (e *causeError) Unwrap() error { func (e *causeError) Unwrap() error {
return e.cause return e.cause
} }
type causeError1 struct {
error
cause error
}
func (e *causeError1) Error() string {
return e.error.Error() + ": " + e.cause.Error()
}
func (e *causeError1) Unwrap() []error {
return []error{e.error, e.cause}
}

View file

@ -12,6 +12,7 @@ import (
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"
) )
// Deprecated: wtf is this?
type Handler interface { type Handler interface {
NewError(ctx context.Context, err error) NewError(ctx context.Context, err error)
} }
@ -31,6 +32,13 @@ func Cause(cause error, message ...any) error {
return &causeError{F.ToString(message...), cause} return &causeError{F.ToString(message...), cause}
} }
func Cause1(err error, cause error) error {
if cause == nil {
panic("cause on an nil error")
}
return &causeError1{err, cause}
}
func Extend(cause error, message ...any) error { func Extend(cause error, message ...any) error {
if cause == nil { if cause == nil {
panic("extend on an nil error") panic("extend on an nil error")
@ -39,11 +47,11 @@ func Extend(cause error, message ...any) error {
} }
func IsClosedOrCanceled(err error) bool { func IsClosedOrCanceled(err error) bool {
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded) return IsClosed(err) || IsCanceled(err) || IsTimeout(err)
} }
func IsClosed(err error) bool { func IsClosed(err error) bool {
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET) return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, syscall.ENOTCONN)
} }
func IsCanceled(err error) bool { func IsCanceled(err error) bool {

View file

@ -1,24 +1,14 @@
package exceptions package exceptions
import "github.com/sagernet/sing/common" import (
"errors"
type HasInnerError interface { "github.com/sagernet/sing/common"
Unwrap() error )
}
// Deprecated: Use errors.Unwrap instead.
func Unwrap(err error) error { func Unwrap(err error) error {
for { return errors.Unwrap(err)
inner, ok := err.(HasInnerError)
if !ok {
break
}
innerErr := inner.Unwrap()
if innerErr == nil {
break
}
err = innerErr
}
return err
} }
func Cast[T any](err error) (T, bool) { func Cast[T any](err error) (T, bool) {

View file

@ -63,12 +63,5 @@ func IsMulti(err error, targetList ...error) bool {
return true return true
} }
} }
err = Unwrap(err) return false
multiErr, isMulti := err.(MultiError)
if !isMulti {
return false
}
return common.All(multiErr.Unwrap(), func(it error) bool {
return IsMulti(it, targetList...)
})
} }

View file

@ -1,17 +1,21 @@
package exceptions package exceptions
import "net" import (
"errors"
"net"
)
type TimeoutError interface { type TimeoutError interface {
Timeout() bool Timeout() bool
} }
func IsTimeout(err error) bool { func IsTimeout(err error) bool {
if netErr, isNetErr := err.(net.Error); isNetErr { var netErr net.Error
//goland:noinspection GoDeprecation if errors.As(err, &netErr) {
//nolint:staticcheck //nolint:staticcheck
return netErr.Temporary() && netErr.Timeout() return netErr.Temporary() && netErr.Timeout()
} else if timeoutErr, isTimeout := Cast[TimeoutError](err); isTimeout { }
if timeoutErr, isTimeout := Cast[TimeoutError](err); isTimeout {
return timeoutErr.Timeout() return timeoutErr.Timeout()
} }
return false return false

View file

@ -2,13 +2,14 @@ package badjson
import ( import (
"bytes" "bytes"
"context"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json" "github.com/sagernet/sing/common/json"
) )
func Decode(content []byte) (any, error) { func Decode(ctx context.Context, content []byte) (any, error) {
decoder := json.NewDecoder(bytes.NewReader(content)) decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
return decodeJSON(decoder) return decodeJSON(decoder)
} }

View file

@ -1,6 +1,7 @@
package badjson package badjson
import ( import (
"context"
"os" "os"
"reflect" "reflect"
@ -9,75 +10,75 @@ import (
"github.com/sagernet/sing/common/json" "github.com/sagernet/sing/common/json"
) )
func Omitempty[T any](value T) (T, error) { func Omitempty[T any](ctx context.Context, value T) (T, error) {
objectContent, err := json.Marshal(value) objectContent, err := json.MarshalContext(ctx, value)
if err != nil { if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal object") return common.DefaultValue[T](), E.Cause(err, "marshal object")
} }
rawNewObject, err := Decode(objectContent) rawNewObject, err := Decode(ctx, objectContent)
if err != nil { if err != nil {
return common.DefaultValue[T](), err return common.DefaultValue[T](), err
} }
newObjectContent, err := json.Marshal(rawNewObject) newObjectContent, err := json.MarshalContext(ctx, rawNewObject)
if err != nil { if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal new object") return common.DefaultValue[T](), E.Cause(err, "marshal new object")
} }
var newObject T var newObject T
err = json.Unmarshal(newObjectContent, &newObject) err = json.UnmarshalContext(ctx, newObjectContent, &newObject)
if err != nil { if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object") return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
} }
return newObject, nil return newObject, nil
} }
func Merge[T any](source T, destination T, disableAppend bool) (T, error) { func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) {
rawSource, err := json.Marshal(source) rawSource, err := json.MarshalContext(ctx, source)
if err != nil { if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source") return common.DefaultValue[T](), E.Cause(err, "marshal source")
} }
rawDestination, err := json.Marshal(destination) rawDestination, err := json.MarshalContext(ctx, destination)
if err != nil { if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination") return common.DefaultValue[T](), E.Cause(err, "marshal destination")
} }
return MergeFrom[T](rawSource, rawDestination, disableAppend) return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
} }
func MergeFromSource[T any](rawSource json.RawMessage, destination T, disableAppend bool) (T, error) { func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
if rawSource == nil { if rawSource == nil {
return destination, nil return destination, nil
} }
rawDestination, err := json.Marshal(destination) rawDestination, err := json.MarshalContext(ctx, destination)
if err != nil { if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination") return common.DefaultValue[T](), E.Cause(err, "marshal destination")
} }
return MergeFrom[T](rawSource, rawDestination, disableAppend) return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
} }
func MergeFromDestination[T any](source T, rawDestination json.RawMessage, disableAppend bool) (T, error) { func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
if rawDestination == nil { if rawDestination == nil {
return source, nil return source, nil
} }
rawSource, err := json.Marshal(source) rawSource, err := json.MarshalContext(ctx, source)
if err != nil { if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source") return common.DefaultValue[T](), E.Cause(err, "marshal source")
} }
return MergeFrom[T](rawSource, rawDestination, disableAppend) return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
} }
func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) { func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
rawMerged, err := MergeJSON(rawSource, rawDestination, disableAppend) rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend)
if err != nil { if err != nil {
return common.DefaultValue[T](), E.Cause(err, "merge options") return common.DefaultValue[T](), E.Cause(err, "merge options")
} }
var merged T var merged T
err = json.Unmarshal(rawMerged, &merged) err = json.UnmarshalContext(ctx, rawMerged, &merged)
if err != nil { if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options") return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
} }
return merged, nil return merged, nil
} }
func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) { func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
if rawSource == nil && rawDestination == nil { if rawSource == nil && rawDestination == nil {
return nil, os.ErrInvalid return nil, os.ErrInvalid
} else if rawSource == nil { } else if rawSource == nil {
@ -85,16 +86,16 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disabl
} else if rawDestination == nil { } else if rawDestination == nil {
return rawSource, nil return rawSource, nil
} }
source, err := Decode(rawSource) source, err := Decode(ctx, rawSource)
if err != nil { if err != nil {
return nil, E.Cause(err, "decode source") return nil, E.Cause(err, "decode source")
} }
destination, err := Decode(rawDestination) destination, err := Decode(ctx, rawDestination)
if err != nil { if err != nil {
return nil, E.Cause(err, "decode destination") return nil, E.Cause(err, "decode destination")
} }
if source == nil { if source == nil {
return json.Marshal(destination) return json.MarshalContext(ctx, destination)
} else if destination == nil { } else if destination == nil {
return json.Marshal(source) return json.Marshal(source)
} }
@ -102,7 +103,7 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disabl
if err != nil { if err != nil {
return nil, err return nil, err
} }
return json.Marshal(merged) return json.MarshalContext(ctx, merged)
} }
func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) { func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) {

View file

@ -0,0 +1,68 @@
package badjson
import (
"context"
"reflect"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
cJSON "github.com/sagernet/sing/common/json/internal/contextjson"
)
func MarshallObjects(objects ...any) ([]byte, error) {
return MarshallObjectsContext(context.Background(), objects...)
}
func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) {
if len(objects) == 1 {
return json.Marshal(objects[0])
}
var content JSONObject
for _, object := range objects {
objectMap, err := newJSONObject(ctx, object)
if err != nil {
return nil, err
}
content.PutAll(objectMap)
}
return content.MarshalJSONContext(ctx)
}
func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error {
return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object)
}
func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
var content JSONObject
err := content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return err
}
for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) {
content.Remove(key)
}
if object == nil {
if content.IsEmpty() {
return nil
}
return E.New("unexpected key: ", content.Keys()[0])
}
inputContent, err = content.MarshalJSONContext(ctx)
if err != nil {
return err
}
return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object)
}
func newJSONObject(ctx context.Context, object any) (*JSONObject, error) {
inputContent, err := json.MarshalContext(ctx, object)
if err != nil {
return nil, err
}
var content JSONObject
err = content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return nil, err
}
return &content, nil
}

View file

@ -2,6 +2,7 @@ package badjson
import ( import (
"bytes" "bytes"
"context"
"strings" "strings"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
@ -28,6 +29,10 @@ func (m *JSONObject) IsEmpty() bool {
} }
func (m *JSONObject) MarshalJSON() ([]byte, error) { func (m *JSONObject) MarshalJSON() ([]byte, error) {
return m.MarshalJSONContext(context.Background())
}
func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
buffer.WriteString("{") buffer.WriteString("{")
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool { items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
@ -38,13 +43,13 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) {
}) })
iLen := len(items) iLen := len(items)
for i, entry := range items { for i, entry := range items {
keyContent, err := json.Marshal(entry.Key) keyContent, err := json.MarshalContext(ctx, entry.Key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
buffer.WriteString(strings.TrimSpace(string(keyContent))) buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ") buffer.WriteString(": ")
valueContent, err := json.Marshal(entry.Value) valueContent, err := json.MarshalContext(ctx, entry.Value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -58,7 +63,11 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) {
} }
func (m *JSONObject) UnmarshalJSON(content []byte) error { func (m *JSONObject) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content)) return m.UnmarshalJSONContext(context.Background(), content)
}
func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
m.Clear() m.Clear()
objectStart, err := decoder.Token() objectStart, err := decoder.Token()
if err != nil { if err != nil {

View file

@ -2,6 +2,7 @@ package badjson
import ( import (
"bytes" "bytes"
"context"
"strings" "strings"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@ -14,18 +15,22 @@ type TypedMap[K comparable, V any] struct {
} }
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) { func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
return m.MarshalJSONContext(context.Background())
}
func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
buffer.WriteString("{") buffer.WriteString("{")
items := m.Entries() items := m.Entries()
iLen := len(items) iLen := len(items)
for i, entry := range items { for i, entry := range items {
keyContent, err := json.Marshal(entry.Key) keyContent, err := json.MarshalContext(ctx, entry.Key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
buffer.WriteString(strings.TrimSpace(string(keyContent))) buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ") buffer.WriteString(": ")
valueContent, err := json.Marshal(entry.Value) valueContent, err := json.MarshalContext(ctx, entry.Value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -39,7 +44,11 @@ func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
} }
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error { func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content)) return m.UnmarshalJSONContext(context.Background(), content)
}
func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
m.Clear() m.Clear()
objectStart, err := decoder.Token() objectStart, err := decoder.Token()
if err != nil { if err != nil {
@ -47,7 +56,7 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
} else if objectStart != json.Delim('{') { } else if objectStart != json.Delim('{') {
return E.New("expected json object start, but starts with ", objectStart) return E.New("expected json object start, but starts with ", objectStart)
} }
err = m.decodeJSON(decoder) err = m.decodeJSON(ctx, decoder)
if err != nil { if err != nil {
return E.Cause(err, "decode json object content") return E.Cause(err, "decode json object content")
} }
@ -60,18 +69,18 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
return nil return nil
} }
func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error { func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error {
for decoder.More() { for decoder.More() {
keyToken, err := decoder.Token() keyToken, err := decoder.Token()
if err != nil { if err != nil {
return err return err
} }
keyContent, err := json.Marshal(keyToken) keyContent, err := json.MarshalContext(ctx, keyToken)
if err != nil { if err != nil {
return err return err
} }
var entryKey K var entryKey K
err = json.Unmarshal(keyContent, &entryKey) err = json.UnmarshalContext(ctx, keyContent, &entryKey)
if err != nil { if err != nil {
return err return err
} }

View file

@ -0,0 +1,32 @@
package badoption
import (
"time"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badoption/internal/my_time"
)
type Duration time.Duration
func (d Duration) Build() time.Duration {
return time.Duration(d)
}
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal((time.Duration)(d).String())
}
func (d *Duration) UnmarshalJSON(bytes []byte) error {
var value string
err := json.Unmarshal(bytes, &value)
if err != nil {
return err
}
duration, err := my_time.ParseDuration(value)
if err != nil {
return err
}
*d = Duration(duration)
return nil
}

View file

@ -0,0 +1,15 @@
package badoption
import "net/http"
type HTTPHeader map[string]Listable[string]
func (h HTTPHeader) Build() http.Header {
header := make(http.Header)
for name, values := range h {
for _, value := range values {
header.Add(name, value)
}
}
return header
}

View file

@ -0,0 +1,226 @@
package my_time
import (
"errors"
"time"
)
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
const durationDay = 24 * time.Hour
var unitMap = map[string]uint64{
"ns": uint64(time.Nanosecond),
"us": uint64(time.Microsecond),
"µs": uint64(time.Microsecond), // U+00B5 = micro symbol
"μs": uint64(time.Microsecond), // U+03BC = Greek letter mu
"ms": uint64(time.Millisecond),
"s": uint64(time.Second),
"m": uint64(time.Minute),
"h": uint64(time.Hour),
"d": uint64(durationDay),
}
// ParseDuration parses a duration string.
// A duration string is a possibly signed sequence of
// decimal numbers, each with optional fraction and a unit suffix,
// such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func ParseDuration(s string) (time.Duration, error) {
// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
orig := s
var d uint64
neg := false
// Consume [-+]?
if s != "" {
c := s[0]
if c == '-' || c == '+' {
neg = c == '-'
s = s[1:]
}
}
// Special case: if all that is left is "0", this is zero.
if s == "0" {
return 0, nil
}
if s == "" {
return 0, errors.New("time: invalid duration " + quote(orig))
}
for s != "" {
var (
v, f uint64 // integers before, after decimal point
scale float64 = 1 // value = v + f/scale
)
var err error
// The next character must be [0-9.]
if !(s[0] == '.' || '0' <= s[0] && s[0] <= '9') {
return 0, errors.New("time: invalid duration " + quote(orig))
}
// Consume [0-9]*
pl := len(s)
v, s, err = leadingInt(s)
if err != nil {
return 0, errors.New("time: invalid duration " + quote(orig))
}
pre := pl != len(s) // whether we consumed anything before a period
// Consume (\.[0-9]*)?
post := false
if s != "" && s[0] == '.' {
s = s[1:]
pl := len(s)
f, scale, s = leadingFraction(s)
post = pl != len(s)
}
if !pre && !post {
// no digits (e.g. ".s" or "-.s")
return 0, errors.New("time: invalid duration " + quote(orig))
}
// Consume unit.
i := 0
for ; i < len(s); i++ {
c := s[i]
if c == '.' || '0' <= c && c <= '9' {
break
}
}
if i == 0 {
return 0, errors.New("time: missing unit in duration " + quote(orig))
}
u := s[:i]
s = s[i:]
unit, ok := unitMap[u]
if !ok {
return 0, errors.New("time: unknown unit " + quote(u) + " in duration " + quote(orig))
}
if v > 1<<63/unit {
// overflow
return 0, errors.New("time: invalid duration " + quote(orig))
}
v *= unit
if f > 0 {
// float64 is needed to be nanosecond accurate for fractions of hours.
// v >= 0 && (f*unit/scale) <= 3.6e+12 (ns/h, h is the largest unit)
v += uint64(float64(f) * (float64(unit) / scale))
if v > 1<<63 {
// overflow
return 0, errors.New("time: invalid duration " + quote(orig))
}
}
d += v
if d > 1<<63 {
return 0, errors.New("time: invalid duration " + quote(orig))
}
}
if neg {
return -time.Duration(d), nil
}
if d > 1<<63-1 {
return 0, errors.New("time: invalid duration " + quote(orig))
}
return time.Duration(d), nil
}
var errLeadingInt = errors.New("time: bad [0-9]*") // never printed
// leadingInt consumes the leading [0-9]* from s.
func leadingInt[bytes []byte | string](s bytes) (x uint64, rem bytes, err error) {
i := 0
for ; i < len(s); i++ {
c := s[i]
if c < '0' || c > '9' {
break
}
if x > 1<<63/10 {
// overflow
return 0, rem, errLeadingInt
}
x = x*10 + uint64(c) - '0'
if x > 1<<63 {
// overflow
return 0, rem, errLeadingInt
}
}
return x, s[i:], nil
}
// leadingFraction consumes the leading [0-9]* from s.
// It is used only for fractions, so does not return an error on overflow,
// it just stops accumulating precision.
func leadingFraction(s string) (x uint64, scale float64, rem string) {
i := 0
scale = 1
overflow := false
for ; i < len(s); i++ {
c := s[i]
if c < '0' || c > '9' {
break
}
if overflow {
continue
}
if x > (1<<63-1)/10 {
// It's possible for overflow to give a positive number, so take care.
overflow = true
continue
}
y := x*10 + uint64(c) - '0'
if y > 1<<63 {
overflow = true
continue
}
x = y
scale *= 10
}
return x, scale, s[i:]
}
// These are borrowed from unicode/utf8 and strconv and replicate behavior in
// that package, since we can't take a dependency on either.
const (
lowerhex = "0123456789abcdef"
runeSelf = 0x80
runeError = '\uFFFD'
)
func quote(s string) string {
buf := make([]byte, 1, len(s)+2) // slice will be at least len(s) + quotes
buf[0] = '"'
for i, c := range s {
if c >= runeSelf || c < ' ' {
// This means you are asking us to parse a time.Duration or
// time.Location with unprintable or non-ASCII characters in it.
// We don't expect to hit this case very often. We could try to
// reproduce strconv.Quote's behavior with full fidelity but
// given how rarely we expect to hit these edge cases, speed and
// conciseness are better.
var width int
if c == runeError {
width = 1
if i+2 < len(s) && s[i:i+3] == string(runeError) {
width = 3
}
} else {
width = len(string(c))
}
for j := 0; j < width; j++ {
buf = append(buf, `\x`...)
buf = append(buf, lowerhex[s[i+j]>>4])
buf = append(buf, lowerhex[s[i+j]&0xF])
}
} else {
if c == '"' || c == '\\' {
buf = append(buf, '\\')
}
buf = append(buf, string(c)...)
}
}
buf = append(buf, '"')
return string(buf)
}

View file

@ -0,0 +1,35 @@
package badoption
import (
"context"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
type Listable[T any] []T
func (l Listable[T]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
arrayList := []T(l)
if len(arrayList) == 1 {
return json.Marshal(arrayList[0])
}
return json.MarshalContext(ctx, arrayList)
}
func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
if string(content) == "null" {
return nil
}
var singleItem T
err := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem)
if err == nil {
*l = []T{singleItem}
return nil
}
newErr := json.UnmarshalContextDisallowUnknownFields(ctx, content, (*[]T)(l))
if newErr == nil {
return nil
}
return E.Errors(err, newErr)
}

View file

@ -0,0 +1,98 @@
package badoption
import (
"net/netip"
"github.com/sagernet/sing/common/json"
)
type Addr netip.Addr
func (a *Addr) Build(defaultAddr netip.Addr) netip.Addr {
if a == nil {
return defaultAddr
}
return netip.Addr(*a)
}
func (a *Addr) MarshalJSON() ([]byte, error) {
return json.Marshal(netip.Addr(*a).String())
}
func (a *Addr) UnmarshalJSON(content []byte) error {
var value string
err := json.Unmarshal(content, &value)
if err != nil {
return err
}
addr, err := netip.ParseAddr(value)
if err != nil {
return err
}
*a = Addr(addr)
return nil
}
type Prefix netip.Prefix
func (p *Prefix) Build(defaultPrefix netip.Prefix) netip.Prefix {
if p == nil {
return defaultPrefix
}
return netip.Prefix(*p)
}
func (p *Prefix) MarshalJSON() ([]byte, error) {
return json.Marshal(netip.Prefix(*p).String())
}
func (p *Prefix) UnmarshalJSON(content []byte) error {
var value string
err := json.Unmarshal(content, &value)
if err != nil {
return err
}
prefix, err := netip.ParsePrefix(value)
if err != nil {
return err
}
*p = Prefix(prefix)
return nil
}
type Prefixable netip.Prefix
func (p *Prefixable) Build(defaultPrefix netip.Prefix) netip.Prefix {
if p == nil {
return defaultPrefix
}
return netip.Prefix(*p)
}
func (p *Prefixable) MarshalJSON() ([]byte, error) {
prefix := netip.Prefix(*p)
if prefix.Bits() == prefix.Addr().BitLen() {
return json.Marshal(prefix.Addr().String())
} else {
return json.Marshal(prefix.String())
}
}
func (p *Prefixable) UnmarshalJSON(content []byte) error {
var value string
err := json.Unmarshal(content, &value)
if err != nil {
return err
}
prefix, prefixErr := netip.ParsePrefix(value)
if prefixErr == nil {
*p = Prefixable(prefix)
return nil
}
addr, addrErr := netip.ParseAddr(value)
if addrErr == nil {
*p = Prefixable(netip.PrefixFrom(addr, addr.BitLen()))
return nil
}
return prefixErr
}

View file

@ -0,0 +1,31 @@
package badoption
import (
"regexp"
"github.com/sagernet/sing/common/json"
)
type Regexp regexp.Regexp
func (r *Regexp) Build() *regexp.Regexp {
return (*regexp.Regexp)(r)
}
func (r *Regexp) MarshalJSON() ([]byte, error) {
return json.Marshal((*regexp.Regexp)(r).String())
}
func (r *Regexp) UnmarshalJSON(content []byte) error {
var stringValue string
err := json.Unmarshal(content, &stringValue)
if err != nil {
return err
}
regex, err := regexp.Compile(stringValue)
if err != nil {
return err
}
*r = Regexp(*regex)
return nil
}

View file

@ -0,0 +1,23 @@
package json
import (
"context"
"github.com/sagernet/sing/common/json/internal/contextjson"
)
var (
MarshalContext = json.MarshalContext
UnmarshalContext = json.UnmarshalContext
NewEncoderContext = json.NewEncoderContext
NewDecoderContext = json.NewDecoderContext
UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields
)
type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}
type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}

View file

@ -0,0 +1,11 @@
package json
import "context"
type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}
type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}

View file

@ -0,0 +1,43 @@
package json_test
import (
"context"
"testing"
"github.com/sagernet/sing/common/json/internal/contextjson"
"github.com/stretchr/testify/require"
)
type myStruct struct {
value string
}
func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) {
return json.Marshal(ctx.Value("key").(string))
}
func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error {
m.value = ctx.Value("key").(string)
return nil
}
//nolint:staticcheck
func TestMarshalContext(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), "key", "value")
var s myStruct
b, err := json.MarshalContext(ctx, &s)
require.NoError(t, err)
require.Equal(t, []byte(`"value"`), b)
}
//nolint:staticcheck
func TestUnmarshalContext(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), "key", "value")
var s myStruct
err := json.UnmarshalContext(ctx, []byte(`{}`), &s)
require.NoError(t, err)
require.Equal(t, "value", s.value)
}

View file

@ -8,6 +8,7 @@
package json package json
import ( import (
"context"
"encoding" "encoding"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
@ -95,10 +96,15 @@ import (
// Instead, they are replaced by the Unicode replacement // Instead, they are replaced by the Unicode replacement
// character U+FFFD. // character U+FFFD.
func Unmarshal(data []byte, v any) error { func Unmarshal(data []byte, v any) error {
return UnmarshalContext(context.Background(), data, v)
}
func UnmarshalContext(ctx context.Context, data []byte, v any) error {
// Check for well-formedness. // Check for well-formedness.
// Avoids filling out half a data structure // Avoids filling out half a data structure
// before discovering a JSON syntax error. // before discovering a JSON syntax error.
var d decodeState var d decodeState
d.ctx = ctx
err := checkValid(data, &d.scan) err := checkValid(data, &d.scan)
if err != nil { if err != nil {
return err return err
@ -209,6 +215,7 @@ type errorContext struct {
// decodeState represents the state while decoding a JSON value. // decodeState represents the state while decoding a JSON value.
type decodeState struct { type decodeState struct {
ctx context.Context
data []byte data []byte
off int // next read offset in data off int // next read offset in data
opcode int // last read result opcode int // last read result
@ -428,7 +435,7 @@ func (d *decodeState) valueQuoted() any {
// If it encounters an Unmarshaler, indirect stops and returns that. // If it encounters an Unmarshaler, indirect stops and returns that.
// If decodingNull is true, indirect stops at the first settable pointer so it // If decodingNull is true, indirect stops at the first settable pointer so it
// can be set to nil. // can be set to nil.
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) {
// Issue #24153 indicates that it is generally not a guaranteed property // Issue #24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem() // that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from // and expect the value to still be settable for values derived from
@ -482,11 +489,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
} }
if v.Type().NumMethod() > 0 && v.CanInterface() { if v.Type().NumMethod() > 0 && v.CanInterface() {
if u, ok := v.Interface().(Unmarshaler); ok { if u, ok := v.Interface().(Unmarshaler); ok {
return u, nil, reflect.Value{} return u, nil, nil, reflect.Value{}
}
if cu, ok := v.Interface().(ContextUnmarshaler); ok {
return nil, cu, nil, reflect.Value{}
} }
if !decodingNull { if !decodingNull {
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return nil, u, reflect.Value{} return nil, nil, u, reflect.Value{}
} }
} }
} }
@ -498,14 +508,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
v = v.Elem() v = v.Elem()
} }
} }
return nil, nil, v return nil, nil, nil, v
} }
// array consumes an array from d.data[d.off-1:], decoding into v. // array consumes an array from d.data[d.off-1:], decoding into v.
// The first byte of the array ('[') has been read already. // The first byte of the array ('[') has been read already.
func (d *decodeState) array(v reflect.Value) error { func (d *decodeState) array(v reflect.Value) error {
// Check for unmarshaler. // Check for unmarshaler.
u, ut, pv := indirect(v, false) u, cu, ut, pv := indirect(v, false)
if u != nil { if u != nil {
start := d.readIndex() start := d.readIndex()
d.skip() d.skip()
@ -515,6 +525,15 @@ func (d *decodeState) array(v reflect.Value) error {
} }
return nil return nil
} }
if cu != nil {
start := d.readIndex()
d.skip()
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil { if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
d.skip() d.skip()
@ -612,7 +631,7 @@ var (
// The first byte ('{') of the object has been read already. // The first byte ('{') of the object has been read already.
func (d *decodeState) object(v reflect.Value) error { func (d *decodeState) object(v reflect.Value) error {
// Check for unmarshaler. // Check for unmarshaler.
u, ut, pv := indirect(v, false) u, cu, ut, pv := indirect(v, false)
if u != nil { if u != nil {
start := d.readIndex() start := d.readIndex()
d.skip() d.skip()
@ -622,6 +641,15 @@ func (d *decodeState) object(v reflect.Value) error {
} }
return nil return nil
} }
if cu != nil {
start := d.readIndex()
d.skip()
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil { if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)}) d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
d.skip() d.skip()
@ -870,7 +898,7 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
return nil return nil
} }
isNull := item[0] == 'n' // null isNull := item[0] == 'n' // null
u, ut, pv := indirect(v, isNull) u, cu, ut, pv := indirect(v, isNull)
if u != nil { if u != nil {
err := u.UnmarshalJSON(item) err := u.UnmarshalJSON(item)
if err != nil { if err != nil {
@ -878,6 +906,13 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
} }
return nil return nil
} }
if cu != nil {
err := cu.UnmarshalJSONContext(d.ctx, item)
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil { if ut != nil {
if item[0] != '"' { if item[0] != '"' {
if fromQuoted { if fromQuoted {

View file

@ -12,6 +12,7 @@ package json
import ( import (
"bytes" "bytes"
"context"
"encoding" "encoding"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
@ -156,7 +157,11 @@ import (
// handle them. Passing cyclic structures to Marshal will result in // handle them. Passing cyclic structures to Marshal will result in
// an error. // an error.
func Marshal(v any) ([]byte, error) { func Marshal(v any) ([]byte, error) {
e := newEncodeState() return MarshalContext(context.Background(), v)
}
func MarshalContext(ctx context.Context, v any) ([]byte, error) {
e := newEncodeState(ctx)
defer encodeStatePool.Put(e) defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: true}) err := e.marshal(v, encOpts{escapeHTML: true})
@ -251,6 +256,7 @@ var hex = "0123456789abcdef"
type encodeState struct { type encodeState struct {
bytes.Buffer // accumulated output bytes.Buffer // accumulated output
ctx context.Context
// Keep track of what pointers we've seen in the current recursive call // Keep track of what pointers we've seen in the current recursive call
// path, to avoid cycles that could lead to a stack overflow. Only do // path, to avoid cycles that could lead to a stack overflow. Only do
// the relatively expensive map operations if ptrLevel is larger than // the relatively expensive map operations if ptrLevel is larger than
@ -264,7 +270,7 @@ const startDetectingCyclesAfter = 1000
var encodeStatePool sync.Pool var encodeStatePool sync.Pool
func newEncodeState() *encodeState { func newEncodeState(ctx context.Context) *encodeState {
if v := encodeStatePool.Get(); v != nil { if v := encodeStatePool.Get(); v != nil {
e := v.(*encodeState) e := v.(*encodeState)
e.Reset() e.Reset()
@ -274,7 +280,7 @@ func newEncodeState() *encodeState {
e.ptrLevel = 0 e.ptrLevel = 0
return e return e
} }
return &encodeState{ptrSeen: make(map[any]struct{})} return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})}
} }
// jsonError is an error wrapper type for internal use only. // jsonError is an error wrapper type for internal use only.
@ -371,8 +377,9 @@ func typeEncoder(t reflect.Type) encoderFunc {
} }
var ( var (
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
) )
// newTypeEncoder constructs an encoderFunc for a type. // newTypeEncoder constructs an encoderFunc for a type.
@ -385,9 +392,15 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) { if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
} }
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) {
return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Implements(marshalerType) { if t.Implements(marshalerType) {
return marshalerEncoder return marshalerEncoder
} }
if t.Implements(contextMarshalerType) {
return contextMarshalerEncoder
}
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) { if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false)) return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
} }
@ -470,6 +483,47 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
} }
} }
func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null")
return
}
m, ok := v.Interface().(ContextMarshaler)
if !ok {
e.WriteString("null")
return
}
b, err := m.MarshalJSONContext(e.ctx)
if err == nil {
e.Grow(len(b))
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
va := v.Addr()
if va.IsNil() {
e.WriteString("null")
return
}
m := va.Interface().(ContextMarshaler)
b, err := m.MarshalJSONContext(e.ctx)
if err == nil {
e.Grow(len(b))
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() { if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null") e.WriteString("null")
@ -827,7 +881,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc {
// Byte slices get special treatment; arrays don't. // Byte slices get special treatment; arrays don't.
if t.Elem().Kind() == reflect.Uint8 { if t.Elem().Kind() == reflect.Uint8 {
p := reflect.PointerTo(t.Elem()) p := reflect.PointerTo(t.Elem())
if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) { if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) {
return encodeByteSlice return encodeByteSlice
} }
} }

View file

@ -0,0 +1,20 @@
package json
import (
"reflect"
"github.com/sagernet/sing/common"
)
func ObjectKeys(object reflect.Type) []string {
switch object.Kind() {
case reflect.Pointer:
return ObjectKeys(object.Elem())
case reflect.Struct:
default:
panic("invalid non-struct input")
}
return common.Map(cachedTypeFields(object).list, func(field field) string {
return field.name
})
}

View file

@ -0,0 +1,26 @@
package json_test
import (
"reflect"
"testing"
json "github.com/sagernet/sing/common/json/internal/contextjson"
"github.com/stretchr/testify/require"
)
type MyObject struct {
Hello string `json:"hello,omitempty"`
MyWorld
MyWorld2 string `json:"-"`
}
type MyWorld struct {
World string `json:"world,omitempty"`
}
func TestObjectKeys(t *testing.T) {
t.Parallel()
keys := json.ObjectKeys(reflect.TypeOf(&MyObject{}))
require.Equal(t, []string{"hello", "world"}, keys)
}

View file

@ -6,6 +6,7 @@ package json
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"io" "io"
) )
@ -29,7 +30,11 @@ type Decoder struct {
// The decoder introduces its own buffering and may // The decoder introduces its own buffering and may
// read data from r beyond the JSON values requested. // read data from r beyond the JSON values requested.
func NewDecoder(r io.Reader) *Decoder { func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r} return NewDecoderContext(context.Background(), r)
}
func NewDecoderContext(ctx context.Context, r io.Reader) *Decoder {
return &Decoder{r: r, d: decodeState{ctx: ctx}}
} }
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a // UseNumber causes the Decoder to unmarshal a number into an interface{} as a
@ -183,6 +188,7 @@ func nonSpace(b []byte) bool {
// An Encoder writes JSON values to an output stream. // An Encoder writes JSON values to an output stream.
type Encoder struct { type Encoder struct {
ctx context.Context
w io.Writer w io.Writer
err error err error
escapeHTML bool escapeHTML bool
@ -194,7 +200,11 @@ type Encoder struct {
// NewEncoder returns a new encoder that writes to w. // NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder { func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w, escapeHTML: true} return NewEncoderContext(context.Background(), w)
}
func NewEncoderContext(ctx context.Context, w io.Writer) *Encoder {
return &Encoder{ctx: ctx, w: w, escapeHTML: true}
} }
// Encode writes the JSON encoding of v to the stream, // Encode writes the JSON encoding of v to the stream,
@ -207,7 +217,7 @@ func (enc *Encoder) Encode(v any) error {
return enc.err return enc.err
} }
e := newEncodeState() e := newEncodeState(enc.ctx)
defer encodeStatePool.Put(e) defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML}) err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})

View file

@ -1,5 +1,7 @@
package json package json
import "context"
func UnmarshalDisallowUnknownFields(data []byte, v any) error { func UnmarshalDisallowUnknownFields(data []byte, v any) error {
var d decodeState var d decodeState
d.disallowUnknownFields = true d.disallowUnknownFields = true
@ -10,3 +12,15 @@ func UnmarshalDisallowUnknownFields(data []byte, v any) error {
d.init(data) d.init(data)
return d.unmarshal(v) return d.unmarshal(v)
} }
func UnmarshalContextDisallowUnknownFields(ctx context.Context, data []byte, v any) error {
var d decodeState
d.ctx = ctx
d.disallowUnknownFields = true
err := checkValid(data, &d.scan)
if err != nil {
return err
}
d.init(data)
return d.unmarshal(v)
}

View file

@ -2,6 +2,7 @@ package json
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"strings" "strings"
@ -10,7 +11,11 @@ import (
) )
func UnmarshalExtended[T any](content []byte) (T, error) { func UnmarshalExtended[T any](content []byte) (T, error) {
decoder := NewDecoder(NewCommentFilter(bytes.NewReader(content))) return UnmarshalExtendedContext[T](context.Background(), content)
}
func UnmarshalExtendedContext[T any](ctx context.Context, content []byte) (T, error) {
decoder := NewDecoderContext(ctx, NewCommentFilter(bytes.NewReader(content)))
var value T var value T
err := decoder.Decode(&value) err := decoder.Decode(&value)
if err == nil { if err == nil {

View file

@ -1,5 +1,6 @@
package metadata package metadata
// Deprecated: wtf is this?
type Metadata struct { type Metadata struct {
Protocol string Protocol string
Source Socksaddr Source Socksaddr

15
common/minmax.go Normal file
View file

@ -0,0 +1,15 @@
//go:build go1.21
package common
import (
"cmp"
)
func Min[T cmp.Ordered](x, y T) T {
return min(x, y)
}
func Max[T cmp.Ordered](x, y T) T {
return max(x, y)
}

19
common/minmax_compat.go Normal file
View file

@ -0,0 +1,19 @@
//go:build go1.20 && !go1.21
package common
import "github.com/sagernet/sing/common/x/constraints"
func Min[T constraints.Ordered](x, y T) T {
if x < y {
return x
}
return y
}
func Max[T constraints.Ordered](x, y T) T {
if x < y {
return y
}
return x
}

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"io" "io"
"net" "net"
"sync"
"time" "time"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
@ -70,8 +71,39 @@ type ExtendedConn interface {
net.Conn net.Conn
} }
type CloseHandlerFunc = func(it error)
func AppendClose(parent CloseHandlerFunc, onClose CloseHandlerFunc) CloseHandlerFunc {
if onClose == nil {
panic("nil onClose")
}
if parent == nil {
return onClose
}
return func(it error) {
onClose(it)
parent(it)
}
}
func OnceClose(onClose CloseHandlerFunc) CloseHandlerFunc {
var once sync.Once
return func(it error) {
once.Do(func() {
onClose(it)
})
}
}
// Deprecated: Use TCPConnectionHandlerEx instead.
type TCPConnectionHandler interface { type TCPConnectionHandler interface {
NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error NewConnection(ctx context.Context, conn net.Conn,
//nolint:staticcheck
metadata M.Metadata) error
}
type TCPConnectionHandlerEx interface {
NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc)
} }
type NetPacketConn interface { type NetPacketConn interface {
@ -85,12 +117,26 @@ type BindPacketConn interface {
net.Conn net.Conn
} }
// Deprecated: Use UDPHandlerEx instead.
type UDPHandler interface { type UDPHandler interface {
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer,
//nolint:staticcheck
metadata M.Metadata) error
} }
type UDPHandlerEx interface {
NewPacketEx(buffer *buf.Buffer, source M.Socksaddr)
}
// Deprecated: Use UDPConnectionHandlerEx instead.
type UDPConnectionHandler interface { type UDPConnectionHandler interface {
NewPacketConnection(ctx context.Context, conn PacketConn, metadata M.Metadata) error NewPacketConnection(ctx context.Context, conn PacketConn,
//nolint:staticcheck
metadata M.Metadata) error
}
type UDPConnectionHandlerEx interface {
NewPacketConnectionEx(ctx context.Context, conn PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc)
} }
type CachedReader interface { type CachedReader interface {
@ -101,11 +147,6 @@ type CachedPacketReader interface {
ReadCachedPacket() *PacketBuffer ReadCachedPacket() *PacketBuffer
} }
type PacketBuffer struct {
Buffer *buf.Buffer
Destination M.Socksaddr
}
type WithUpstreamReader interface { type WithUpstreamReader interface {
UpstreamReader() any UpstreamReader() any
} }

View file

@ -13,10 +13,6 @@ type Dialer interface {
ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error)
} }
type PayloadDialer interface {
DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error)
}
type ParallelDialer interface { type ParallelDialer interface {
Dialer Dialer
DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error)

View file

@ -15,19 +15,39 @@ type ReadWaitOptions struct {
MTU int MTU int
} }
func NewReadWaitOptions(source any, destination any) ReadWaitOptions {
return ReadWaitOptions{
FrontHeadroom: CalculateFrontHeadroom(destination),
RearHeadroom: CalculateRearHeadroom(destination),
MTU: CalculateMTU(source, destination),
}
}
func (o ReadWaitOptions) NeedHeadroom() bool { func (o ReadWaitOptions) NeedHeadroom() bool {
return o.FrontHeadroom > 0 || o.RearHeadroom > 0 return o.FrontHeadroom > 0 || o.RearHeadroom > 0
} }
func (o ReadWaitOptions) Copy(buffer *buf.Buffer) *buf.Buffer {
if o.FrontHeadroom > buffer.Start() ||
o.RearHeadroom > buffer.FreeLen() {
newBuffer := o.newBuffer(buf.UDPBufferSize, false)
newBuffer.Write(buffer.Bytes())
buffer.Release()
return newBuffer
} else {
return buffer
}
}
func (o ReadWaitOptions) NewBuffer() *buf.Buffer { func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
return o.newBuffer(buf.BufferSize) return o.newBuffer(buf.BufferSize, true)
} }
func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer { func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
return o.newBuffer(buf.UDPBufferSize) return o.newBuffer(buf.UDPBufferSize, true)
} }
func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer { func (o ReadWaitOptions) newBuffer(defaultBufferSize int, reserve bool) *buf.Buffer {
var bufferSize int var bufferSize int
if o.MTU > 0 { if o.MTU > 0 {
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
@ -38,7 +58,7 @@ func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer {
if o.FrontHeadroom > 0 { if o.FrontHeadroom > 0 {
buffer.Resize(o.FrontHeadroom, 0) buffer.Resize(o.FrontHeadroom, 0)
} }
if o.RearHeadroom > 0 { if o.RearHeadroom > 0 && reserve {
buffer.Reserve(o.RearHeadroom) buffer.Reserve(o.RearHeadroom)
} }
return buffer return buffer

View file

@ -1,6 +1,9 @@
package network package network
import ( import (
"io"
"net"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
@ -13,17 +16,75 @@ type HandshakeSuccess interface {
HandshakeSuccess() error HandshakeSuccess() error
} }
func ReportHandshakeFailure(conn any, err error) error { type ConnHandshakeSuccess interface {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](conn); isHandshakeConn { ConnHandshakeSuccess(conn net.Conn) error
}
type PacketConnHandshakeSuccess interface {
PacketConnHandshakeSuccess(conn net.PacketConn) error
}
func ReportHandshakeFailure(reporter any, err error) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn {
return E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error { return E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error {
return E.Cause(err, "write handshake failure") return E.Cause(err, "write handshake failure")
}) })
} }
return nil
}
func CloseOnHandshakeFailure(reporter io.Closer, onClose CloseHandlerFunc, err error) error {
if err != nil {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn {
hErr := handshakeConn.HandshakeFailure(err)
err = E.Append(err, hErr, func(err error) error {
if closer, isCloser := reporter.(io.Closer); isCloser {
err = E.Append(err, closer.Close(), func(err error) error {
return E.Cause(err, "close")
})
}
return E.Cause(err, "write handshake failure")
})
} else {
if tcpConn, isTCPConn := common.Cast[interface {
SetLinger(sec int) error
}](reporter); isTCPConn {
tcpConn.SetLinger(0)
}
}
err = E.Append(err, reporter.Close(), func(err error) error {
return E.Cause(err, "close")
})
}
if onClose != nil {
onClose(err)
}
return err return err
} }
func ReportHandshakeSuccess(conn any) error { // Deprecated: use ReportConnHandshakeSuccess/ReportPacketConnHandshakeSuccess instead
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](conn); isHandshakeConn { func ReportHandshakeSuccess(reporter any) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.HandshakeSuccess()
}
return nil
}
func ReportConnHandshakeSuccess(reporter any, conn net.Conn) error {
if handshakeConn, isHandshakeConn := common.Cast[ConnHandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.ConnHandshakeSuccess(conn)
}
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.HandshakeSuccess()
}
return nil
}
func ReportPacketConnHandshakeSuccess(reporter any, conn net.PacketConn) error {
if handshakeConn, isHandshakeConn := common.Cast[PacketConnHandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.PacketConnHandshakeSuccess(conn)
}
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.HandshakeSuccess() return handshakeConn.HandshakeSuccess()
} }
return nil return nil

35
common/network/packet.go Normal file
View file

@ -0,0 +1,35 @@
package network
import (
"sync"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
type PacketBuffer struct {
Buffer *buf.Buffer
Destination M.Socksaddr
}
var packetPool = sync.Pool{
New: func() any {
return new(PacketBuffer)
},
}
func NewPacketBuffer() *PacketBuffer {
return packetPool.Get().(*PacketBuffer)
}
func PutPacketBuffer(packet *PacketBuffer) {
*packet = PacketBuffer{}
packetPool.Put(packet)
}
func ReleaseMultiPacketBuffer(packetBuffers []*PacketBuffer) {
for _, packet := range packetBuffers {
packet.Buffer.Release()
PutPacketBuffer(packet)
}
}

View file

@ -11,6 +11,7 @@ type ThreadUnsafeWriter interface {
} }
// Deprecated: Use ReadWaiter interface instead. // Deprecated: Use ReadWaiter interface instead.
type ThreadSafeReader interface { type ThreadSafeReader interface {
// Deprecated: Use ReadWaiter interface instead. // Deprecated: Use ReadWaiter interface instead.
ReadBufferThreadSafe() (buffer *buf.Buffer, err error) ReadBufferThreadSafe() (buffer *buf.Buffer, err error)
@ -18,7 +19,6 @@ type ThreadSafeReader interface {
// Deprecated: Use ReadWaiter interface instead. // Deprecated: Use ReadWaiter interface instead.
type ThreadSafePacketReader interface { type ThreadSafePacketReader interface {
// Deprecated: Use ReadWaiter interface instead.
ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error) ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error)
} }

View file

@ -26,6 +26,7 @@ type Options struct {
Logger logger.Logger Logger logger.Logger
Server M.Socksaddr Server M.Socksaddr
Interval time.Duration Interval time.Duration
Timeout time.Duration
WriteToSystem bool WriteToSystem bool
} }
@ -39,6 +40,7 @@ type Service struct {
server M.Socksaddr server M.Socksaddr
writeToSystem bool writeToSystem bool
ticker *time.Ticker ticker *time.Ticker
timeout time.Duration
clockOffset time.Duration clockOffset time.Duration
pause pause.Manager pause pause.Manager
} }
@ -81,6 +83,7 @@ func NewService(options Options) *Service {
writeToSystem: options.WriteToSystem, writeToSystem: options.WriteToSystem,
server: destination, server: destination,
ticker: time.NewTicker(interval), ticker: time.NewTicker(interval),
timeout: options.Timeout,
pause: service.FromContext[pause.Manager](ctx), pause: service.FromContext[pause.Manager](ctx),
} }
} }
@ -88,9 +91,10 @@ func NewService(options Options) *Service {
func (s *Service) Start() error { func (s *Service) Start() error {
err := s.update() err := s.update()
if err != nil { if err != nil {
return E.Cause(err, "initialize time") s.logger.Error(E.Cause(err, "initialize time"))
} else {
s.logger.Info("updated time: ", s.TimeFunc()().Local().Format(TimeLayout))
} }
s.logger.Info("updated time: ", s.TimeFunc()().Local().Format(TimeLayout))
go s.loopUpdate() go s.loopUpdate()
return nil return nil
} }
@ -124,15 +128,23 @@ func (s *Service) loopUpdate() {
} }
err := s.update() err := s.update()
if err == nil { if err == nil {
s.logger.Debug("updated time: ", s.TimeFunc()().Local().Format(TimeLayout)) s.logger.Info("updated time: ", s.TimeFunc()().Local().Format(TimeLayout))
} else { } else {
s.logger.Warn("update time: ", err) s.logger.Error("update time: ", err)
} }
} }
} }
func (s *Service) update() error { func (s *Service) update() error {
response, err := Exchange(s.ctx, s.dialer, s.server) ctx := s.ctx
var cancel context.CancelFunc
if s.timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, s.timeout)
}
response, err := Exchange(ctx, s.dialer, s.server)
if cancel != nil {
cancel()
}
if err != nil { if err != nil {
return err return err
} }
@ -140,7 +152,7 @@ func (s *Service) update() error {
if s.writeToSystem { if s.writeToSystem {
writeErr := SetSystemTime(s.TimeFunc()()) writeErr := SetSystemTime(s.TimeFunc()())
if writeErr != nil { if writeErr != nil {
s.logger.Warn("write time to system: ", writeErr) s.logger.Error("write time to system: ", writeErr)
} }
} }
return nil return nil

View file

@ -0,0 +1,6 @@
package ntp
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
// https://learn.microsoft.com/en-us/windows/win32/api/sysinfoapi/nf-sysinfoapi-setsystemtime
//sys setSystemTime(lpSystemTime *windows.Systemtime) (err error) = kernel32.SetSystemTime

View file

@ -2,12 +2,12 @@ package ntp
import ( import (
"time" "time"
"unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
func SetSystemTime(nowTime time.Time) error { func SetSystemTime(nowTime time.Time) error {
nowTime = nowTime.UTC()
var systemTime windows.Systemtime var systemTime windows.Systemtime
systemTime.Year = uint16(nowTime.Year()) systemTime.Year = uint16(nowTime.Year())
systemTime.Month = uint16(nowTime.Month()) systemTime.Month = uint16(nowTime.Month())
@ -16,17 +16,5 @@ func SetSystemTime(nowTime time.Time) error {
systemTime.Minute = uint16(nowTime.Minute()) systemTime.Minute = uint16(nowTime.Minute())
systemTime.Second = uint16(nowTime.Second()) systemTime.Second = uint16(nowTime.Second())
systemTime.Milliseconds = uint16(nowTime.UnixMilli() - nowTime.Unix()*1000) systemTime.Milliseconds = uint16(nowTime.UnixMilli() - nowTime.Unix()*1000)
return setSystemTime(&systemTime)
dllKernel32 := windows.NewLazySystemDLL("kernel32.dll")
proc := dllKernel32.NewProc("SetSystemTime")
_, _, err := proc.Call(
uintptr(unsafe.Pointer(&systemTime)),
)
if err != nil && err.Error() != "The operation completed successfully." {
return err
}
return nil
} }

View file

@ -0,0 +1,52 @@
// Code generated by 'go generate'; DO NOT EDIT.
package ntp
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procSetSystemTime = modkernel32.NewProc("SetSystemTime")
)
func setSystemTime(lpSystemTime *windows.Systemtime) (err error) {
r1, _, e1 := syscall.Syscall(procSetSystemTime.Addr(), 1, uintptr(unsafe.Pointer(lpSystemTime)), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}

20
common/oncefunc.go Normal file
View file

@ -0,0 +1,20 @@
//go:build go1.21
package common
import "sync"
// OnceFunc is a wrapper around sync.OnceFunc.
func OnceFunc(f func()) func() {
return sync.OnceFunc(f)
}
// OnceValue is a wrapper around sync.OnceValue.
func OnceValue[T any](f func() T) func() T {
return sync.OnceValue(f)
}
// OnceValues is a wrapper around sync.OnceValues.
func OnceValues[T1, T2 any](f func() (T1, T2)) func() (T1, T2) {
return sync.OnceValues(f)
}

104
common/oncefunc_compat.go Normal file
View file

@ -0,0 +1,104 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.20 && !go1.21
package common
import "sync"
// OnceFunc returns a function that invokes f only once. The returned function
// may be called concurrently.
//
// If f panics, the returned function will panic with the same value on every call.
func OnceFunc(f func()) func() {
var (
once sync.Once
valid bool
p any
)
// Construct the inner closure just once to reduce costs on the fast path.
g := func() {
defer func() {
p = recover()
if !valid {
// Re-panic immediately so on the first call the user gets a
// complete stack trace into f.
panic(p)
}
}()
f()
f = nil // Do not keep f alive after invoking it.
valid = true // Set only if f does not panic.
}
return func() {
once.Do(g)
if !valid {
panic(p)
}
}
}
// OnceValue returns a function that invokes f only once and returns the value
// returned by f. The returned function may be called concurrently.
//
// If f panics, the returned function will panic with the same value on every call.
func OnceValue[T any](f func() T) func() T {
var (
once sync.Once
valid bool
p any
result T
)
g := func() {
defer func() {
p = recover()
if !valid {
panic(p)
}
}()
result = f()
f = nil
valid = true
}
return func() T {
once.Do(g)
if !valid {
panic(p)
}
return result
}
}
// OnceValues returns a function that invokes f only once and returns the values
// returned by f. The returned function may be called concurrently.
//
// If f panics, the returned function will panic with the same value on every call.
func OnceValues[T1, T2 any](f func() (T1, T2)) func() (T1, T2) {
var (
once sync.Once
valid bool
p any
r1 T1
r2 T2
)
g := func() {
defer func() {
p = recover()
if !valid {
panic(p)
}
}()
r1, r2 = f()
f = nil
valid = true
}
return func() (T1, T2) {
once.Do(g)
if !valid {
panic(p)
}
return r1, r2
}
}

View file

@ -14,24 +14,24 @@ import (
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
// pipeDeadline is an abstraction for handling timeouts. // Deadline is an abstraction for handling timeouts.
type pipeDeadline struct { type Deadline struct {
mu sync.Mutex // Guards timer and cancel mu sync.Mutex // Guards timer and cancel
timer *time.Timer timer *time.Timer
cancel chan struct{} // Must be non-nil cancel chan struct{} // Must be non-nil
} }
func makePipeDeadline() pipeDeadline { func MakeDeadline() Deadline {
return pipeDeadline{cancel: make(chan struct{})} return Deadline{cancel: make(chan struct{})}
} }
// set sets the point in time when the deadline will time out. // Set sets the point in time when the deadline will time out.
// A timeout event is signaled by closing the channel returned by waiter. // A timeout event is signaled by closing the channel returned by waiter.
// Once a timeout has occurred, the deadline can be refreshed by specifying a // Once a timeout has occurred, the deadline can be refreshed by specifying a
// t value in the future. // t value in the future.
// //
// A zero value for t prevents timeout. // A zero value for t prevents timeout.
func (d *pipeDeadline) set(t time.Time) { func (d *Deadline) Set(t time.Time) {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
@ -66,8 +66,8 @@ func (d *pipeDeadline) set(t time.Time) {
} }
} }
// wait returns a channel that is closed when the deadline is exceeded. // Wait returns a channel that is closed when the deadline is exceeded.
func (d *pipeDeadline) wait() chan struct{} { func (d *Deadline) Wait() chan struct{} {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
return d.cancel return d.cancel
@ -104,8 +104,8 @@ type pipe struct {
localDone chan struct{} localDone chan struct{}
remoteDone <-chan struct{} remoteDone <-chan struct{}
readDeadline pipeDeadline readDeadline Deadline
writeDeadline pipeDeadline writeDeadline Deadline
readWaitOptions N.ReadWaitOptions readWaitOptions N.ReadWaitOptions
} }
@ -127,15 +127,15 @@ func Pipe() (net.Conn, net.Conn) {
rdRx: cb1, rdTx: cn1, rdRx: cb1, rdTx: cn1,
wrTx: cb2, wrRx: cn2, wrTx: cb2, wrRx: cn2,
localDone: done1, remoteDone: done2, localDone: done1, remoteDone: done2,
readDeadline: makePipeDeadline(), readDeadline: MakeDeadline(),
writeDeadline: makePipeDeadline(), writeDeadline: MakeDeadline(),
} }
p2 := &pipe{ p2 := &pipe{
rdRx: cb2, rdTx: cn2, rdRx: cb2, rdTx: cn2,
wrTx: cb1, wrRx: cn1, wrTx: cb1, wrRx: cn1,
localDone: done2, remoteDone: done1, localDone: done2, remoteDone: done1,
readDeadline: makePipeDeadline(), readDeadline: MakeDeadline(),
writeDeadline: makePipeDeadline(), writeDeadline: MakeDeadline(),
} }
return p1, p2 return p1, p2
} }
@ -157,7 +157,7 @@ func (p *pipe) read(b []byte) (n int, err error) {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
case isClosedChan(p.remoteDone): case isClosedChan(p.remoteDone):
return 0, io.EOF return 0, io.EOF
case isClosedChan(p.readDeadline.wait()): case isClosedChan(p.readDeadline.Wait()):
return 0, os.ErrDeadlineExceeded return 0, os.ErrDeadlineExceeded
} }
@ -170,7 +170,7 @@ func (p *pipe) read(b []byte) (n int, err error) {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
case <-p.remoteDone: case <-p.remoteDone:
return 0, io.EOF return 0, io.EOF
case <-p.readDeadline.wait(): case <-p.readDeadline.Wait():
return 0, os.ErrDeadlineExceeded return 0, os.ErrDeadlineExceeded
} }
} }
@ -189,7 +189,7 @@ func (p *pipe) write(b []byte) (n int, err error) {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
case isClosedChan(p.remoteDone): case isClosedChan(p.remoteDone):
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
case isClosedChan(p.writeDeadline.wait()): case isClosedChan(p.writeDeadline.Wait()):
return 0, os.ErrDeadlineExceeded return 0, os.ErrDeadlineExceeded
} }
@ -205,7 +205,7 @@ func (p *pipe) write(b []byte) (n int, err error) {
return n, io.ErrClosedPipe return n, io.ErrClosedPipe
case <-p.remoteDone: case <-p.remoteDone:
return n, io.ErrClosedPipe return n, io.ErrClosedPipe
case <-p.writeDeadline.wait(): case <-p.writeDeadline.Wait():
return n, os.ErrDeadlineExceeded return n, os.ErrDeadlineExceeded
} }
} }
@ -216,8 +216,8 @@ func (p *pipe) SetDeadline(t time.Time) error {
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) { if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
p.readDeadline.set(t) p.readDeadline.Set(t)
p.writeDeadline.set(t) p.writeDeadline.Set(t)
return nil return nil
} }
@ -225,7 +225,7 @@ func (p *pipe) SetReadDeadline(t time.Time) error {
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) { if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
p.readDeadline.set(t) p.readDeadline.Set(t)
return nil return nil
} }
@ -233,7 +233,7 @@ func (p *pipe) SetWriteDeadline(t time.Time) error {
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) { if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
p.writeDeadline.set(t) p.writeDeadline.Set(t)
return nil return nil
} }

View file

@ -30,7 +30,7 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) {
return nil, io.ErrClosedPipe return nil, io.ErrClosedPipe
case isClosedChan(p.remoteDone): case isClosedChan(p.remoteDone):
return nil, io.EOF return nil, io.EOF
case isClosedChan(p.readDeadline.wait()): case isClosedChan(p.readDeadline.Wait()):
return nil, os.ErrDeadlineExceeded return nil, os.ErrDeadlineExceeded
} }
select { select {
@ -49,7 +49,7 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) {
return nil, io.ErrClosedPipe return nil, io.ErrClosedPipe
case <-p.remoteDone: case <-p.remoteDone:
return nil, io.EOF return nil, io.EOF
case <-p.readDeadline.wait(): case <-p.readDeadline.Wait():
return nil, os.ErrDeadlineExceeded return nil, os.ErrDeadlineExceeded
} }
} }

View file

@ -20,6 +20,5 @@ func InitializeSeed() {
func initializeSeed() { func initializeSeed() {
var seed int64 var seed int64
common.Must(binary.Read(rand.Reader, binary.LittleEndian, &seed)) common.Must(binary.Read(rand.Reader, binary.LittleEndian, &seed))
//goland:noinspection GoDeprecation
mRand.Seed(seed) mRand.Seed(seed)
} }

View file

@ -6,6 +6,7 @@ import (
) )
func TestRevertRanges(t *testing.T) { func TestRevertRanges(t *testing.T) {
t.Parallel()
for _, testRange := range []struct { for _, testRange := range []struct {
start, end int start, end int
ranges []Range[int] ranges []Range[int]
@ -77,6 +78,7 @@ func TestRevertRanges(t *testing.T) {
} }
func TestMergeRanges(t *testing.T) { func TestMergeRanges(t *testing.T) {
t.Parallel()
for _, testRange := range []struct { for _, testRange := range []struct {
ranges []Range[int] ranges []Range[int]
expected []Range[int] expected []Range[int]
@ -144,6 +146,7 @@ func TestMergeRanges(t *testing.T) {
} }
func TestExcludeRanges(t *testing.T) { func TestExcludeRanges(t *testing.T) {
t.Parallel()
for _, testRange := range []struct { for _, testRange := range []struct {
ranges []Range[int] ranges []Range[int]
exclude []Range[int] exclude []Range[int]

View file

@ -27,7 +27,6 @@ func ToByteReader(reader io.Reader) io.ByteReader {
// Deprecated: Use binary.ReadUvarint instead. // Deprecated: Use binary.ReadUvarint instead.
func ReadUVariant(reader io.Reader) (uint64, error) { func ReadUVariant(reader io.Reader) (uint64, error) {
//goland:noinspection GoDeprecation
return binary.ReadUvarint(ToByteReader(reader)) return binary.ReadUvarint(ToByteReader(reader))
} }

View file

@ -54,17 +54,9 @@ func (g *Group) Concurrency(n int) {
} }
} }
func (g *Group) Run(contextList ...context.Context) error { func (g *Group) Run(ctx context.Context) error {
return g.RunContextList(contextList)
}
func (g *Group) RunContextList(contextList []context.Context) error {
if len(contextList) == 0 {
contextList = append(contextList, context.Background())
}
taskContext, taskFinish := common.ContextWithCancelCause(context.Background()) taskContext, taskFinish := common.ContextWithCancelCause(context.Background())
taskCancelContext, taskCancel := common.ContextWithCancelCause(context.Background()) taskCancelContext, taskCancel := common.ContextWithCancelCause(ctx)
var errorAccess sync.Mutex var errorAccess sync.Mutex
var returnError error var returnError error
@ -112,10 +104,12 @@ func (g *Group) RunContextList(contextList []context.Context) error {
}() }()
} }
selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskCancelContext}, contextList...)) var upstreamErr bool
select {
if selectedContext != 0 { case <-taskCancelContext.Done():
taskCancel(upstreamErr) case <-ctx.Done():
upstreamErr = true
taskCancel(ctx.Err())
} }
if g.cleanup != nil { if g.cleanup != nil {
@ -124,10 +118,8 @@ func (g *Group) RunContextList(contextList []context.Context) error {
<-taskContext.Done() <-taskContext.Done()
if selectedContext != 0 { if upstreamErr {
returnError = E.Append(returnError, upstreamErr, func(err error) error { return ctx.Err()
return E.Cause(err, "upstream")
})
} }
return returnError return returnError

View file

@ -2,6 +2,7 @@ package task
import "context" import "context"
// Deprecated: Use Group instead
func Run(ctx context.Context, tasks ...func() error) error { func Run(ctx context.Context, tasks ...func() error) error {
var group Group var group Group
for _, task := range tasks { for _, task := range tasks {
@ -13,6 +14,7 @@ func Run(ctx context.Context, tasks ...func() error) error {
return group.Run(ctx) return group.Run(ctx)
} }
// Deprecated: Use Group instead
func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error { func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error {
var group Group var group Group
for _, task := range tasks { for _, task := range tasks {

View file

@ -2,6 +2,7 @@ package udpnat
import ( import (
"io" "io"
"os"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -34,5 +35,7 @@ func (c *conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er
return return
case <-c.ctx.Done(): case <-c.ctx.Done():
return nil, M.Socksaddr{}, io.ErrClosedPipe return nil, M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
} }
} }

View file

@ -13,20 +13,26 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
) )
// Deprecated: Use N.UDPConnectionHandler instead.
//
//nolint:staticcheck
type Handler interface { type Handler interface {
N.UDPConnectionHandler N.UDPConnectionHandler
E.Handler E.Handler
} }
type Service[K comparable] struct { type Service[K comparable] struct {
nat *cache.LruCache[K, *conn] nat *cache.LruCache[K, *conn]
handler Handler handler Handler
handlerEx N.UDPConnectionHandlerEx
} }
// Deprecated: Use NewEx instead.
func New[K comparable](maxAge int64, handler Handler) *Service[K] { func New[K comparable](maxAge int64, handler Handler) *Service[K] {
return &Service[K]{ service := &Service[K]{
nat: cache.New( nat: cache.New(
cache.WithAge[K, *conn](maxAge), cache.WithAge[K, *conn](maxAge),
cache.WithUpdateAgeOnGet[K, *conn](), cache.WithUpdateAgeOnGet[K, *conn](),
@ -36,11 +42,27 @@ func New[K comparable](maxAge int64, handler Handler) *Service[K] {
), ),
handler: handler, handler: handler,
} }
return service
}
func NewEx[K comparable](maxAge int64, handler N.UDPConnectionHandlerEx) *Service[K] {
service := &Service[K]{
nat: cache.New(
cache.WithAge[K, *conn](maxAge),
cache.WithUpdateAgeOnGet[K, *conn](),
cache.WithEvict[K, *conn](func(key K, conn *conn) {
conn.Close()
}),
),
handlerEx: handler,
}
return service
} }
func (s *Service[T]) WriteIsThreadUnsafe() { func (s *Service[T]) WriteIsThreadUnsafe() {
} }
// Deprecated: don't use
func (s *Service[T]) NewPacketDirect(ctx context.Context, key T, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) { func (s *Service[T]) NewPacketDirect(ctx context.Context, key T, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) {
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
return ctx, &DirectBackWriter{conn, natConn} return ctx, &DirectBackWriter{conn, natConn}
@ -60,18 +82,31 @@ func (w *DirectBackWriter) Upstream() any {
return w.Source return w.Source
} }
// Deprecated: use NewPacketEx instead.
func (s *Service[T]) NewPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) N.PacketWriter) { func (s *Service[T]) NewPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) N.PacketWriter) {
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
return ctx, init(natConn) return ctx, init(natConn)
}) })
} }
func (s *Service[T]) NewPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) N.PacketWriter) {
s.NewContextPacketEx(ctx, key, buffer, source, destination, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
return ctx, init(natConn)
})
}
// Deprecated: Use NewPacketConnectionEx instead.
func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) { func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) {
s.NewContextPacketEx(ctx, key, buffer, metadata.Source, metadata.Destination, init)
}
func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) {
c, loaded := s.nat.LoadOrStore(key, func() *conn { c, loaded := s.nat.LoadOrStore(key, func() *conn {
c := &conn{ c := &conn{
data: make(chan packet, 64), data: make(chan packet, 64),
localAddr: metadata.Source, localAddr: source,
remoteAddr: metadata.Destination, remoteAddr: destination,
readDeadline: pipe.MakeDeadline(),
} }
c.ctx, c.cancel = common.ContextWithCancelCause(ctx) c.ctx, c.cancel = common.ContextWithCancelCause(ctx)
return c return c
@ -79,26 +114,34 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu
if !loaded { if !loaded {
ctx, c.source = init(c) ctx, c.source = init(c)
go func() { go func() {
err := s.handler.NewPacketConnection(ctx, c, metadata) if s.handlerEx != nil {
if err != nil { s.handlerEx.NewPacketConnectionEx(ctx, c, source, destination, func(err error) {
s.handler.NewError(ctx, err) s.nat.Delete(key)
})
} else {
//nolint:staticcheck
err := s.handler.NewPacketConnection(ctx, c, M.Metadata{
Source: source,
Destination: destination,
})
if err != nil {
s.handler.NewError(ctx, err)
}
c.Close()
s.nat.Delete(key)
} }
c.Close()
s.nat.Delete(key)
}() }()
} else {
c.localAddr = metadata.Source
} }
if common.Done(c.ctx) { if common.Done(c.ctx) {
s.nat.Delete(key) s.nat.Delete(key)
if !common.Done(ctx) { if !common.Done(ctx) {
s.NewContextPacket(ctx, key, buffer, metadata, init) s.NewContextPacketEx(ctx, key, buffer, source, destination, init)
} }
return return
} }
c.data <- packet{ c.data <- packet{
data: buffer, data: buffer,
destination: metadata.Destination, destination: destination,
} }
} }
@ -116,6 +159,7 @@ type conn struct {
localAddr M.Socksaddr localAddr M.Socksaddr
remoteAddr M.Socksaddr remoteAddr M.Socksaddr
source N.PacketWriter source N.PacketWriter
readDeadline pipe.Deadline
readWaitOptions N.ReadWaitOptions readWaitOptions N.ReadWaitOptions
} }
@ -127,6 +171,8 @@ func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
return p.destination, err return p.destination, err
case <-c.ctx.Done(): case <-c.ctx.Done():
return M.Socksaddr{}, io.ErrClosedPipe return M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
} }
} }
@ -159,17 +205,14 @@ func (c *conn) SetDeadline(t time.Time) error {
} }
func (c *conn) SetReadDeadline(t time.Time) error { func (c *conn) SetReadDeadline(t time.Time) error {
return os.ErrInvalid c.readDeadline.Set(t)
return nil
} }
func (c *conn) SetWriteDeadline(t time.Time) error { func (c *conn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid return os.ErrInvalid
} }
func (c *conn) NeedAdditionalReadDeadline() bool {
return true
}
func (c *conn) Upstream() any { func (c *conn) Upstream() any {
return c.source return c.source
} }

138
common/udpnat2/conn.go Normal file
View file

@ -0,0 +1,138 @@
package udpnat
import (
"io"
"net"
"net/netip"
"os"
"sync"
"time"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/canceler"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
"github.com/sagernet/sing/contrab/freelru"
)
type Conn interface {
N.PacketConn
SetHandler(handler N.UDPHandlerEx)
canceler.PacketConn
}
var _ Conn = (*natConn)(nil)
type natConn struct {
cache freelru.Cache[netip.AddrPort, *natConn]
writer N.PacketWriter
localAddr M.Socksaddr
handlerAccess sync.RWMutex
handler N.UDPHandlerEx
packetChan chan *N.PacketBuffer
closeOnce sync.Once
doneChan chan struct{}
readDeadline pipe.Deadline
readWaitOptions N.ReadWaitOptions
}
func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
select {
case p := <-c.packetChan:
_, err = buffer.ReadOnceFrom(p.Buffer)
destination := p.Destination
p.Buffer.Release()
N.PutPacketBuffer(p)
return destination, err
case <-c.doneChan:
return M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.writer.WritePacket(buffer, destination)
}
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case packet := <-c.packetChan:
buffer = c.readWaitOptions.Copy(packet.Buffer)
destination = packet.Destination
N.PutPacketBuffer(packet)
return
case <-c.doneChan:
return nil, M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
func (c *natConn) SetHandler(handler N.UDPHandlerEx) {
c.handlerAccess.Lock()
c.handler = handler
c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler)
c.handlerAccess.Unlock()
fetch:
for {
select {
case packet := <-c.packetChan:
c.handler.NewPacketEx(packet.Buffer, packet.Destination)
N.PutPacketBuffer(packet)
continue fetch
default:
break fetch
}
}
}
func (c *natConn) Timeout() time.Duration {
rawConn, lifetime, loaded := c.cache.PeekWithLifetime(c.localAddr.AddrPort())
if !loaded || rawConn != c {
return 0
}
return time.Until(lifetime)
}
func (c *natConn) SetTimeout(timeout time.Duration) bool {
return c.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout)
}
func (c *natConn) Close() error {
c.closeOnce.Do(func() {
close(c.doneChan)
})
return nil
}
func (c *natConn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *natConn) RemoteAddr() net.Addr {
return M.Socksaddr{}
}
func (c *natConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *natConn) SetReadDeadline(t time.Time) error {
c.readDeadline.Set(t)
return nil
}
func (c *natConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *natConn) Upstream() any {
return c.writer
}

103
common/udpnat2/service.go Normal file
View file

@ -0,0 +1,103 @@
package udpnat
import (
"context"
"net/netip"
"time"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
)
type Service struct {
cache freelru.Cache[netip.AddrPort, *natConn]
handler N.UDPConnectionHandlerEx
prepare PrepareFunc
}
type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc)
func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service {
if timeout == 0 {
panic("invalid timeout")
}
var cache freelru.Cache[netip.AddrPort, *natConn]
if !shared {
cache = common.Must1(freelru.NewSynced[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
} else {
cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
}
cache.SetLifetime(timeout)
cache.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool {
select {
case <-conn.doneChan:
return false
default:
return true
}
})
cache.SetOnEvict(func(_ netip.AddrPort, conn *natConn) {
conn.Close()
})
return &Service{
cache: cache,
handler: handler,
prepare: prepare,
}
}
func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) {
conn, _, ok := s.cache.GetAndRefreshOrAdd(source.AddrPort(), func() (*natConn, bool) {
ok, ctx, writer, onClose := s.prepare(source, destination, userData)
if !ok {
return nil, false
}
newConn := &natConn{
cache: s.cache,
writer: writer,
localAddr: source,
packetChan: make(chan *N.PacketBuffer, 64),
doneChan: make(chan struct{}),
readDeadline: pipe.MakeDeadline(),
}
go s.handler.NewPacketConnectionEx(ctx, newConn, source, destination, onClose)
return newConn, true
})
if !ok {
return
}
buffer := conn.readWaitOptions.NewPacketBuffer()
for _, bufferSlice := range bufferSlices {
buffer.Write(bufferSlice)
}
conn.handlerAccess.RLock()
handler := conn.handler
conn.handlerAccess.RUnlock()
if handler != nil {
handler.NewPacketEx(buffer, destination)
return
}
packet := N.NewPacketBuffer()
*packet = N.PacketBuffer{
Buffer: buffer,
Destination: destination,
}
select {
case conn.packetChan <- packet:
default:
packet.Buffer.Release()
N.PutPacketBuffer(packet)
}
}
func (s *Service) Purge() {
s.cache.Purge()
}
func (s *Service) PurgeExpired() {
s.cache.PurgeExpired()
}

View file

@ -9,6 +9,7 @@ import (
) )
func TestSlicesValue(t *testing.T) { func TestSlicesValue(t *testing.T) {
t.Parallel()
int64Arr := make([]int64, 64) int64Arr := make([]int64, 64)
for i := range int64Arr { for i := range int64Arr {
int64Arr[i] = rand.Int63() int64Arr[i] = rand.Int63()
@ -18,6 +19,7 @@ func TestSlicesValue(t *testing.T) {
} }
func TestSetSliceValue(t *testing.T) { func TestSetSliceValue(t *testing.T) {
t.Parallel()
int64Arr := make([]int64, 64) int64Arr := make([]int64, 64)
value := reflect.Indirect(reflect.ValueOf(&int64Arr)) value := reflect.Indirect(reflect.ValueOf(&int64Arr))
newInt64Arr := make([]int64, 64) newInt64Arr := make([]int64, 64)

View file

@ -0,0 +1,9 @@
//go:build !windows
package windnsapi
import "os"
func FlushResolverCache() error {
return os.ErrInvalid
}

View file

@ -0,0 +1,14 @@
//go:build windows
package windnsapi
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestDNSAPI(t *testing.T) {
t.Parallel()
require.NoError(t, FlushResolverCache())
}

View file

@ -1,21 +0,0 @@
package windnsapi
import (
"os"
"syscall"
"golang.org/x/sys/windows"
)
var (
moddnsapi = windows.NewLazySystemDLL("dnsapi.dll")
procDnsFlushResolverCache = moddnsapi.NewProc("DnsFlushResolverCache")
)
func FlushResolverCache() error {
r0, _, err := syscall.SyscallN(procDnsFlushResolverCache.Addr())
if r0 == 0 {
return os.NewSyscallError("DnsFlushResolverCache", err)
}
return nil
}

View file

@ -0,0 +1,6 @@
package windnsapi
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
// dnsapi.DnsFlushResolverCache is an undocumented function
//sys FlushResolverCache() (err error) = dnsapi.DnsFlushResolverCache

View file

@ -0,0 +1,52 @@
// Code generated by 'go generate'; DO NOT EDIT.
package windnsapi
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
moddnsapi = windows.NewLazySystemDLL("dnsapi.dll")
procDnsFlushResolverCache = moddnsapi.NewProc("DnsFlushResolverCache")
)
func FlushResolverCache() (err error) {
r1, _, e1 := syscall.Syscall(procDnsFlushResolverCache.Addr(), 0, 0, 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}

View file

@ -0,0 +1,217 @@
//go:build windows
package winiphlpapi
import (
"context"
"encoding/binary"
"net"
"net/netip"
"os"
"time"
"unsafe"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func LoadEStats() error {
err := modiphlpapi.Load()
if err != nil {
return err
}
err = procGetTcpTable.Find()
if err != nil {
return err
}
err = procGetTcp6Table.Find()
if err != nil {
return err
}
err = procGetPerTcp6ConnectionEStats.Find()
if err != nil {
return err
}
err = procGetPerTcp6ConnectionEStats.Find()
if err != nil {
return err
}
err = procSetPerTcpConnectionEStats.Find()
if err != nil {
return err
}
err = procSetPerTcp6ConnectionEStats.Find()
if err != nil {
return err
}
return nil
}
func LoadExtendedTable() error {
err := modiphlpapi.Load()
if err != nil {
return err
}
err = procGetExtendedTcpTable.Find()
if err != nil {
return err
}
err = procGetExtendedUdpTable.Find()
if err != nil {
return err
}
return nil
}
func FindPid(network string, source netip.AddrPort) (uint32, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
if source.Addr().Is4() {
tcpTable, err := GetExtendedTcpTable()
if err != nil {
return 0, err
}
for _, row := range tcpTable {
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
} else {
tcpTable, err := GetExtendedTcp6Table()
if err != nil {
return 0, err
}
for _, row := range tcpTable {
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
}
case N.NetworkUDP:
if source.Addr().Is4() {
udpTable, err := GetExtendedUdpTable()
if err != nil {
return 0, err
}
for _, row := range udpTable {
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
} else {
udpTable, err := GetExtendedUdp6Table()
if err != nil {
return 0, err
}
for _, row := range udpTable {
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
}
}
return 0, E.New("process not found for ", source)
}
func WriteAndWaitAck(ctx context.Context, conn net.Conn, payload []byte) error {
source := M.AddrPortFromNet(conn.LocalAddr())
destination := M.AddrPortFromNet(conn.RemoteAddr())
if source.Addr().Is4() {
tcpTable, err := GetTcpTable()
if err != nil {
return err
}
var tcpRow *MibTcpRow
for _, row := range tcpTable {
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) ||
destination == netip.AddrPortFrom(DwordToAddr(row.DwRemoteAddr), DwordToPort(row.DwRemotePort)) {
tcpRow = &row
break
}
}
if tcpRow == nil {
return E.New("row not found for: ", source)
}
err = SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: true,
})
if err != nil {
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
}
defer SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: false,
})
_, err = conn.Write(payload)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
eStstsSendBuffer, err := GetPerTcpConnectionEStatsSendBuffer(tcpRow)
if err != nil {
return err
}
if eStstsSendBuffer.CurRetxQueue == 0 {
return nil
}
time.Sleep(10 * time.Millisecond)
}
} else {
tcpTable, err := GetTcp6Table()
if err != nil {
return err
}
var tcpRow *MibTcp6Row
for _, row := range tcpTable {
if source == netip.AddrPortFrom(netip.AddrFrom16(row.LocalAddr), DwordToPort(row.LocalPort)) ||
destination == netip.AddrPortFrom(netip.AddrFrom16(row.RemoteAddr), DwordToPort(row.RemotePort)) {
tcpRow = &row
break
}
}
if tcpRow == nil {
return E.New("row not found for: ", source)
}
err = SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: true,
})
if err != nil {
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
}
defer SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: false,
})
_, err = conn.Write(payload)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
eStstsSendBuffer, err := GetPerTcp6ConnectionEStatsSendBuffer(tcpRow)
if err != nil {
return err
}
if eStstsSendBuffer.CurRetxQueue == 0 {
return nil
}
time.Sleep(10 * time.Millisecond)
}
}
}
func DwordToAddr(addr uint32) netip.Addr {
return netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(&addr)))
}
func DwordToPort(dword uint32) uint16 {
return binary.BigEndian.Uint16((*[4]byte)(unsafe.Pointer(&dword))[:])
}

View file

@ -0,0 +1,313 @@
//go:build windows
package winiphlpapi
import (
"errors"
"os"
"unsafe"
"golang.org/x/sys/windows"
)
const (
TcpTableBasicListener uint32 = iota
TcpTableBasicConnections
TcpTableBasicAll
TcpTableOwnerPidListener
TcpTableOwnerPidConnections
TcpTableOwnerPidAll
TcpTableOwnerModuleListener
TcpTableOwnerModuleConnections
TcpTableOwnerModuleAll
)
const (
UdpTableBasic uint32 = iota
UdpTableOwnerPid
UdpTableOwnerModule
)
const (
TcpConnectionEstatsSynOpts uint32 = iota
TcpConnectionEstatsData
TcpConnectionEstatsSndCong
TcpConnectionEstatsPath
TcpConnectionEstatsSendBuff
TcpConnectionEstatsRec
TcpConnectionEstatsObsRec
TcpConnectionEstatsBandwidth
TcpConnectionEstatsFineRtt
TcpConnectionEstatsMaximum
)
type MibTcpTable struct {
DwNumEntries uint32
Table [1]MibTcpRow
}
type MibTcpRow struct {
DwState uint32
DwLocalAddr uint32
DwLocalPort uint32
DwRemoteAddr uint32
DwRemotePort uint32
}
type MibTcp6Table struct {
DwNumEntries uint32
Table [1]MibTcp6Row
}
type MibTcp6Row struct {
State uint32
LocalAddr [16]byte
LocalScopeId uint32
LocalPort uint32
RemoteAddr [16]byte
RemoteScopeId uint32
RemotePort uint32
}
type MibTcpTableOwnerPid struct {
DwNumEntries uint32
Table [1]MibTcpRowOwnerPid
}
type MibTcpRowOwnerPid struct {
DwState uint32
DwLocalAddr uint32
DwLocalPort uint32
DwRemoteAddr uint32
DwRemotePort uint32
DwOwningPid uint32
}
type MibTcp6TableOwnerPid struct {
DwNumEntries uint32
Table [1]MibTcp6RowOwnerPid
}
type MibTcp6RowOwnerPid struct {
UcLocalAddr [16]byte
DwLocalScopeId uint32
DwLocalPort uint32
UcRemoteAddr [16]byte
DwRemoteScopeId uint32
DwRemotePort uint32
DwState uint32
DwOwningPid uint32
}
type MibUdpTableOwnerPid struct {
DwNumEntries uint32
Table [1]MibUdpRowOwnerPid
}
type MibUdpRowOwnerPid struct {
DwLocalAddr uint32
DwLocalPort uint32
DwOwningPid uint32
}
type MibUdp6TableOwnerPid struct {
DwNumEntries uint32
Table [1]MibUdp6RowOwnerPid
}
type MibUdp6RowOwnerPid struct {
UcLocalAddr [16]byte
DwLocalScopeId uint32
DwLocalPort uint32
DwOwningPid uint32
}
type TcpEstatsSendBufferRodV0 struct {
CurRetxQueue uint64
MaxRetxQueue uint64
CurAppWQueue uint64
MaxAppWQueue uint64
}
type TcpEstatsSendBuffRwV0 struct {
EnableCollection bool
}
const (
offsetOfMibTcpTable = unsafe.Offsetof(MibTcpTable{}.Table)
offsetOfMibTcp6Table = unsafe.Offsetof(MibTcp6Table{}.Table)
offsetOfMibTcpTableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table)
offsetOfMibTcp6TableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table)
offsetOfMibUdpTableOwnerPid = unsafe.Offsetof(MibUdpTableOwnerPid{}.Table)
offsetOfMibUdp6TableOwnerPid = unsafe.Offsetof(MibUdp6TableOwnerPid{}.Table)
sizeOfTcpEstatsSendBuffRwV0 = unsafe.Sizeof(TcpEstatsSendBuffRwV0{})
sizeOfTcpEstatsSendBufferRodV0 = unsafe.Sizeof(TcpEstatsSendBufferRodV0{})
)
func GetTcpTable() ([]MibTcpRow, error) {
var size uint32
err := getTcpTable(nil, &size, false)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, err
}
for {
table := make([]byte, size)
err = getTcpTable(&table[0], &size, false)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, err
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcpRow)(unsafe.Pointer(&table[offsetOfMibTcpTable])), dwNumEntries), nil
}
}
func GetTcp6Table() ([]MibTcp6Row, error) {
var size uint32
err := getTcp6Table(nil, &size, false)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, err
}
for {
table := make([]byte, size)
err = getTcp6Table(&table[0], &size, false)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, err
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcp6Row)(unsafe.Pointer(&table[offsetOfMibTcp6Table])), dwNumEntries), nil
}
}
func GetExtendedTcpTable() ([]MibTcpRowOwnerPid, error) {
var size uint32
err := getExtendedTcpTable(nil, &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcpTableOwnerPid])), dwNumEntries), nil
}
}
func GetExtendedTcp6Table() ([]MibTcp6RowOwnerPid, error) {
var size uint32
err := getExtendedTcpTable(nil, &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcp6TableOwnerPid])), dwNumEntries), nil
}
}
func GetExtendedUdpTable() ([]MibUdpRowOwnerPid, error) {
var size uint32
err := getExtendedUdpTable(nil, &size, false, windows.AF_INET, UdpTableOwnerPid, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET, UdpTableOwnerPid, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdpTableOwnerPid])), dwNumEntries), nil
}
}
func GetExtendedUdp6Table() ([]MibUdp6RowOwnerPid, error) {
var size uint32
err := getExtendedUdpTable(nil, &size, false, windows.AF_INET6, UdpTableOwnerPid, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET6, UdpTableOwnerPid, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdp6TableOwnerPid])), dwNumEntries), nil
}
}
func GetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow) (*TcpEstatsSendBufferRodV0, error) {
var rod TcpEstatsSendBufferRodV0
err := getPerTcpConnectionEStats(row,
TcpConnectionEstatsSendBuff,
0,
0,
0,
0,
0,
0,
uintptr(unsafe.Pointer(&rod)),
0,
uint64(sizeOfTcpEstatsSendBufferRodV0),
)
if err != nil {
return nil, err
}
return &rod, nil
}
func GetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row) (*TcpEstatsSendBufferRodV0, error) {
var rod TcpEstatsSendBufferRodV0
err := getPerTcp6ConnectionEStats(row,
TcpConnectionEstatsSendBuff,
0,
0,
0,
0,
0,
0,
uintptr(unsafe.Pointer(&rod)),
0,
uint64(sizeOfTcpEstatsSendBufferRodV0),
)
if err != nil {
return nil, err
}
return &rod, nil
}
func SetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow, rw *TcpEstatsSendBuffRwV0) error {
return setPerTcpConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0)
}
func SetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row, rw *TcpEstatsSendBuffRwV0) error {
return setPerTcp6ConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0)
}

View file

@ -0,0 +1,90 @@
//go:build windows
package winiphlpapi_test
import (
"context"
"net"
"syscall"
"testing"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/winiphlpapi"
"github.com/stretchr/testify/require"
)
func TestFindPidTcp4(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}
func TestFindPidTcp6(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "[::1]:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}
func TestFindPidUdp4(t *testing.T) {
t.Parallel()
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}
func TestFindPidUdp6(t *testing.T) {
t.Parallel()
conn, err := net.ListenPacket("udp", "[::1]:0")
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}
func TestWaitAck4(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello"))
require.NoError(t, err)
}
func TestWaitAck6(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "[::1]:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello"))
require.NoError(t, err)
}

View file

@ -0,0 +1,27 @@
package winiphlpapi
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcptable
//sys getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcpTable
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcp6table
//sys getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcp6Table
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcpconnectionestats
//sys getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcpConnectionEStats
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcp6connectionestats
//sys getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcp6ConnectionEStats
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcpconnectionestats
//sys setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcpConnectionEStats
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcp6connectionestats
//sys setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcp6ConnectionEStats
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
//sys getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedTcpTable
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedudptable
//sys getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedUdpTable

View file

@ -0,0 +1,131 @@
// Code generated by 'go generate'; DO NOT EDIT.
package winiphlpapi
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable")
procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable")
procGetPerTcp6ConnectionEStats = modiphlpapi.NewProc("GetPerTcp6ConnectionEStats")
procGetPerTcpConnectionEStats = modiphlpapi.NewProc("GetPerTcpConnectionEStats")
procGetTcp6Table = modiphlpapi.NewProc("GetTcp6Table")
procGetTcpTable = modiphlpapi.NewProc("GetTcpTable")
procSetPerTcp6ConnectionEStats = modiphlpapi.NewProc("SetPerTcp6ConnectionEStats")
procSetPerTcpConnectionEStats = modiphlpapi.NewProc("SetPerTcpConnectionEStats")
)
func getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) {
var _p0 uint32
if bOrder {
_p0 = 1
}
r0, _, _ := syscall.Syscall6(procGetExtendedTcpTable.Addr(), 6, uintptr(unsafe.Pointer(pTcpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) {
var _p0 uint32
if bOrder {
_p0 = 1
}
r0, _, _ := syscall.Syscall6(procGetExtendedUdpTable.Addr(), 6, uintptr(unsafe.Pointer(pUdpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) {
r0, _, _ := syscall.Syscall12(procGetPerTcp6ConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0)
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) {
r0, _, _ := syscall.Syscall12(procGetPerTcpConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0)
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) {
var _p0 uint32
if order {
_p0 = 1
}
r0, _, _ := syscall.Syscall(procGetTcp6Table.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) {
var _p0 uint32
if order {
_p0 = 1
}
r0, _, _ := syscall.Syscall(procGetTcpTable.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) {
r0, _, _ := syscall.Syscall6(procSetPerTcp6ConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) {
r0, _, _ := syscall.Syscall6(procSetPerTcpConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}

View file

@ -6,6 +6,8 @@ const (
EVENT_RESUME_AUTOMATIC // Because the user is not present, most applications should do nothing. EVENT_RESUME_AUTOMATIC // Because the user is not present, most applications should do nothing.
) )
type EventCallback = func(event int)
type EventListener interface { type EventListener interface {
Start() error Start() error
Close() error Close() error

View file

@ -6,6 +6,6 @@ import (
"os" "os"
) )
func NewEventListener(callback func(event int)) (EventListener, error) { func NewEventListener(callback EventCallback) (EventListener, error) {
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }

View file

@ -11,6 +11,7 @@ func TestPowerEvents(t *testing.T) {
if runtime.GOOS != "windows" { if runtime.GOOS != "windows" {
t.SkipNow() t.SkipNow()
} }
t.Parallel()
listener, err := NewEventListener(func(event int) {}) listener, err := NewEventListener(func(event int) {})
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, listener) require.NotNil(t, listener)

Some files were not shown because too many files have changed in this diff Show more