Compare commits

...

181 commits
v0.2.21 ... dev

Author SHA1 Message Date
世界
159e489fc3
Add E.Cause1 2025-04-03 18:26:45 +08:00
世界
d39c2c2fdd
socks: Add custom udp listener 2025-03-26 13:18:24 +08:00
世界
ea82ac275f
Add freelru.GetWithLifetimeNoExpire 2025-03-26 13:18:18 +08:00
世界
ea0ac932ae
Add winiphlpapi 2025-03-26 13:18:17 +08:00
世界
2b41455f5a
Fix udpnat2 handler again 2025-03-26 12:46:15 +08:00
世界
23b0180a1b
Fix crash on udpnat2 handler 2025-03-24 18:11:10 +08:00
世界
ce1b4851a4
Fix socks5 UDP 2025-03-16 10:23:29 +08:00
世界
2238a05966
Fix merge objects 2025-03-16 10:23:29 +08:00
世界
b55d1c78b3
bufio: Add destination NAT packet conn 2025-03-09 15:20:32 +08:00
世界
d54716612c
Fix syscall packet read waiter for Windows 2025-02-28 12:07:45 +08:00
世界
9eafc7fc62
udpnat2: Fix crash 2025-02-10 15:08:18 +08:00
世界
d8153df67f
Add ENOTCONN to IsClosed 2025-02-06 08:41:32 +08:00
世界
d9f6eb136d
Fix set windows system time 2025-01-09 23:30:25 +08:00
世界
4dabb9be97
freelru: Fix GetAndRefreshOrAdd 2025-01-09 15:59:26 +08:00
世界
be9840c70f
listable: Fix incorrect unmarshaling of null to []T{null} 2025-01-09 15:57:12 +08:00
世界
aa7d2543a3
Fix errors usage 2024-12-16 09:20:34 +08:00
世界
33beacc053
Fix socks5 UDP handshake 2024-12-14 18:16:15 +08:00
世界
442cceb9fa
Fix disable UDP fragment 2024-12-12 20:43:56 +08:00
世界
3374a45475
Fix socks5 UDP implementation 2024-12-10 19:53:57 +08:00
世界
73776cf797
Fix lru test 2024-12-10 19:42:55 +08:00
世界
957166799e
Fix CloseOnHandshakeFailure 2024-12-04 17:14:58 +08:00
世界
809d8eca13
freelru: fix PurgeExpired 2024-12-04 11:36:20 +08:00
世界
9f69e7f9f7
E: IsClosedOrCanceled check IsTimeout 2024-12-01 20:19:37 +08:00
世界
478265cd45
badoption: Finish netip options 2024-12-01 14:33:23 +08:00
世界
3f30aaf25e
freelru: purge all expired items 2024-11-30 16:06:59 +08:00
世界
39040e06dc
udpnat2: Fix concurrency 2024-11-28 13:51:17 +08:00
世界
6edd2ce0ea
freelru: Update source and add GetAndRefreshOrAdd 2024-11-28 13:51:17 +08:00
世界
0a2e2a3eaf
udpnat2: Fix timeout 2024-11-27 18:02:22 +08:00
世界
4ba1eb123c
Fix set timeout 2024-11-27 17:28:18 +08:00
世界
c44912a861
freelru: Fix purge 2024-11-27 13:51:08 +08:00
世界
a8f5bf4eb0
udpnat2: Add timeout check 2024-11-26 19:08:35 +08:00
世界
30e9d91b57
Fix AppendClose 2024-11-26 12:21:37 +08:00
世界
7fd3517e4d
udpnat2: Add purge expire ticker 2024-11-26 12:21:37 +08:00
世界
a8285e06a5
udpnat2: Implement set timeout for nat conn 2024-11-26 12:21:37 +08:00
世界
3613ead480
freelru: Add PeekWithLifetime and UpdateLifetime 2024-11-26 11:29:14 +08:00
世界
c8f251c668
Fix copy count 2024-11-24 19:02:21 +08:00
世界
fa5355e99e
bufio: more copy funcs 2024-11-20 11:27:20 +08:00
世界
30fbafd954
udpnat2: Add cache funcs 2024-11-18 12:14:35 +08:00
世界
fdca9b3f8e
badjson: Fix Listable 2024-11-16 16:03:00 +08:00
世界
e52e04f721
Fix HandshakeFailure usages 2024-11-15 16:27:03 +08:00
世界
7f621fdd78
Add freelru.SetUpdateLifetimeOnGet/GetWithLifetime 2024-11-14 17:49:49 +08:00
世界
ae139d9ee1
Update N.PayloadDialer 2024-11-14 17:49:49 +08:00
世界
c432befd02
http: Fix proxying websocket 2024-11-13 19:02:07 +08:00
世界
cc7e630923
control: Refactor interface finder 2024-11-12 20:15:50 +08:00
世界
0998999911
udpnat2: Fix missing shared impl 2024-11-09 11:40:27 +08:00
世界
72ff654ee0
shared: Add SetHealthCheck to interface 2024-11-09 11:40:27 +08:00
世界
11ffb962ae
freelru: Fix impl 2024-11-09 11:40:27 +08:00
世界
fcb19641e6
freelru: Copy shared source 2024-11-09 11:40:27 +08:00
世界
524a6bd0d1
udpnat2: Set upstream to writer 2024-11-09 11:40:27 +08:00
世界
b5f9e70ffd
badjson: Fix Listable 2024-11-09 11:40:27 +08:00
世界
c80c8f907c
badjson: Add context marshaler/unmarshaler 2024-11-05 18:43:05 +08:00
世界
a4eb7fa900
udpnat2: Add SetHandler 2024-11-05 18:43:05 +08:00
世界
7ec09d6045
udpnat2: New synced udp nat service 2024-11-05 18:43:04 +08:00
世界
0641c71805
maphash: copy source from v0.1.0 2024-11-05 18:43:04 +08:00
世界
e7ec021b81
freelru: copy source from v0.14.0 2024-11-05 18:43:04 +08:00
世界
0f2447a95b
Crazy sekai overturns the small pond 2024-11-05 18:43:04 +08:00
世界
72db784fc7
Add bind.Interface.Flags 2024-11-04 11:05:38 +08:00
世界
d59ac57aaa
Add go1.21 compat funcs 2024-10-19 09:09:15 +08:00
世界
c63546470b
Add Update() error to control.InterfaceFinder 2024-09-22 22:15:12 +08:00
世界
55908bea36
Update linter configuration 2024-09-14 21:36:41 +08:00
世界
6567829958
Fix cached conn eats up read deadlines 2024-09-14 10:11:50 +08:00
世界
c324d4143d
json: Add badoption templates 2024-09-10 23:57:22 +08:00
世界
0acb36c118
Minor fixes 2024-09-10 23:46:03 +08:00
世界
26511a251f
udpnat: Fix read deadline not initialized 2024-08-19 17:56:31 +08:00
世界
afd8993773
windnsapi: Fix incorrect error checking 2024-08-19 17:47:42 +08:00
世界
96bef0733f
Fix bad group usages 2024-08-18 11:15:20 +08:00
世界
ec1df651e8
Update golangci-lint configuration 2024-08-18 11:15:15 +08:00
世界
e33b1d67d5
bufio: Add ReadBufferSize and ReadPacketSize 2024-08-18 09:14:52 +08:00
世界
ed6cde73f7
udpnat: Implement read deadline 2024-08-18 09:14:44 +08:00
世界
73cc65605e
pipe: Make pipeDeadline public for use 2024-08-18 08:59:04 +08:00
世界
6c19e0736d
windows: Migrate to mkwinsyscall 2024-08-09 12:00:21 +08:00
世界
08e8c02fb1
Fix usage of PowerUnregisterSuspendResumeNotification 2024-08-08 15:27:35 +08:00
世界
7beca62e4f
Improve winpowrprof callback 2024-08-06 13:19:09 +08:00
世界
e422e3d048
Reuse winpowrprof callback 2024-08-06 12:23:56 +08:00
世界
fa81eabc29
ntp: Fix a bad context usage 2024-08-06 12:23:56 +08:00
世界
4498e57839
task: Fix context not continuous 2024-07-31 18:06:42 +08:00
世界
f97054e917
ntp: Ignore setup error 2024-07-31 10:25:01 +08:00
世界
a2f9fef936
domain: Add adguard matcher 2024-07-26 08:00:09 +08:00
世界
7893a74f75
json: Add UnmarshalDisallowUnknownFields 2024-07-22 12:02:46 +08:00
世界
332e470075
domain: Add a new label type for domain suffix 2024-07-17 15:55:30 +08:00
世界
2bf9cc7253
Remove unused legacy variable 2024-07-17 15:54:12 +08:00
世界
bf8fc103a4
Fix golangci-lint configuration 2024-07-17 15:53:40 +08:00
世界
774893928c
Add crazy badges to README 2024-07-02 16:39:51 +08:00
renovate[bot]
7ceaf63d41
[dependencies] Update github-actions
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
2024-07-02 16:32:54 +08:00
世界
8806e421f2
Fix renovate configuration 2024-07-02 16:29:55 +08:00
世界
c37f988a4f
Add dump for domain matcher 2024-07-02 15:49:37 +08:00
世界
0b4c0a1283
varbin: Unique pointer value format for array or map value 2024-06-24 19:51:17 +08:00
世界
4745c34b4c
varbin: Accept fixed slices 2024-06-24 13:51:49 +08:00
世界
e0196407a3
Remove bad rw usages 2024-06-24 09:42:23 +08:00
世界
d8ec9c46cc
Deprecated bad rw funcs 2024-06-24 09:42:23 +08:00
世界
a33349366d
Remove bad linkname usage 2024-06-24 09:42:23 +08:00
世界
3155c16990
Use varbin for domain matcher read write 2024-06-24 09:42:23 +08:00
世界
caa4340dc9
binary: Move varint utils to new package 2024-06-24 09:42:23 +08:00
世界
a31dba8ad2
bianry: Improve varint read and write 2024-06-23 14:07:07 +08:00
世界
9571124cf4
badjson: Add disableAppend option 2024-06-22 20:07:31 +08:00
世界
1c495c9b07
Make RoutingMark uint32 2024-06-17 12:38:23 +08:00
wwqgtxx
0f95dfe0e3
Recover context json for go1.20 2024-06-17 12:38:22 +08:00
世界
f3380c8dfe
Drop support for go1.18 and go1.19 2024-06-17 12:38:22 +08:00
世界
ab4353dd13
binary: Add variant data reader/writer 2024-06-17 12:38:22 +08:00
世界
e0e490af7b
binary: init from encoding/binary 2024-06-17 12:38:22 +08:00
世界
ad4d59e2ed
Add control.Interface.HardwareAddr 2024-06-17 12:38:21 +08:00
世界
aca2a85545
Add redirect and tproxy controls 2024-06-17 12:38:21 +08:00
世界
589c7eb4df
contentjson: Add support for tailing comma 2024-06-17 12:38:21 +08:00
世界
2873799b6d
Drop context json for go1.20 2024-06-17 12:38:21 +08:00
世界
47cc308abf
Add InterfaceFinder.Interfaces() []control.Interface 2024-06-17 12:38:21 +08:00
HystericalDragon
d9f2559214
Fix http socks close 2024-06-17 12:37:52 +08:00
世界
b8736cc58d
Update workflows 2024-06-11 21:13:49 +08:00
世界
e82ff8e2e6
Fix get source address from X-Forwarded-For 2024-06-06 22:18:55 +08:00
世界
ba68e017a9
Fix PrefixFromNet 2024-06-06 22:18:20 +08:00
世界
967afcf6c1
Update dependencies 2024-06-06 22:18:13 +08:00
世界
0c110ad733
Update dependencies 2024-05-27 19:35:12 +08:00
世界
de1b0bd772
Fix Fix socks5 packet conn 2024-05-21 15:09:17 +08:00
wwqgtxx
f67a0988a6
Remove linkname usages of x/sys/windows 2024-05-18 20:51:36 +08:00
世界
e0ee7f49e2
Remove disallowed linkname usages 2024-05-18 13:20:31 +08:00
dyhkwong
284cb5ce98
Fix socks5 packet conn 2024-05-17 21:24:05 +08:00
世界
8fb1634c9a
Fix calculate reader headroom 2024-05-17 13:20:35 +08:00
世界
4ab8cac5eb
Update dependencies 2024-04-23 23:36:51 +08:00
世界
eec2fc325a
Improve interface finder 2024-04-12 09:05:39 +08:00
世界
2fa039945c
Add control.SetKeepAlivePeriod 2024-04-10 20:25:47 +08:00
世界
e5825dcb59
Merge time service to library 2024-04-10 20:25:47 +08:00
wwqgtxx
6b73a57a24
Add winpowrprof package 2024-04-10 20:25:46 +08:00
世界
3e2631ef0b
badjson: Improve omitempty 2024-04-10 20:25:46 +08:00
世界
4d96f15eca
Improve domain suffix match behavior 2024-04-10 20:25:46 +08:00
dyhkwong
f9c59e9940
Improve bufio.NATPacketConn 2024-04-10 20:24:58 +08:00
dyhkwong
8b68fc4d7a
Fix canceler.PacketConn 2024-04-10 20:17:34 +08:00
世界
5bfc326913
Fix syscall packet read waiter 2024-03-25 01:47:03 +08:00
世界
04152ea672
Fix canceler 2024-03-24 00:12:02 +08:00
wwqgtxx
a069af4787
Fix TypedValue 2024-03-13 16:07:00 +08:00
世界
807a51bb81
Fix crash in HTTP server again 2024-03-10 16:52:22 +08:00
DuFoxit
ec2595f010
Fix SwitchyOmega authentication failed 2024-03-05 13:04:20 +08:00
世界
c98e8b6921
Update dependencies 2024-03-05 12:51:24 +08:00
世界
a4a9ec42c6
Fix crash in HTTP server 2024-03-02 14:25:20 +08:00
世界
6e3921083b
Fix SO_BINDTOIFINDEX usage 2024-02-29 12:56:57 +08:00
世界
8e89f9b4dc
Update workflow 2024-02-29 12:54:19 +08:00
世界
5f02cb1cff
Fix HTTP server authenticate 2024-02-22 20:46:01 +08:00
世界
ef00a1ec1e
Fix badjson merge 2024-02-22 17:27:13 +08:00
世界
5ee4f84faf
Fix IPv6 handshake for HTTP proxy 2024-02-21 13:40:24 +08:00
世界
30f7629317
Update dependencies 2024-02-21 13:40:21 +08:00
世界
9e1749e108
Fix task 2024-02-10 11:41:03 +08:00
世界
b1355d7a4b
Improve rw.Discard 2024-02-10 11:41:02 +08:00
世界
45f572495e
Fix actions/setup-go usage 2024-01-24 11:53:34 +08:00
世界
3ac055b755
Fix missing Upstream() for timeout debug conns 2023-12-26 17:28:48 +08:00
世界
a6e8fa3019
Remove legacy writev funcs 2023-12-25 00:05:51 +08:00
世界
57b8a4c64a
Add test for copy waiter 2023-12-25 00:05:51 +08:00
wwqgtxx
b7a631f798
Support syscallReadWaiter on windows 2023-12-25 00:05:51 +08:00
世界
fa0cc448dc
Fix missing clear iovecList for unix vectorised packet writer 2023-12-25 00:05:51 +08:00
世界
0d1b3d6d6d
badjson: Refactor TypedMap to handle multiple key types 2023-12-25 00:05:51 +08:00
世界
2196f193ac
Fix filemanager.MkdirAll 2023-12-24 08:03:43 +08:00
世界
c501a58ae7
Add 'preferInterfaceName' parameter to BindToInterface0 2023-12-24 08:03:43 +08:00
世界
c9319a35ee
Remove unnecessary context wrappers 2023-12-24 08:03:42 +08:00
世界
cdb9908442
Improve pause manager 2023-12-24 08:03:42 +08:00
世界
81d1bc2768
Fix vectorised writer test 2023-12-24 08:03:41 +08:00
世界
2e36fa6849
Improve vectorised writer 2023-12-24 08:03:41 +08:00
世界
edd320c3a8
badjson: Add UnmarshalExtended 2023-12-24 08:03:41 +08:00
世界
56b953e091
Add nil checking to MergeJSON function in badjson
The MergeJSON function in the badjson package has been updated to handle cases where the source or destination raw JSON is nil. This introduces error checking that results in returning an error when both are nil and returning the non-nil JSON when only one is nil.
2023-12-24 08:03:41 +08:00
世界
4c4773fe54
badjson: Refactor and restructure Merge functions
The Merge function has been refactored for clearer code by splitting it into multiple functions: "Merge", "MergeFromSource", "MergeFromDestination" and "MergeFrom" in the badjson package. These new functions improve the handling of raw JSON during merge and with this refactoring, the responsibility of each function is more defined.
2023-12-24 08:03:41 +08:00
世界
36acc18bfb
Add Ptr function to common
A new function Ptr has been added to common/cond.go. This function returns address of any type of object that is passed as an argument to it. This enhancement will be useful in many scenarios where pointer to an object is required.
2023-12-24 08:03:40 +08:00
世界
ad670bab68
Improve WriteZeroN using clear in go1.21 2023-12-24 08:03:40 +08:00
世界
2a2dbf1971
Add compat func for clear 2023-12-24 08:03:40 +08:00
世界
afa72012e5
Update renovate configuration 2023-12-24 08:03:40 +08:00
世界
c7ef05a85b
Fix buffer
Will be merged into f0be1a9e
2023-12-24 08:03:39 +08:00
世界
0f7de716ac
Refactor the Authenticator interface to a struct 2023-12-24 08:03:39 +08:00
世界
231d7607bc
Enable read wait copy for windows 2023-12-24 08:03:39 +08:00
世界
8b43ec8058
Add reserve support for buffer 2023-12-24 08:03:39 +08:00
世界
c17babe0ba
Merge ThreadSafeReader into ReadWaiter interface 2023-12-24 08:03:38 +08:00
世界
1f02d6daca
Implementation read waiter for pipe 2023-12-24 08:03:38 +08:00
世界
aa34723225
Implementation read waiter for socks5 UDP and UoT 2023-12-24 08:03:38 +08:00
世界
ae8098ad39
Refactor read waiter interface 2023-12-24 08:03:38 +08:00
世界
05c71c99d1
badjson: Add Omitempty 2023-12-24 08:03:37 +08:00
世界
060edf2d69
badjson: Remove empty JSON object in JSON object 2023-12-24 08:03:37 +08:00
世界
d171f04941
json: use context json in go1.20 2023-12-24 08:03:36 +08:00
世界
51aeb14a87
contextjson120: Add context to decode error message 2023-12-24 08:03:36 +08:00
世界
96f5dea24b
contextjson120: Import form go1.20.11 2023-12-24 08:03:36 +08:00
世界
3336b50119
Migrate json wrapper and badjson to library 2023-12-24 08:03:35 +08:00
世界
36be4ef141
contextjson: Add context to decode error message 2023-12-24 08:03:35 +08:00
世界
843bab522a
contentjson: Import from go1.21.4 2023-12-24 08:03:35 +08:00
H1JK
af92594d6d
Shrink buf pool range 2023-12-24 08:03:35 +08:00
H1JK
f23499eaea
Pool allocate arrays instead of slices
This is inspired by https://go-review.googlesource.com/c/net/+/539915
2023-12-24 08:03:34 +08:00
世界
d7ce998e7e
Remove legacy buffer header 2023-12-24 08:03:34 +08:00
世界
99d07d6e5a
Add concurrency limit for task 2023-12-24 08:03:34 +08:00
世界
028dcd722c
Add serialize support for domain matcher 2023-12-24 08:03:34 +08:00
212 changed files with 15666 additions and 1379 deletions

View file

@ -5,6 +5,9 @@
"config:base",
":disableRateLimiting"
],
"baseBranches": [
"dev"
],
"packageRules": [
{
"matchManagers": [
@ -14,9 +17,9 @@
},
{
"matchManagers": [
"gomod"
"dockerfile"
],
"groupName": "gomod"
"groupName": "Dockerfile"
}
]
}

View file

@ -1,43 +0,0 @@
name: Debug build
on:
push:
branches:
- dev
paths-ignore:
- '**.md'
- '.github/**'
- '!.github/workflows/debug.yml'
pull_request:
branches:
- dev
jobs:
build:
name: Debug build
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ${{ steps.version.outputs.go_version }}
- name: Add cache to Go proxy
run: |
version=`git rev-parse HEAD`
mkdir build
pushd build
go mod init build
go get -v github.com/sagernet/sing@$version
popd
continue-on-error: true
- name: Build
run: |
make test

View file

@ -1,8 +1,9 @@
name: Lint
name: lint
on:
push:
branches:
- main
- dev
paths-ignore:
- '**.md'
@ -10,6 +11,7 @@ on:
- '!.github/workflows/lint.yml'
pull_request:
branches:
- main
- dev
jobs:
@ -21,21 +23,17 @@ jobs:
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: ${{ steps.version.outputs.go_version }}
go-version: ^1.23
- name: Cache go module
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
key: go-${{ hashFiles('**/go.sum') }}
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
uses: golangci/golangci-lint-action@v6
with:
version: latest

112
.github/workflows/test.yml vendored Normal file
View file

@ -0,0 +1,112 @@
name: test
on:
push:
branches:
- main
- dev
paths-ignore:
- '**.md'
- '.github/**'
- '!.github/workflows/debug.yml'
pull_request:
branches:
- main
- dev
jobs:
build:
name: Linux
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.23
- name: Build
run: |
make test
build_go120:
name: Linux (Go 1.20)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.20
continue-on-error: true
- name: Build
run: |
make test
build_go121:
name: Linux (Go 1.21)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.21
continue-on-error: true
- name: Build
run: |
make test
build_go122:
name: Linux (Go 1.22)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.22
continue-on-error: true
- name: Build
run: |
make test
build_windows:
name: Windows
runs-on: windows-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.23
continue-on-error: true
- name: Build
run: |
make test
build_darwin:
name: macOS
runs-on: macos-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.23
continue-on-error: true
- name: Build
run: |
make test

View file

@ -5,6 +5,8 @@ linters:
- govet
- gci
- staticcheck
- paralleltest
- ineffassign
linters-settings:
gci:
@ -14,4 +16,9 @@ linters-settings:
- prefix(github.com/sagernet/)
- default
staticcheck:
go: '1.20'
checks:
- all
- -SA1003
run:
go: "1.23"

View file

@ -8,14 +8,14 @@ fmt_install:
go install -v github.com/daixiang0/gci@latest
lint:
GOOS=linux golangci-lint run ./...
GOOS=android golangci-lint run ./...
GOOS=windows golangci-lint run ./...
GOOS=darwin golangci-lint run ./...
GOOS=freebsd golangci-lint run ./...
GOOS=linux golangci-lint run
GOOS=android golangci-lint run
GOOS=windows golangci-lint run
GOOS=darwin golangci-lint run
GOOS=freebsd golangci-lint run
lint_install:
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
test:
go test -v ./...
go test ./...

View file

@ -1,3 +1,6 @@
# sing
![test](https://github.com/sagernet/sing/actions/workflows/test.yml/badge.svg)
![lint](https://github.com/sagernet/sing/actions/workflows/lint.yml/badge.svg)
Do you hear the people sing?

View file

@ -10,26 +10,37 @@ type TypedValue[T any] struct {
value atomic.Value
}
// typedValue is a struct with determined type to resolve atomic.Value usages with interface types
// https://github.com/golang/go/issues/22550
//
// The intention to have an atomic value store for errors. However, running this code panics:
// panic: sync/atomic: store of inconsistently typed value into Value
// This is because atomic.Value requires that the underlying concrete type be the same (which is a reasonable expectation for its implementation).
// When going through the atomic.Value.Store method call, the fact that both these are of the error interface is lost.
type typedValue[T any] struct {
value T
}
func (t *TypedValue[T]) Load() T {
value := t.value.Load()
if value == nil {
return common.DefaultValue[T]()
}
return value.(T)
return value.(typedValue[T]).value
}
func (t *TypedValue[T]) Store(value T) {
t.value.Store(value)
t.value.Store(typedValue[T]{value})
}
func (t *TypedValue[T]) Swap(new T) T {
old := t.value.Swap(new)
old := t.value.Swap(typedValue[T]{new})
if old == nil {
return common.DefaultValue[T]()
}
return old.(T)
return old.(typedValue[T]).value
}
func (t *TypedValue[T]) CompareAndSwap(old, new T) bool {
return t.value.CompareAndSwap(old, new)
return t.value.CompareAndSwap(typedValue[T]{old}, typedValue[T]{new})
}

View file

@ -1,38 +1,30 @@
package auth
type Authenticator interface {
Verify(user string, pass string) bool
Users() []string
}
import "github.com/sagernet/sing/common"
type User struct {
Username string `json:"username"`
Password string `json:"password"`
Username string
Password string
}
type inMemoryAuthenticator struct {
storage map[string]string
usernames []string
type Authenticator struct {
userMap map[string][]string
}
func (au *inMemoryAuthenticator) Verify(username string, password string) bool {
realPass, ok := au.storage[username]
return ok && realPass == password
}
func (au *inMemoryAuthenticator) Users() []string { return au.usernames }
func NewAuthenticator(users []User) Authenticator {
func NewAuthenticator(users []User) *Authenticator {
if len(users) == 0 {
return nil
}
au := &inMemoryAuthenticator{
storage: make(map[string]string),
usernames: make([]string, 0, len(users)),
au := &Authenticator{
userMap: make(map[string][]string),
}
for _, user := range users {
au.storage[user.Username] = user.Password
au.usernames = append(au.usernames, user.Username)
au.userMap[user.Username] = append(au.userMap[user.Username], user.Password)
}
return au
}
func (au *Authenticator) Verify(username string, password string) bool {
passwordList, ok := au.userMap[username]
return ok && common.Contains(passwordList, password)
}

View file

@ -2,11 +2,10 @@ package baderror
import (
"context"
"errors"
"io"
"net"
"strings"
E "github.com/sagernet/sing/common/exceptions"
)
func Contains(err error, msgList ...string) bool {
@ -22,8 +21,7 @@ func WrapH2(err error) error {
if err == nil {
return nil
}
err = E.Unwrap(err)
if err == io.ErrUnexpectedEOF {
if errors.Is(err, io.ErrUnexpectedEOF) {
return io.EOF
}
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {

3
common/binary/README.md Normal file
View file

@ -0,0 +1,3 @@
# binary
mod from go 1.22.3

817
common/binary/binary.go Normal file
View file

@ -0,0 +1,817 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package binary implements simple translation between numbers and byte
// sequences and encoding and decoding of varints.
//
// Numbers are translated by reading and writing fixed-size values.
// A fixed-size value is either a fixed-size arithmetic
// type (bool, int8, uint8, int16, float32, complex64, ...)
// or an array or struct containing only fixed-size values.
//
// The varint functions encode and decode single integer values using
// a variable-length encoding; smaller values require fewer bytes.
// For a specification, see
// https://developers.google.com/protocol-buffers/docs/encoding.
//
// This package favors simplicity over efficiency. Clients that require
// high-performance serialization, especially for large data structures,
// should look at more advanced solutions such as the [encoding/gob]
// package or [google.golang.org/protobuf] for protocol buffers.
package binary
import (
"errors"
"io"
"math"
"reflect"
"sync"
)
// A ByteOrder specifies how to convert byte slices into
// 16-, 32-, or 64-bit unsigned integers.
//
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
type ByteOrder interface {
Uint16([]byte) uint16
Uint32([]byte) uint32
Uint64([]byte) uint64
PutUint16([]byte, uint16)
PutUint32([]byte, uint32)
PutUint64([]byte, uint64)
String() string
}
// AppendByteOrder specifies how to append 16-, 32-, or 64-bit unsigned integers
// into a byte slice.
//
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
type AppendByteOrder interface {
AppendUint16([]byte, uint16) []byte
AppendUint32([]byte, uint32) []byte
AppendUint64([]byte, uint64) []byte
String() string
}
// LittleEndian is the little-endian implementation of [ByteOrder] and [AppendByteOrder].
var LittleEndian littleEndian
// BigEndian is the big-endian implementation of [ByteOrder] and [AppendByteOrder].
var BigEndian bigEndian
type littleEndian struct{}
func (littleEndian) Uint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[0]) | uint16(b[1])<<8
}
func (littleEndian) PutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
}
func (littleEndian) AppendUint16(b []byte, v uint16) []byte {
return append(b,
byte(v),
byte(v>>8),
)
}
func (littleEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func (littleEndian) PutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
}
func (littleEndian) AppendUint32(b []byte, v uint32) []byte {
return append(b,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
)
}
func (littleEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
func (littleEndian) PutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
b[4] = byte(v >> 32)
b[5] = byte(v >> 40)
b[6] = byte(v >> 48)
b[7] = byte(v >> 56)
}
func (littleEndian) AppendUint64(b []byte, v uint64) []byte {
return append(b,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
byte(v>>32),
byte(v>>40),
byte(v>>48),
byte(v>>56),
)
}
func (littleEndian) String() string { return "LittleEndian" }
func (littleEndian) GoString() string { return "binary.LittleEndian" }
type bigEndian struct{}
func (bigEndian) Uint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[1]) | uint16(b[0])<<8
}
func (bigEndian) PutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 8)
b[1] = byte(v)
}
func (bigEndian) AppendUint16(b []byte, v uint16) []byte {
return append(b,
byte(v>>8),
byte(v),
)
}
func (bigEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}
func (bigEndian) PutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
b[3] = byte(v)
}
func (bigEndian) AppendUint32(b []byte, v uint32) []byte {
return append(b,
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v),
)
}
func (bigEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
}
func (bigEndian) PutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
b[3] = byte(v >> 32)
b[4] = byte(v >> 24)
b[5] = byte(v >> 16)
b[6] = byte(v >> 8)
b[7] = byte(v)
}
func (bigEndian) AppendUint64(b []byte, v uint64) []byte {
return append(b,
byte(v>>56),
byte(v>>48),
byte(v>>40),
byte(v>>32),
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v),
)
}
func (bigEndian) String() string { return "BigEndian" }
func (bigEndian) GoString() string { return "binary.BigEndian" }
func (nativeEndian) String() string { return "NativeEndian" }
func (nativeEndian) GoString() string { return "binary.NativeEndian" }
// Read reads structured binary data from r into data.
// Data must be a pointer to a fixed-size value or a slice
// of fixed-size values.
// Bytes read from r are decoded using the specified byte order
// and written to successive fields of the data.
// When decoding boolean values, a zero byte is decoded as false, and
// any other non-zero byte is decoded as true.
// When reading into structs, the field data for fields with
// blank (_) field names is skipped; i.e., blank field names
// may be used for padding.
// When reading into a struct, all non-blank fields must be exported
// or Read may panic.
//
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// Read returns [io.ErrUnexpectedEOF].
func Read(r io.Reader, order ByteOrder, data any) error {
// Fast path for basic types and slices.
if n := intDataSize(data); n != 0 {
bs := make([]byte, n)
if _, err := io.ReadFull(r, bs); err != nil {
return err
}
switch data := data.(type) {
case *bool:
*data = bs[0] != 0
case *int8:
*data = int8(bs[0])
case *uint8:
*data = bs[0]
case *int16:
*data = int16(order.Uint16(bs))
case *uint16:
*data = order.Uint16(bs)
case *int32:
*data = int32(order.Uint32(bs))
case *uint32:
*data = order.Uint32(bs)
case *int64:
*data = int64(order.Uint64(bs))
case *uint64:
*data = order.Uint64(bs)
case *float32:
*data = math.Float32frombits(order.Uint32(bs))
case *float64:
*data = math.Float64frombits(order.Uint64(bs))
case []bool:
for i, x := range bs { // Easier to loop over the input for 8-bit values.
data[i] = x != 0
}
case []int8:
for i, x := range bs {
data[i] = int8(x)
}
case []uint8:
copy(data, bs)
case []int16:
for i := range data {
data[i] = int16(order.Uint16(bs[2*i:]))
}
case []uint16:
for i := range data {
data[i] = order.Uint16(bs[2*i:])
}
case []int32:
for i := range data {
data[i] = int32(order.Uint32(bs[4*i:]))
}
case []uint32:
for i := range data {
data[i] = order.Uint32(bs[4*i:])
}
case []int64:
for i := range data {
data[i] = int64(order.Uint64(bs[8*i:]))
}
case []uint64:
for i := range data {
data[i] = order.Uint64(bs[8*i:])
}
case []float32:
for i := range data {
data[i] = math.Float32frombits(order.Uint32(bs[4*i:]))
}
case []float64:
for i := range data {
data[i] = math.Float64frombits(order.Uint64(bs[8*i:]))
}
default:
n = 0 // fast path doesn't apply
}
if n != 0 {
return nil
}
}
// Fallback to reflect-based decoding.
v := reflect.ValueOf(data)
size := -1
switch v.Kind() {
case reflect.Pointer:
v = v.Elem()
size = dataSize(v)
case reflect.Slice:
size = dataSize(v)
}
if size < 0 {
return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String())
}
d := &decoder{order: order, buf: make([]byte, size)}
if _, err := io.ReadFull(r, d.buf); err != nil {
return err
}
d.value(v)
return nil
}
// Write writes the binary representation of data into w.
// Data must be a fixed-size value or a slice of fixed-size
// values, or a pointer to such data.
// Boolean values encode as one byte: 1 for true, and 0 for false.
// Bytes written to w are encoded using the specified byte order
// and read from successive fields of the data.
// When writing structs, zero values are written for fields
// with blank (_) field names.
func Write(w io.Writer, order ByteOrder, data any) error {
// Fast path for basic types and slices.
if n := intDataSize(data); n != 0 {
bs := make([]byte, n)
switch v := data.(type) {
case *bool:
if *v {
bs[0] = 1
} else {
bs[0] = 0
}
case bool:
if v {
bs[0] = 1
} else {
bs[0] = 0
}
case []bool:
for i, x := range v {
if x {
bs[i] = 1
} else {
bs[i] = 0
}
}
case *int8:
bs[0] = byte(*v)
case int8:
bs[0] = byte(v)
case []int8:
for i, x := range v {
bs[i] = byte(x)
}
case *uint8:
bs[0] = *v
case uint8:
bs[0] = v
case []uint8:
bs = v
case *int16:
order.PutUint16(bs, uint16(*v))
case int16:
order.PutUint16(bs, uint16(v))
case []int16:
for i, x := range v {
order.PutUint16(bs[2*i:], uint16(x))
}
case *uint16:
order.PutUint16(bs, *v)
case uint16:
order.PutUint16(bs, v)
case []uint16:
for i, x := range v {
order.PutUint16(bs[2*i:], x)
}
case *int32:
order.PutUint32(bs, uint32(*v))
case int32:
order.PutUint32(bs, uint32(v))
case []int32:
for i, x := range v {
order.PutUint32(bs[4*i:], uint32(x))
}
case *uint32:
order.PutUint32(bs, *v)
case uint32:
order.PutUint32(bs, v)
case []uint32:
for i, x := range v {
order.PutUint32(bs[4*i:], x)
}
case *int64:
order.PutUint64(bs, uint64(*v))
case int64:
order.PutUint64(bs, uint64(v))
case []int64:
for i, x := range v {
order.PutUint64(bs[8*i:], uint64(x))
}
case *uint64:
order.PutUint64(bs, *v)
case uint64:
order.PutUint64(bs, v)
case []uint64:
for i, x := range v {
order.PutUint64(bs[8*i:], x)
}
case *float32:
order.PutUint32(bs, math.Float32bits(*v))
case float32:
order.PutUint32(bs, math.Float32bits(v))
case []float32:
for i, x := range v {
order.PutUint32(bs[4*i:], math.Float32bits(x))
}
case *float64:
order.PutUint64(bs, math.Float64bits(*v))
case float64:
order.PutUint64(bs, math.Float64bits(v))
case []float64:
for i, x := range v {
order.PutUint64(bs[8*i:], math.Float64bits(x))
}
}
_, err := w.Write(bs)
return err
}
// Fallback to reflect-based encoding.
v := reflect.Indirect(reflect.ValueOf(data))
size := dataSize(v)
if size < 0 {
return errors.New("binary.Write: some values are not fixed-sized in type " + reflect.TypeOf(data).String())
}
buf := make([]byte, size)
e := &encoder{order: order, buf: buf}
e.value(v)
_, err := w.Write(buf)
return err
}
// Size returns how many bytes [Write] would generate to encode the value v, which
// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data.
// If v is neither of these, Size returns -1.
func Size(v any) int {
return dataSize(reflect.Indirect(reflect.ValueOf(v)))
}
var structSize sync.Map // map[reflect.Type]int
// dataSize returns the number of bytes the actual data represented by v occupies in memory.
// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice
// it returns the length of the slice times the element size and does not count the memory
// occupied by the header. If the type of v is not acceptable, dataSize returns -1.
func dataSize(v reflect.Value) int {
switch v.Kind() {
case reflect.Slice:
if s := sizeof(v.Type().Elem()); s >= 0 {
return s * v.Len()
}
case reflect.Struct:
t := v.Type()
if size, ok := structSize.Load(t); ok {
return size.(int)
}
size := sizeof(t)
structSize.Store(t, size)
return size
default:
if v.IsValid() {
return sizeof(v.Type())
}
}
return -1
}
// sizeof returns the size >= 0 of variables for the given type or -1 if the type is not acceptable.
func sizeof(t reflect.Type) int {
switch t.Kind() {
case reflect.Array:
if s := sizeof(t.Elem()); s >= 0 {
return s * t.Len()
}
case reflect.Struct:
sum := 0
for i, n := 0, t.NumField(); i < n; i++ {
s := sizeof(t.Field(i).Type)
if s < 0 {
return -1
}
sum += s
}
return sum
case reflect.Bool,
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
return int(t.Size())
}
return -1
}
type coder struct {
order ByteOrder
buf []byte
offset int
}
type (
decoder coder
encoder coder
)
func (d *decoder) bool() bool {
x := d.buf[d.offset]
d.offset++
return x != 0
}
func (e *encoder) bool(x bool) {
if x {
e.buf[e.offset] = 1
} else {
e.buf[e.offset] = 0
}
e.offset++
}
func (d *decoder) uint8() uint8 {
x := d.buf[d.offset]
d.offset++
return x
}
func (e *encoder) uint8(x uint8) {
e.buf[e.offset] = x
e.offset++
}
func (d *decoder) uint16() uint16 {
x := d.order.Uint16(d.buf[d.offset : d.offset+2])
d.offset += 2
return x
}
func (e *encoder) uint16(x uint16) {
e.order.PutUint16(e.buf[e.offset:e.offset+2], x)
e.offset += 2
}
func (d *decoder) uint32() uint32 {
x := d.order.Uint32(d.buf[d.offset : d.offset+4])
d.offset += 4
return x
}
func (e *encoder) uint32(x uint32) {
e.order.PutUint32(e.buf[e.offset:e.offset+4], x)
e.offset += 4
}
func (d *decoder) uint64() uint64 {
x := d.order.Uint64(d.buf[d.offset : d.offset+8])
d.offset += 8
return x
}
func (e *encoder) uint64(x uint64) {
e.order.PutUint64(e.buf[e.offset:e.offset+8], x)
e.offset += 8
}
func (d *decoder) int8() int8 { return int8(d.uint8()) }
func (e *encoder) int8(x int8) { e.uint8(uint8(x)) }
func (d *decoder) int16() int16 { return int16(d.uint16()) }
func (e *encoder) int16(x int16) { e.uint16(uint16(x)) }
func (d *decoder) int32() int32 { return int32(d.uint32()) }
func (e *encoder) int32(x int32) { e.uint32(uint32(x)) }
func (d *decoder) int64() int64 { return int64(d.uint64()) }
func (e *encoder) int64(x int64) { e.uint64(uint64(x)) }
func (d *decoder) value(v reflect.Value) {
switch v.Kind() {
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
d.value(v.Index(i))
}
case reflect.Struct:
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
// Note: Calling v.CanSet() below is an optimization.
// It would be sufficient to check the field name,
// but creating the StructField info for each field is
// costly (run "go test -bench=ReadStruct" and compare
// results when making changes to this code).
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
d.value(v)
} else {
d.skip(v)
}
}
case reflect.Slice:
l := v.Len()
for i := 0; i < l; i++ {
d.value(v.Index(i))
}
case reflect.Bool:
v.SetBool(d.bool())
case reflect.Int8:
v.SetInt(int64(d.int8()))
case reflect.Int16:
v.SetInt(int64(d.int16()))
case reflect.Int32:
v.SetInt(int64(d.int32()))
case reflect.Int64:
v.SetInt(d.int64())
case reflect.Uint8:
v.SetUint(uint64(d.uint8()))
case reflect.Uint16:
v.SetUint(uint64(d.uint16()))
case reflect.Uint32:
v.SetUint(uint64(d.uint32()))
case reflect.Uint64:
v.SetUint(d.uint64())
case reflect.Float32:
v.SetFloat(float64(math.Float32frombits(d.uint32())))
case reflect.Float64:
v.SetFloat(math.Float64frombits(d.uint64()))
case reflect.Complex64:
v.SetComplex(complex(
float64(math.Float32frombits(d.uint32())),
float64(math.Float32frombits(d.uint32())),
))
case reflect.Complex128:
v.SetComplex(complex(
math.Float64frombits(d.uint64()),
math.Float64frombits(d.uint64()),
))
}
}
func (e *encoder) value(v reflect.Value) {
switch v.Kind() {
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
e.value(v.Index(i))
}
case reflect.Struct:
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
// see comment for corresponding code in decoder.value()
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
e.value(v)
} else {
e.skip(v)
}
}
case reflect.Slice:
l := v.Len()
for i := 0; i < l; i++ {
e.value(v.Index(i))
}
case reflect.Bool:
e.bool(v.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch v.Type().Kind() {
case reflect.Int8:
e.int8(int8(v.Int()))
case reflect.Int16:
e.int16(int16(v.Int()))
case reflect.Int32:
e.int32(int32(v.Int()))
case reflect.Int64:
e.int64(v.Int())
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
switch v.Type().Kind() {
case reflect.Uint8:
e.uint8(uint8(v.Uint()))
case reflect.Uint16:
e.uint16(uint16(v.Uint()))
case reflect.Uint32:
e.uint32(uint32(v.Uint()))
case reflect.Uint64:
e.uint64(v.Uint())
}
case reflect.Float32, reflect.Float64:
switch v.Type().Kind() {
case reflect.Float32:
e.uint32(math.Float32bits(float32(v.Float())))
case reflect.Float64:
e.uint64(math.Float64bits(v.Float()))
}
case reflect.Complex64, reflect.Complex128:
switch v.Type().Kind() {
case reflect.Complex64:
x := v.Complex()
e.uint32(math.Float32bits(float32(real(x))))
e.uint32(math.Float32bits(float32(imag(x))))
case reflect.Complex128:
x := v.Complex()
e.uint64(math.Float64bits(real(x)))
e.uint64(math.Float64bits(imag(x)))
}
}
}
func (d *decoder) skip(v reflect.Value) {
d.offset += dataSize(v)
}
func (e *encoder) skip(v reflect.Value) {
n := dataSize(v)
zero := e.buf[e.offset : e.offset+n]
for i := range zero {
zero[i] = 0
}
e.offset += n
}
// intDataSize returns the size of the data required to represent the data when encoded.
// It returns zero if the type cannot be implemented by the fast path in Read or Write.
func intDataSize(data any) int {
switch data := data.(type) {
case bool, int8, uint8, *bool, *int8, *uint8:
return 1
case []bool:
return len(data)
case []int8:
return len(data)
case []uint8:
return len(data)
case int16, uint16, *int16, *uint16:
return 2
case []int16:
return 2 * len(data)
case []uint16:
return 2 * len(data)
case int32, uint32, *int32, *uint32:
return 4
case []int32:
return 4 * len(data)
case []uint32:
return 4 * len(data)
case int64, uint64, *int64, *uint64:
return 8
case []int64:
return 8 * len(data)
case []uint64:
return 8 * len(data)
case float32, *float32:
return 4
case float64, *float64:
return 8
case []float32:
return 4 * len(data)
case []float64:
return 8 * len(data)
}
return 0
}

18
common/binary/export.go Normal file
View file

@ -0,0 +1,18 @@
package binary
import (
"encoding/binary"
"reflect"
)
func DataSize(t reflect.Value) int {
return dataSize(t)
}
func EncodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
(&encoder{order: order, buf: buf}).value(v)
}
func DecodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
(&decoder{order: order, buf: buf}).value(v)
}

View file

@ -0,0 +1,14 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build armbe || arm64be || m68k || mips || mips64 || mips64p32 || ppc || ppc64 || s390 || s390x || shbe || sparc || sparc64
package binary
type nativeEndian struct {
bigEndian
}
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
var NativeEndian nativeEndian

View file

@ -0,0 +1,14 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build 386 || amd64 || amd64p32 || alpha || arm || arm64 || loong64 || mipsle || mips64le || mips64p32le || nios2 || ppc64le || riscv || riscv64 || sh || wasm
package binary
type nativeEndian struct {
littleEndian
}
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
var NativeEndian nativeEndian

166
common/binary/varint.go Normal file
View file

@ -0,0 +1,166 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package binary
// This file implements "varint" encoding of 64-bit integers.
// The encoding is:
// - unsigned integers are serialized 7 bits at a time, starting with the
// least significant bits
// - the most significant bit (msb) in each output byte indicates if there
// is a continuation byte (msb = 1)
// - signed integers are mapped to unsigned integers using "zig-zag"
// encoding: Positive values x are written as 2*x + 0, negative values
// are written as 2*(^x) + 1; that is, negative numbers are complemented
// and whether to complement is encoded in bit 0.
//
// Design note:
// At most 10 bytes are needed for 64-bit values. The encoding could
// be more dense: a full 64-bit value needs an extra byte just to hold bit 63.
// Instead, the msb of the previous byte could be used to hold bit 63 since we
// know there can't be more than 64 bits. This is a trivial improvement and
// would reduce the maximum encoding length to 9 bytes. However, it breaks the
// invariant that the msb is always the "continuation bit" and thus makes the
// format incompatible with a varint encoding for larger numbers (say 128-bit).
import (
"errors"
"io"
)
// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer.
const (
MaxVarintLen16 = 3
MaxVarintLen32 = 5
MaxVarintLen64 = 10
)
// AppendUvarint appends the varint-encoded form of x,
// as generated by [PutUvarint], to buf and returns the extended buffer.
func AppendUvarint(buf []byte, x uint64) []byte {
for x >= 0x80 {
buf = append(buf, byte(x)|0x80)
x >>= 7
}
return append(buf, byte(x))
}
// PutUvarint encodes a uint64 into buf and returns the number of bytes written.
// If the buffer is too small, PutUvarint will panic.
func PutUvarint(buf []byte, x uint64) int {
i := 0
for x >= 0x80 {
buf[i] = byte(x) | 0x80
x >>= 7
i++
}
buf[i] = byte(x)
return i + 1
}
// Uvarint decodes a uint64 from buf and returns that value and the
// number of bytes read (> 0). If an error occurred, the value is 0
// and the number of bytes n is <= 0 meaning:
//
// n == 0: buf too small
// n < 0: value larger than 64 bits (overflow)
// and -n is the number of bytes read
func Uvarint(buf []byte) (uint64, int) {
var x uint64
var s uint
for i, b := range buf {
if i == MaxVarintLen64 {
// Catch byte reads past MaxVarintLen64.
// See issue https://golang.org/issues/41185
return 0, -(i + 1) // overflow
}
if b < 0x80 {
if i == MaxVarintLen64-1 && b > 1 {
return 0, -(i + 1) // overflow
}
return x | uint64(b)<<s, i + 1
}
x |= uint64(b&0x7f) << s
s += 7
}
return 0, 0
}
// AppendVarint appends the varint-encoded form of x,
// as generated by [PutVarint], to buf and returns the extended buffer.
func AppendVarint(buf []byte, x int64) []byte {
ux := uint64(x) << 1
if x < 0 {
ux = ^ux
}
return AppendUvarint(buf, ux)
}
// PutVarint encodes an int64 into buf and returns the number of bytes written.
// If the buffer is too small, PutVarint will panic.
func PutVarint(buf []byte, x int64) int {
ux := uint64(x) << 1
if x < 0 {
ux = ^ux
}
return PutUvarint(buf, ux)
}
// Varint decodes an int64 from buf and returns that value and the
// number of bytes read (> 0). If an error occurred, the value is 0
// and the number of bytes n is <= 0 with the following meaning:
//
// n == 0: buf too small
// n < 0: value larger than 64 bits (overflow)
// and -n is the number of bytes read
func Varint(buf []byte) (int64, int) {
ux, n := Uvarint(buf) // ok to continue in presence of error
x := int64(ux >> 1)
if ux&1 != 0 {
x = ^x
}
return x, n
}
var errOverflow = errors.New("binary: varint overflows a 64-bit integer")
// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64.
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// ReadUvarint returns [io.ErrUnexpectedEOF].
func ReadUvarint(r io.ByteReader) (uint64, error) {
var x uint64
var s uint
for i := 0; i < MaxVarintLen64; i++ {
b, err := r.ReadByte()
if err != nil {
if i > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return x, err
}
if b < 0x80 {
if i == MaxVarintLen64-1 && b > 1 {
return x, errOverflow
}
return x | uint64(b)<<s, nil
}
x |= uint64(b&0x7f) << s
s += 7
}
return x, errOverflow
}
// ReadVarint reads an encoded signed integer from r and returns it as an int64.
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// ReadVarint returns [io.ErrUnexpectedEOF].
func ReadVarint(r io.ByteReader) (int64, error) {
ux, err := ReadUvarint(r) // ok to continue in presence of error
x := int64(ux >> 1)
if ux&1 != 0 {
x = ^x
}
return x, err
}

View file

@ -8,7 +8,7 @@ import (
"sync"
)
var DefaultAllocator = newDefaultAllocer()
var DefaultAllocator = newDefaultAllocator()
type Allocator interface {
Get(size int) []byte
@ -17,22 +17,28 @@ type Allocator interface {
// defaultAllocator for incoming frames, optimized to prevent overwriting after zeroing
type defaultAllocator struct {
buffers []sync.Pool
buffers [11]sync.Pool
}
// NewAllocator initiates a []byte allocator for frames less than 65536 bytes,
// the waste(memory fragmentation) of space allocation is guaranteed to be
// no more than 50%.
func newDefaultAllocer() Allocator {
alloc := new(defaultAllocator)
alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K
for k := range alloc.buffers {
i := k
alloc.buffers[k].New = func() any {
return make([]byte, 1<<uint32(i))
}
func newDefaultAllocator() Allocator {
return &defaultAllocator{
buffers: [...]sync.Pool{ // 64B -> 64K
{New: func() any { return new([1 << 6]byte) }},
{New: func() any { return new([1 << 7]byte) }},
{New: func() any { return new([1 << 8]byte) }},
{New: func() any { return new([1 << 9]byte) }},
{New: func() any { return new([1 << 10]byte) }},
{New: func() any { return new([1 << 11]byte) }},
{New: func() any { return new([1 << 12]byte) }},
{New: func() any { return new([1 << 13]byte) }},
{New: func() any { return new([1 << 14]byte) }},
{New: func() any { return new([1 << 15]byte) }},
{New: func() any { return new([1 << 16]byte) }},
},
}
return alloc
}
// Get a []byte from pool with most appropriate cap
@ -41,12 +47,42 @@ func (alloc *defaultAllocator) Get(size int) []byte {
return nil
}
bits := msb(size)
if size == 1<<bits {
return alloc.buffers[bits].Get().([]byte)[:size]
var index uint16
if size > 64 {
index = msb(size)
if size != 1<<index {
index += 1
}
index -= 6
}
return alloc.buffers[bits+1].Get().([]byte)[:size]
buffer := alloc.buffers[index].Get()
switch index {
case 0:
return buffer.(*[1 << 6]byte)[:size]
case 1:
return buffer.(*[1 << 7]byte)[:size]
case 2:
return buffer.(*[1 << 8]byte)[:size]
case 3:
return buffer.(*[1 << 9]byte)[:size]
case 4:
return buffer.(*[1 << 10]byte)[:size]
case 5:
return buffer.(*[1 << 11]byte)[:size]
case 6:
return buffer.(*[1 << 12]byte)[:size]
case 7:
return buffer.(*[1 << 13]byte)[:size]
case 8:
return buffer.(*[1 << 14]byte)[:size]
case 9:
return buffer.(*[1 << 15]byte)[:size]
case 10:
return buffer.(*[1 << 16]byte)[:size]
default:
panic("invalid pool index")
}
}
// Put returns a []byte to pool for future use,
@ -56,10 +92,37 @@ func (alloc *defaultAllocator) Put(buf []byte) error {
if cap(buf) == 0 || cap(buf) > 65536 || cap(buf) != 1<<bits {
return errors.New("allocator Put() incorrect buffer size")
}
bits -= 6
buf = buf[:cap(buf)]
//nolint
//lint:ignore SA6002 ignore temporarily
alloc.buffers[bits].Put(buf)
switch bits {
case 0:
alloc.buffers[bits].Put((*[1 << 6]byte)(buf))
case 1:
alloc.buffers[bits].Put((*[1 << 7]byte)(buf))
case 2:
alloc.buffers[bits].Put((*[1 << 8]byte)(buf))
case 3:
alloc.buffers[bits].Put((*[1 << 9]byte)(buf))
case 4:
alloc.buffers[bits].Put((*[1 << 10]byte)(buf))
case 5:
alloc.buffers[bits].Put((*[1 << 11]byte)(buf))
case 6:
alloc.buffers[bits].Put((*[1 << 12]byte)(buf))
case 7:
alloc.buffers[bits].Put((*[1 << 13]byte)(buf))
case 8:
alloc.buffers[bits].Put((*[1 << 14]byte)(buf))
case 9:
alloc.buffers[bits].Put((*[1 << 15]byte)(buf))
case 10:
alloc.buffers[bits].Put((*[1 << 16]byte)(buf))
default:
panic("invalid pool index")
}
return nil
}

View file

@ -4,39 +4,36 @@ import (
"crypto/rand"
"io"
"net"
"strconv"
"sync/atomic"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)
const ReversedHeader = 1024
type Buffer struct {
data []byte
start int
end int
refs int32
managed bool
closed bool
data []byte
start int
end int
capacity int
refs atomic.Int32
managed bool
}
func New() *Buffer {
return &Buffer{
data: Get(BufferSize),
start: ReversedHeader,
end: ReversedHeader,
managed: true,
data: Get(BufferSize),
capacity: BufferSize,
managed: true,
}
}
func NewPacket() *Buffer {
return &Buffer{
data: Get(UDPBufferSize),
start: ReversedHeader,
end: ReversedHeader,
managed: true,
data: Get(UDPBufferSize),
capacity: UDPBufferSize,
managed: true,
}
}
@ -45,40 +42,29 @@ func NewSize(size int) *Buffer {
return &Buffer{}
} else if size > 65535 {
return &Buffer{
data: make([]byte, size),
data: make([]byte, size),
capacity: size,
}
}
return &Buffer{
data: Get(size),
managed: true,
data: Get(size),
capacity: size,
managed: true,
}
}
// Deprecated: use New instead.
func StackNew() *Buffer {
return New()
}
// Deprecated: use NewPacket instead.
func StackNewPacket() *Buffer {
return NewPacket()
}
// Deprecated: use NewSize instead.
func StackNewSize(size int) *Buffer {
return NewSize(size)
}
func As(data []byte) *Buffer {
return &Buffer{
data: data,
end: len(data),
data: data,
end: len(data),
capacity: len(data),
}
}
func With(data []byte) *Buffer {
return &Buffer{
data: data,
data: data,
capacity: len(data),
}
}
@ -92,8 +78,8 @@ func (b *Buffer) SetByte(index int, value byte) {
func (b *Buffer) Extend(n int) []byte {
end := b.end + n
if end > cap(b.data) {
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",end " + strconv.Itoa(b.end) + ", need " + strconv.Itoa(n))
if end > b.capacity {
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",end ", b.end, ", need ", n))
}
ext := b.data[b.end:end]
b.end = end
@ -115,14 +101,14 @@ func (b *Buffer) Write(data []byte) (n int, err error) {
if b.IsFull() {
return 0, io.ErrShortBuffer
}
n = copy(b.data[b.end:], data)
n = copy(b.data[b.end:b.capacity], data)
b.end += n
return
}
func (b *Buffer) ExtendHeader(n int) []byte {
if b.start < n {
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",start " + strconv.Itoa(b.start) + ", need " + strconv.Itoa(n))
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",start ", b.start, ", need ", n))
}
b.start -= n
return b.data[b.start : b.start+n]
@ -175,7 +161,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
}
func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) {
if b.end+size > b.Cap() {
if b.end+size > b.capacity {
return 0, io.ErrShortBuffer
}
n, err = io.ReadFull(r, b.data[b.end:b.end+size])
@ -212,7 +198,7 @@ func (b *Buffer) WriteString(s string) (n int, err error) {
if b.IsFull() {
return 0, io.ErrShortBuffer
}
n = copy(b.data[b.end:], s)
n = copy(b.data[b.end:b.capacity], s)
b.end += n
return
}
@ -227,13 +213,10 @@ func (b *Buffer) WriteZero() error {
}
func (b *Buffer) WriteZeroN(n int) error {
if b.end+n > b.Cap() {
if b.end+n > b.capacity {
return io.ErrShortBuffer
}
for i := b.end; i < b.end+n; i++ {
b.data[i] = 0
}
b.end += n
common.ClearArray(b.Extend(n))
return nil
}
@ -276,40 +259,63 @@ func (b *Buffer) Resize(start, end int) {
b.end = b.start + end
}
func (b *Buffer) Reset() {
b.start = ReversedHeader
b.end = ReversedHeader
func (b *Buffer) Reserve(n int) {
if n > b.capacity {
panic(F.ToString("buffer overflow: capacity ", b.capacity, ", need ", n))
}
b.capacity -= n
}
func (b *Buffer) FullReset() {
func (b *Buffer) OverCap(n int) {
if b.capacity+n > len(b.data) {
panic(F.ToString("buffer overflow: capacity ", len(b.data), ", need ", b.capacity+n))
}
b.capacity += n
}
func (b *Buffer) Reset() {
b.start = 0
b.end = 0
b.capacity = len(b.data)
}
// Deprecated: use Reset instead.
func (b *Buffer) FullReset() {
b.Reset()
}
func (b *Buffer) IncRef() {
atomic.AddInt32(&b.refs, 1)
b.refs.Add(1)
}
func (b *Buffer) DecRef() {
atomic.AddInt32(&b.refs, -1)
b.refs.Add(-1)
}
func (b *Buffer) Release() {
if b == nil || b.closed || !b.managed {
if b == nil || !b.managed {
return
}
if atomic.LoadInt32(&b.refs) > 0 {
if b.refs.Load() > 0 {
return
}
common.Must(Put(b.data))
*b = Buffer{closed: true}
*b = Buffer{}
}
func (b *Buffer) Cut(start int, end int) *Buffer {
b.start += start
b.end = len(b.data) - end
return &Buffer{
data: b.data[b.start:b.end],
func (b *Buffer) Leak() {
if debug.Enabled {
if b == nil || !b.managed {
return
}
refs := b.refs.Load()
if refs == 0 {
panic("leaking buffer")
} else {
panic(F.ToString("leaking buffer with ", refs, " references"))
}
} else {
b.Release()
}
}
@ -322,6 +328,10 @@ func (b *Buffer) Len() int {
}
func (b *Buffer) Cap() int {
return b.capacity
}
func (b *Buffer) RawCap() int {
return len(b.data)
}
@ -329,10 +339,6 @@ func (b *Buffer) Bytes() []byte {
return b.data[b.start:b.end]
}
func (b *Buffer) Slice() []byte {
return b.data
}
func (b *Buffer) From(n int) []byte {
return b.data[b.start+n : b.end]
}
@ -350,11 +356,11 @@ func (b *Buffer) Index(start int) []byte {
}
func (b *Buffer) FreeLen() int {
return b.Cap() - b.end
return b.capacity - b.end
}
func (b *Buffer) FreeBytes() []byte {
return b.data[b.end:b.Cap()]
return b.data[b.end:b.capacity]
}
func (b *Buffer) IsEmpty() bool {
@ -362,7 +368,7 @@ func (b *Buffer) IsEmpty() bool {
}
func (b *Buffer) IsFull() bool {
return b.end == b.Cap()
return b.end == b.capacity
}
func (b *Buffer) ToOwned() *Buffer {
@ -370,5 +376,6 @@ func (b *Buffer) ToOwned() *Buffer {
copy(n.data[b.start:b.end], b.data[b.start:b.end])
n.start = b.start
n.end = b.end
n.capacity = b.capacity
return n
}

34
common/bufio/addr_bsd.go Normal file
View file

@ -0,0 +1,34 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package bufio
import (
"encoding/binary"
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
if destination.Addr().Is4() {
sa := unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: destination.Addr().As4(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet4
} else {
sa := unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: destination.Addr().As16(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet6
}
return
}

View file

@ -9,19 +9,20 @@ import (
type AddrConn struct {
net.Conn
M.Metadata
Source M.Socksaddr
Destination M.Socksaddr
}
func (c *AddrConn) LocalAddr() net.Addr {
if c.Metadata.Destination.IsValid() {
return c.Metadata.Destination.TCPAddr()
if c.Destination.IsValid() {
return c.Destination.TCPAddr()
}
return c.Conn.LocalAddr()
}
func (c *AddrConn) RemoteAddr() net.Addr {
if c.Metadata.Source.IsValid() {
return c.Metadata.Source.TCPAddr()
if c.Source.IsValid() {
return c.Source.TCPAddr()
}
return c.Conn.RemoteAddr()
}

View file

@ -0,0 +1,30 @@
package bufio
import (
"encoding/binary"
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
if destination.Addr().Is4() {
sa := unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: destination.Addr().As4(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet4
} else {
sa := unix.RawSockaddrInet6{
Family: unix.AF_INET6,
Addr: destination.Addr().As16(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet6
}
return
}

View file

@ -0,0 +1,30 @@
package bufio
import (
"encoding/binary"
"net/netip"
"unsafe"
"golang.org/x/sys/windows"
)
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen int32) {
if destination.Addr().Is4() {
sa := windows.RawSockaddrInet4{
Family: windows.AF_INET,
Addr: destination.Addr().As4(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = int32(unsafe.Sizeof(sa))
} else {
sa := windows.RawSockaddrInet6{
Family: windows.AF_INET6,
Addr: destination.Addr().As16(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = int32(unsafe.Sizeof(sa))
}
return
}

View file

@ -8,51 +8,76 @@ import (
N "github.com/sagernet/sing/common/network"
)
type BindPacketConn struct {
type BindPacketConn interface {
N.NetPacketConn
Addr net.Addr
net.Conn
}
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) *BindPacketConn {
return &BindPacketConn{
type bindPacketConn struct {
N.NetPacketConn
addr net.Addr
}
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) BindPacketConn {
return &bindPacketConn{
NewPacketConn(conn),
addr,
}
}
func (c *BindPacketConn) Read(b []byte) (n int, err error) {
func (c *bindPacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *BindPacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.Addr)
func (c *bindPacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.addr)
}
func (c *BindPacketConn) RemoteAddr() net.Addr {
return c.Addr
func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
if !isReadWaiter {
return nil, false
}
return &bindPacketReadWaiter{readWaiter}, true
}
func (c *BindPacketConn) Upstream() any {
func (c *bindPacketConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *bindPacketConn) Upstream() any {
return c.NetPacketConn
}
var (
_ N.NetPacketConn = (*UnbindPacketConn)(nil)
_ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil)
)
type UnbindPacketConn struct {
N.ExtendedConn
Addr M.Socksaddr
addr M.Socksaddr
}
func NewUnbindPacketConn(conn net.Conn) *UnbindPacketConn {
func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn {
return &UnbindPacketConn{
NewExtendedConn(conn),
M.SocksaddrFromNet(conn.RemoteAddr()),
}
}
func NewUnbindPacketConnWithAddr(conn net.Conn, addr M.Socksaddr) N.NetPacketConn {
return &UnbindPacketConn{
NewExtendedConn(conn),
addr,
}
}
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.ExtendedConn.Read(p)
if err == nil {
addr = c.Addr.UDPAddr()
addr = c.addr.UDPAddr()
}
return
}
@ -66,7 +91,7 @@ func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
if err != nil {
return
}
destination = c.Addr
destination = c.addr
return
}
@ -74,6 +99,67 @@ func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
readWaiter, isReadWaiter := CreateReadWaiter(c.ExtendedConn)
if !isReadWaiter {
return nil, false
}
return &unbindPacketReadWaiter{readWaiter, c.addr}, true
}
func (c *UnbindPacketConn) Upstream() any {
return c.ExtendedConn
}
func NewServerPacketConn(conn net.PacketConn) N.ExtendedConn {
return &serverPacketConn{
NetPacketConn: NewPacketConn(conn),
}
}
type serverPacketConn struct {
N.NetPacketConn
remoteAddr M.Socksaddr
}
func (c *serverPacketConn) Read(p []byte) (n int, err error) {
n, addr, err := c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
c.remoteAddr = M.SocksaddrFromNet(addr)
return
}
func (c *serverPacketConn) ReadBuffer(buffer *buf.Buffer) error {
destination, err := c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return err
}
c.remoteAddr = destination
return nil
}
func (c *serverPacketConn) Write(p []byte) (n int, err error) {
return c.NetPacketConn.WriteTo(p, c.remoteAddr.UDPAddr())
}
func (c *serverPacketConn) WriteBuffer(buffer *buf.Buffer) error {
return c.NetPacketConn.WritePacket(buffer, c.remoteAddr)
}
func (c *serverPacketConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *serverPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
if !isReadWaiter {
return nil, false
}
return &serverPacketReadWaiter{c, readWaiter}, true
}

62
common/bufio/bind_wait.go Normal file
View file

@ -0,0 +1,62 @@
package bufio
import (
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var _ N.ReadWaiter = (*bindPacketReadWaiter)(nil)
type bindPacketReadWaiter struct {
readWaiter N.PacketReadWaiter
}
func (w *bindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *bindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, _, err = w.readWaiter.WaitReadPacket()
return
}
var _ N.PacketReadWaiter = (*unbindPacketReadWaiter)(nil)
type unbindPacketReadWaiter struct {
readWaiter N.ReadWaiter
addr M.Socksaddr
}
func (w *unbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *unbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, err = w.readWaiter.WaitReadBuffer()
if err != nil {
return
}
destination = w.addr
return
}
var _ N.ReadWaiter = (*serverPacketReadWaiter)(nil)
type serverPacketReadWaiter struct {
*serverPacketConn
readWaiter N.PacketReadWaiter
}
func (w *serverPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *serverPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, destination, err := w.readWaiter.WaitReadPacket()
if err != nil {
return
}
w.remoteAddr = destination
return
}

View file

@ -4,6 +4,7 @@ import (
"io"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
)
@ -37,7 +38,26 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) {
if err != nil {
return
}
w.buffer.FullReset()
w.buffer.Reset()
}
}
func (w *BufferedWriter) WriteByte(c byte) error {
w.access.Lock()
defer w.access.Unlock()
if w.buffer == nil {
return common.Error(w.upstream.Write([]byte{c}))
}
for {
err := w.buffer.WriteByte(c)
if err == nil {
return nil
}
_, err = w.upstream.Write(w.buffer.Bytes())
if err != nil {
return err
}
w.buffer.Reset()
}
}

View file

@ -3,7 +3,6 @@ package bufio
import (
"io"
"net"
"time"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
@ -60,13 +59,6 @@ func (c *CachedConn) WriteTo(w io.Writer) (n int64, err error) {
return
}
func (c *CachedConn) SetReadDeadline(t time.Time) error {
if c.buffer != nil && !c.buffer.IsEmpty() {
return nil
}
return c.Conn.SetReadDeadline(t)
}
func (c *CachedConn) ReadFrom(r io.Reader) (n int64, err error) {
return Copy(c.Conn, r)
}
@ -192,10 +184,12 @@ func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
if buffer != nil {
buffer.DecRef()
}
return &N.PacketBuffer{
packet := N.NewPacketBuffer()
*packet = N.PacketBuffer{
Buffer: buffer,
Destination: c.destination,
}
return packet
}
func (c *CachedPacketConn) Upstream() any {

View file

@ -30,7 +30,7 @@ func (c *ChunkReader) ReadBuffer(buffer *buf.Buffer) error {
} else if !c.cache.IsEmpty() {
return common.Error(buffer.ReadFrom(c.cache))
}
c.cache.FullReset()
c.cache.Reset()
err := c.upstream.ReadBuffer(c.cache)
if err != nil {
c.cache.Release()
@ -46,7 +46,7 @@ func (c *ChunkReader) Read(p []byte) (n int, err error) {
} else if !c.cache.IsEmpty() {
return c.cache.Read(p)
}
c.cache.FullReset()
c.cache.Reset()
err = c.upstream.ReadBuffer(c.cache)
if err != nil {
c.cache.Release()
@ -70,7 +70,7 @@ func (c *ChunkReader) ReadChunk() (*buf.Buffer, error) {
} else if !c.cache.IsEmpty() {
return c.cache, nil
}
c.cache.FullReset()
c.cache.Reset()
err := c.upstream.ReadBuffer(c.cache)
if err != nil {
c.cache.Release()

View file

@ -35,14 +35,7 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
if destination.IsFqdn() {
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
if err != nil {
return err
}
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
}
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
}
func (w *ExtendedUDPConn) Upstream() any {

View file

@ -5,7 +5,6 @@ import (
"errors"
"io"
"net"
"reflect"
"syscall"
"github.com/sagernet/sing/common"
@ -13,7 +12,6 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
)
@ -31,93 +29,71 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
if cachedSrc, isCached := source.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil {
if !cachedBuffer.IsEmpty() {
_, err = destination.Write(cachedBuffer.Bytes())
if err != nil {
cachedBuffer.Release()
return
}
}
dataLen := cachedBuffer.Len()
_, err = destination.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err != nil {
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
continue
}
}
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
break
}
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
}
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
}
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
safeSrc := N.IsSafeReader(source)
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
if safeSrc != nil {
if headroom == 0 {
return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters)
}
}
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
readWaiter, isReadWaiter := CreateReadWaiter(source)
if isReadWaiter {
var handled bool
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
if handled {
return
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destination),
})
if !needCopy || common.LowMemory {
var handled bool
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
if handled {
return
}
}
}
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
}
// Deprecated: not used
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
buffer.IncRef()
defer buffer.DecRef()
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
var notFirstTime bool
for {
readBuffer.Resize(frontHeadroom, 0)
err = source.ReadBuffer(readBuffer)
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destination.WriteBuffer(buffer)
if err != nil {
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var notFirstTime bool
for {
var buffer *buf.Buffer
buffer, err = source.ReadBufferThreadSafe()
err = source.ReadBuffer(buffer)
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
@ -126,9 +102,9 @@ func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWri
return
}
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
@ -146,21 +122,11 @@ func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWri
}
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
bufferSize := N.CalculateMTU(source, destination)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
options := N.NewReadWaitOptions(source, destination)
var notFirstTime bool
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
err = source.ReadBuffer(readBuffer)
buffer := options.NewBuffer()
err = source.ReadBuffer(buffer)
if err != nil {
buffer.Release()
if errors.Is(err, io.EOF) {
@ -169,11 +135,11 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
dataLen := buffer.Len()
options.PostReturn(buffer)
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Release()
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
@ -191,16 +157,12 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
}
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
return CopyConnContextList([]context.Context{ctx}, source, destination)
}
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
var group task.Group
if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex {
if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex {
group.Append("upload", func(ctx context.Context) error {
err := common.Error(Copy(destination, source))
if err == nil {
rw.CloseWrite(destination)
N.CloseWrite(destination)
} else {
common.Close(destination)
}
@ -212,11 +174,11 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
return common.Error(Copy(destination, source))
})
}
if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex {
if _, srcDuplex := common.Cast[N.WriteCloser](source); srcDuplex {
group.Append("download", func(ctx context.Context) error {
err := common.Error(Copy(source, destination))
if err == nil {
rw.CloseWrite(source)
N.CloseWrite(source)
} else {
common.Close(source)
}
@ -231,7 +193,7 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
group.Cleanup(func() {
common.Close(source, destination)
})
return group.RunContextList(contextList)
return group.Run(ctx)
}
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
@ -251,33 +213,30 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
break
}
if cachedPackets != nil {
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
if err != nil {
return
}
}
safeSrc := N.IsSafePacketReader(source)
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
headroom := frontHeadroom + rearHeadroom
if safeSrc != nil {
if headroom == 0 {
var copyN int64
copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0)
n += copyN
return
}
}
copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
n += copeN
return
}
func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var (
handled bool
copeN int64
)
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled {
n += copeN
return
needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled {
n += copeN
return
}
}
}
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
@ -285,116 +244,65 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
return
}
func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
var buffer *buf.Buffer
var destination M.Socksaddr
func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
options := N.NewReadWaitOptions(source, destination)
var destinationAddress M.Socksaddr
for {
buffer, destination, err = source.ReadPacketThreadSafe()
buffer := options.NewPacketBuffer()
destinationAddress, err = source.ReadPacket(buffer)
if err != nil {
buffer.Release()
return
}
if buffer == nil {
panic("nil buffer returned from " + reflect.TypeOf(source).String())
}
dataLen := buffer.Len()
if dataLen == 0 {
continue
}
err = destinationConn.WritePacket(buffer, destination)
options.PostReturn(buffer)
err = destination.WritePacket(buffer, destinationAddress)
if err != nil {
buffer.Release()
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
bufferSize := N.CalculateMTU(source, destinationConn)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var destination M.Socksaddr
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
destination, err = source.ReadPacket(readBuffer)
if err != nil {
buffer.Release()
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
options := N.NewReadWaitOptions(nil, destination)
var notFirstTime bool
for _, packetBuffer := range packetBuffers {
buffer := buf.NewPacket()
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
_, err = readBuffer.Write(packetBuffer.Buffer.Bytes())
packetBuffer.Buffer.Release()
buffer := options.Copy(packetBuffer.Buffer)
dataLen := buffer.Len()
err = destination.WritePacket(buffer, packetBuffer.Destination)
N.PutPacketBuffer(packetBuffer)
if err != nil {
buffer.Release()
continue
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
if err != nil {
buffer.Release()
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
n += int64(dataLen)
notFirstTime = true
}
return
}
func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
return CopyPacketConnContextList([]context.Context{ctx}, source, destination)
}
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
var group task.Group
group.Append("upload", func(ctx context.Context) error {
return common.Error(CopyPacket(destination, source))
@ -406,5 +314,5 @@ func CopyPacketConnContextList(contextList []context.Context, source N.PacketCon
common.Close(source, destination)
})
group.FastFail()
return group.RunContextList(contextList)
return group.Run(ctx)
}

View file

@ -1,12 +1,16 @@
package bufio
import (
"errors"
"io"
"syscall"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
rawSource, err := source.SyscallConn()
if err != nil {
return
@ -18,3 +22,69 @@ func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
return
}
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
notFirstTime bool
)
for {
buffer, err = source.WaitReadBuffer()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := buffer.Len()
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
destination M.Socksaddr
)
for {
buffer, destination, err = source.WaitReadPacket()
if err != nil {
return
}
dataLen := buffer.Len()
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}

View file

@ -3,7 +3,6 @@
package bufio
import (
"errors"
"io"
"net/netip"
"os"
@ -15,114 +14,14 @@ import (
N "github.com/sagernet/sing/common/network"
)
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
bufferSize := N.CalculateMTU(source, destination)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
notFirstTime bool
)
source.InitializeReadWaiter(func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
})
defer source.InitializeReadWaiter(nil)
for {
err = source.WaitReadBuffer()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
handled = true
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
bufferSize := N.CalculateMTU(source, destinationConn)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
destination M.Socksaddr
)
source.InitializeReadWaiter(func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
})
defer source.InitializeReadWaiter(nil)
for {
destination, err = source.WaitReadPacket()
if err != nil {
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
type syscallReadWaiter struct {
rawConn syscall.RawConn
readErr error
readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer
options N.ReadWaitOptions
}
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
@ -135,47 +34,48 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
return nil, false
}
func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readErr = nil
if newBuffer == nil {
w.readFunc = nil
} else {
w.readFunc = func(fd uintptr) (done bool) {
buffer := newBuffer()
var readN int
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
if readN > 0 {
buffer.Truncate(readN)
} else {
buffer.Release()
buffer = nil
}
if w.readErr == syscall.EAGAIN {
return false
}
if readN == 0 {
w.readErr = io.EOF
}
return true
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
buffer := w.options.NewBuffer()
var readN int
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
if readN > 0 {
buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
}
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr == syscall.EAGAIN {
return false
}
if readN == 0 && w.readErr == nil {
w.readErr = io.EOF
}
return true
}
return false
}
func (w *syscallReadWaiter) WaitReadBuffer() error {
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
if w.readFunc == nil {
return os.ErrInvalid
return nil, os.ErrInvalid
}
err := w.rawConn.Read(w.readFunc)
err = w.rawConn.Read(w.readFunc)
if err != nil {
return err
return
}
if w.readErr != nil {
if w.readErr == io.EOF {
return io.EOF
return nil, io.EOF
}
return E.Cause(w.readErr, "raw read")
return nil, E.Cause(w.readErr, "raw read")
}
return nil
buffer = w.buffer
w.buffer = nil
return
}
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
@ -185,6 +85,8 @@ type syscallPacketReadWaiter struct {
readErr error
readFrom M.Socksaddr
readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer
options N.ReadWaitOptions
}
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
@ -197,42 +99,37 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
return nil, false
}
func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readErr = nil
w.readFrom = M.Socksaddr{}
if newBuffer == nil {
w.readFunc = nil
} else {
w.readFunc = func(fd uintptr) (done bool) {
buffer := newBuffer()
var readN int
var from syscall.Sockaddr
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
if readN > 0 {
buffer.Truncate(readN)
} else {
buffer.Release()
buffer = nil
}
if w.readErr == syscall.EAGAIN {
return false
}
if from != nil {
switch fromAddr := from.(type) {
case *syscall.SockaddrInet4:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *syscall.SockaddrInet6:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
}
return true
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
buffer := w.options.NewPacketBuffer()
var readN int
var from syscall.Sockaddr
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr != nil {
buffer.Release()
return w.readErr != syscall.EAGAIN
}
if readN > 0 {
buffer.Truncate(readN)
}
w.options.PostReturn(buffer)
w.buffer = buffer
switch fromAddr := from.(type) {
case *syscall.SockaddrInet4:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *syscall.SockaddrInet6:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
return true
}
return false
}
func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if w.readFunc == nil {
return M.Socksaddr{}, os.ErrInvalid
return nil, M.Socksaddr{}, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
@ -242,6 +139,8 @@ func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err
err = E.Cause(w.readErr, "raw read")
return
}
buffer = w.buffer
w.buffer = nil
destination = w.readFrom
return
}

View file

@ -0,0 +1,77 @@
package bufio
import (
"net"
"testing"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
"github.com/stretchr/testify/require"
)
func TestCopyWaitTCP(t *testing.T) {
t.Parallel()
inputConn, outputConn := TCPPipe(t)
readWaiter, created := CreateReadWaiter(outputConn)
require.True(t, created)
require.NotNil(t, readWaiter)
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{
Conn: outputConn,
readWaiter: readWaiter,
}))
}
type readWaitWrapper struct {
net.Conn
readWaiter N.ReadWaiter
buffer *buf.Buffer
}
func (r *readWaitWrapper) Read(p []byte) (n int, err error) {
if r.buffer != nil {
if r.buffer.Len() > 0 {
return r.buffer.Read(p)
}
if r.buffer.IsEmpty() {
r.buffer.Release()
r.buffer = nil
}
}
buffer, err := r.readWaiter.WaitReadBuffer()
if err != nil {
return
}
r.buffer = buffer
return r.buffer.Read(p)
}
func TestCopyWaitUDP(t *testing.T) {
t.Parallel()
inputConn, outputConn, outputAddr := UDPPipe(t)
readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn))
require.True(t, created)
require.NotNil(t, readWaiter)
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{
PacketConn: outputConn,
readWaiter: readWaiter,
}, outputAddr))
}
type packetReadWaitWrapper struct {
net.PacketConn
readWaiter N.PacketReadWaiter
}
func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
buffer, destination, err := r.readWaiter.WaitReadPacket()
if err != nil {
return
}
n = copy(p, buffer.Bytes())
buffer.Release()
addr = destination.UDPAddr()
return
}

View file

@ -2,22 +2,162 @@ package bufio
import (
"io"
"net/netip"
"os"
"syscall"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/windows"
)
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
return
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
type syscallReadWaiter struct {
rawConn syscall.RawConn
readErr error
readFunc func(fd uintptr) (done bool)
hasData bool
buffer *buf.Buffer
options N.ReadWaitOptions
}
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
return
}
func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) {
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
rawConn, err := syscallConn.SyscallConn()
if err == nil {
return &syscallReadWaiter{rawConn: rawConn}, true
}
}
return nil, false
}
func createSyscallPacketReadWaiter(reader any) (N.PacketReadWaiter, bool) {
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
if !w.hasData {
w.hasData = true
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
// socket is readable if we return false. So the `recv` syscall will not block the system thread.
return false
}
buffer := w.options.NewBuffer()
var readN int32
readN, w.readErr = recv(windows.Handle(fd), buffer.FreeBytes(), 0)
if readN > 0 {
buffer.Truncate(int(readN))
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
}
if w.readErr == windows.WSAEWOULDBLOCK {
return false
}
if readN == 0 && w.readErr == nil {
w.readErr = io.EOF
}
w.hasData = false
return true
}
return false
}
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
if w.readFunc == nil {
return nil, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
return
}
if w.readErr != nil {
if w.readErr == io.EOF {
return nil, io.EOF
}
return nil, E.Cause(w.readErr, "raw read")
}
buffer = w.buffer
w.buffer = nil
return
}
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
type syscallPacketReadWaiter struct {
rawConn syscall.RawConn
readErr error
readFrom M.Socksaddr
readFunc func(fd uintptr) (done bool)
hasData bool
buffer *buf.Buffer
options N.ReadWaitOptions
}
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
rawConn, err := syscallConn.SyscallConn()
if err == nil {
return &syscallPacketReadWaiter{rawConn: rawConn}, true
}
}
return nil, false
}
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
if !w.hasData {
w.hasData = true
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
// socket is readable if we return false. So the `recvfrom` syscall will not block the system thread.
return false
}
buffer := w.options.NewPacketBuffer()
var readN int
var from windows.Sockaddr
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr != nil {
buffer.Release()
return w.readErr != windows.WSAEWOULDBLOCK
}
if readN > 0 {
buffer.Truncate(readN)
}
w.options.PostReturn(buffer)
w.buffer = buffer
if from != nil {
switch fromAddr := from.(type) {
case *windows.SockaddrInet4:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *windows.SockaddrInet6:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
}
w.hasData = false
return true
}
return false
}
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if w.readFunc == nil {
return nil, M.Socksaddr{}, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
return
}
if w.readErr != nil {
err = E.Cause(w.readErr, "raw read")
return
}
buffer = w.buffer
w.buffer = nil
destination = w.readFrom
return
}

View file

@ -38,6 +38,10 @@ func (c *SerialConn) ReadBuffer(buffer *buf.Buffer) error {
return c.ExtendedConn.ReadBuffer(buffer)
}
func (c *SerialConn) Upstream() any {
return c.ExtendedConn
}
type SerialPacketConn struct {
N.NetPacketConn
access sync.Mutex

View file

@ -25,6 +25,45 @@ func ReadPacket(reader N.PacketReader, buffer *buf.Buffer) (n int, addr net.Addr
return
}
func ReadBufferSize(reader io.Reader, bufferSize int) (buffer *buf.Buffer, err error) {
readWaiter, isReadWaiter := CreateReadWaiter(reader)
if isReadWaiter {
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
MTU: bufferSize,
})
return readWaiter.WaitReadBuffer()
}
buffer = buf.NewSize(bufferSize)
if extendedReader, isExtendedReader := reader.(N.ExtendedReader); isExtendedReader {
err = extendedReader.ReadBuffer(buffer)
} else {
_, err = buffer.ReadOnceFrom(reader)
}
if err != nil {
buffer.Release()
buffer = nil
}
return
}
func ReadPacketSize(reader N.PacketReader, packetSize int) (buffer *buf.Buffer, destination M.Socksaddr, err error) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(reader)
if isReadWaiter {
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
MTU: packetSize,
})
buffer, destination, err = readWaiter.WaitReadPacket()
return
}
buffer = buf.NewSize(packetSize)
destination, err = reader.ReadPacket(buffer)
if err != nil {
buffer.Release()
buffer = nil
}
return
}
func Write(writer io.Writer, data []byte) (n int, err error) {
if extendedWriter, isExtended := writer.(N.ExtendedWriter); isExtended {
return WriteBuffer(extendedWriter, buf.As(data))

View file

@ -17,13 +17,21 @@ type NATPacketConn interface {
func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &unidirectionalNATPacketConn{
NetPacketConn: conn,
origin: origin,
destination: destination,
origin: socksaddrWithoutPort(origin),
destination: socksaddrWithoutPort(destination),
}
}
func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &bidirectionalNATPacketConn{
NetPacketConn: conn,
origin: socksaddrWithoutPort(origin),
destination: socksaddrWithoutPort(destination),
}
}
func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &destinationNATPacketConn{
NetPacketConn: conn,
origin: origin,
destination: destination,
@ -37,15 +45,24 @@ type unidirectionalNATPacketConn struct {
}
func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if M.SocksaddrFromNet(addr) == c.destination {
addr = c.origin.UDPAddr()
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WriteTo(p, addr)
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
}
func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination == c.destination {
destination = c.origin
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
@ -54,6 +71,10 @@ func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
}
func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func (c *unidirectionalNATPacketConn) Upstream() any {
return c.NetPacketConn
}
@ -66,30 +87,55 @@ type bidirectionalNATPacketConn struct {
func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p)
if err == nil && M.SocksaddrFromNet(addr) == c.origin {
addr = c.destination.UDPAddr()
if err != nil {
return
}
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
addr = destination.UDPAddr()
return
}
func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if M.SocksaddrFromNet(addr) == c.destination {
addr = c.origin.UDPAddr()
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WriteTo(p, addr)
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
}
func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer)
if destination == c.origin {
destination = c.destination
if err != nil {
return
}
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
return
}
func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination == c.destination {
destination = c.origin
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
@ -101,3 +147,66 @@ func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.
func (c *bidirectionalNATPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
type destinationNATPacketConn struct {
N.NetPacketConn
origin M.Socksaddr
destination M.Socksaddr
}
func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
if M.SocksaddrFromNet(addr) == c.origin {
addr = c.destination.UDPAddr()
}
return
}
func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if M.SocksaddrFromNet(addr) == c.destination {
addr = c.origin.UDPAddr()
}
return c.NetPacketConn.WriteTo(p, addr)
}
func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return
}
if destination == c.origin {
destination = c.destination
}
return
}
func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination == c.destination {
destination = c.origin
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
}
func (c *destinationNATPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *destinationNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
destination.Port = 0
return destination
}

39
common/bufio/nat_wait.go Normal file
View file

@ -0,0 +1,39 @@
package bufio
import (
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func (c *bidirectionalNATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaiter, bool) {
waiter, created := CreatePacketReadWaiter(c.NetPacketConn)
if !created {
return nil, false
}
return &waitBidirectionalNATPacketConn{c, waiter}, true
}
type waitBidirectionalNATPacketConn struct {
*bidirectionalNATPacketConn
readWaiter N.PacketReadWaiter
}
func (c *waitBidirectionalNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return c.readWaiter.InitializeReadWaiter(options)
}
func (c *waitBidirectionalNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, destination, err = c.readWaiter.WaitReadPacket()
if err != nil {
return
}
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
return
}

277
common/bufio/net_test.go Normal file
View file

@ -0,0 +1,277 @@
package bufio
import (
"context"
"crypto/md5"
"crypto/rand"
"errors"
"io"
"net"
"sync"
"testing"
"time"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/task"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
var (
group task.Group
serverConn net.Conn
clientConn net.Conn
)
group.Append0(func(ctx context.Context) error {
var serverErr error
serverConn, serverErr = listener.Accept()
return serverErr
})
group.Append0(func(ctx context.Context) error {
var clientErr error
clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
return clientErr
})
err = group.Run(context.Background())
require.NoError(t, err)
listener.Close()
t.Cleanup(func() {
serverConn.Close()
clientConn.Close()
})
return serverConn, clientConn
}
func UDPPipe(t *testing.T) (net.PacketConn, net.PacketConn, M.Socksaddr) {
serverConn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
clientConn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
return serverConn, clientConn, M.SocksaddrFromNet(clientConn.LocalAddr())
}
func Timeout(t *testing.T) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-ctx.Done():
return
case <-time.After(5 * time.Second):
t.Error("timeout")
}
}()
return cancel
}
type hashPair struct {
sendHash map[int][]byte
recvHash map[int][]byte
}
func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) {
pingCh := make(chan hashPair)
pongCh := make(chan hashPair)
test := func(t *testing.T) error {
defer close(pingCh)
defer close(pongCh)
pingOpen := false
pongOpen := false
var serverPair hashPair
var clientPair hashPair
for {
if pingOpen && pongOpen {
break
}
select {
case serverPair, pingOpen = <-pingCh:
assert.True(t, pingOpen)
case clientPair, pongOpen = <-pongCh:
assert.True(t, pongOpen)
case <-time.After(10 * time.Second):
return errors.New("timeout")
}
}
assert.Equal(t, serverPair.recvHash, clientPair.sendHash)
assert.Equal(t, serverPair.sendHash, clientPair.recvHash)
return nil
}
return pingCh, pongCh, test
}
func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error {
times := 100
chunkSize := int64(64 * 1024)
pingCh, pongCh, test := newLargeDataPair()
writeRandData := func(conn net.Conn) (map[int][]byte, error) {
buf := make([]byte, chunkSize)
hashMap := map[int][]byte{}
for i := 0; i < times; i++ {
if _, err := rand.Read(buf[1:]); err != nil {
return nil, err
}
buf[0] = byte(i)
hash := md5.Sum(buf)
hashMap[i] = hash[:]
if _, err := conn.Write(buf); err != nil {
return nil, err
}
}
return hashMap, nil
}
go func() {
hashMap := map[int][]byte{}
buf := make([]byte, chunkSize)
for i := 0; i < times; i++ {
_, err := io.ReadFull(outputConn, buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf)
hashMap[int(buf[0])] = hash[:]
}
sendHash, err := writeRandData(outputConn)
if err != nil {
t.Log(err.Error())
return
}
pingCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
go func() {
sendHash, err := writeRandData(inputConn)
if err != nil {
t.Log(err.Error())
return
}
hashMap := map[int][]byte{}
buf := make([]byte, chunkSize)
for i := 0; i < times; i++ {
_, err = io.ReadFull(inputConn, buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf)
hashMap[int(buf[0])] = hash[:]
}
pongCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
return test(t)
}
func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error {
rAddr := outputAddr.UDPAddr()
times := 50
chunkSize := 9000
pingCh, pongCh, test := newLargeDataPair()
writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) {
hashMap := map[int][]byte{}
mux := sync.Mutex{}
for i := 0; i < times; i++ {
buf := make([]byte, chunkSize)
if _, err := rand.Read(buf[1:]); err != nil {
t.Log(err.Error())
continue
}
buf[0] = byte(i)
hash := md5.Sum(buf)
mux.Lock()
hashMap[i] = hash[:]
mux.Unlock()
if _, err := pc.WriteTo(buf, addr); err != nil {
t.Log(err.Error())
}
time.Sleep(10 * time.Millisecond)
}
return hashMap, nil
}
go func() {
var (
lAddr net.Addr
err error
)
hashMap := map[int][]byte{}
buf := make([]byte, 64*1024)
for i := 0; i < times; i++ {
_, lAddr, err = outputConn.ReadFrom(buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf[:chunkSize])
hashMap[int(buf[0])] = hash[:]
}
sendHash, err := writeRandData(outputConn, lAddr)
if err != nil {
t.Log(err.Error())
return
}
pingCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
go func() {
sendHash, err := writeRandData(inputConn, rAddr)
if err != nil {
t.Log(err.Error())
return
}
hashMap := map[int][]byte{}
buf := make([]byte, 64*1024)
for i := 0; i < times; i++ {
_, _, err := inputConn.ReadFrom(buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf[:chunkSize])
hashMap[int(buf[0])] = hash[:]
}
pongCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
return test(t)
}

View file

@ -0,0 +1,5 @@
package bufio
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
//sys recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) [failretval == -1] = ws2_32.recv

View file

@ -33,10 +33,10 @@ func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) {
case syscall.Conn:
rawConn, err := w.SyscallConn()
if err == nil {
return &SyscallVectorisedWriter{writer, rawConn}, true
return &SyscallVectorisedWriter{upstream: writer, rawConn: rawConn}, true
}
case syscall.RawConn:
return &SyscallVectorisedWriter{writer, w}, true
return &SyscallVectorisedWriter{upstream: writer, rawConn: w}, true
}
return nil, false
}
@ -48,10 +48,10 @@ func CreateVectorisedPacketWriter(writer any) (N.VectorisedPacketWriter, bool) {
case syscall.Conn:
rawConn, err := w.SyscallConn()
if err == nil {
return &SyscallVectorisedPacketWriter{writer, rawConn}, true
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: rawConn}, true
}
case syscall.RawConn:
return &SyscallVectorisedPacketWriter{writer, w}, true
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: w}, true
}
return nil, false
}
@ -111,6 +111,7 @@ var _ N.VectorisedWriter = (*SyscallVectorisedWriter)(nil)
type SyscallVectorisedWriter struct {
upstream any
rawConn syscall.RawConn
syscallVectorisedWriterFields
}
func (w *SyscallVectorisedWriter) Upstream() any {
@ -126,6 +127,7 @@ var _ N.VectorisedPacketWriter = (*SyscallVectorisedPacketWriter)(nil)
type SyscallVectorisedPacketWriter struct {
upstream any
rawConn syscall.RawConn
syscallVectorisedWriterFields
}
func (w *SyscallVectorisedPacketWriter) Upstream() any {

View file

@ -0,0 +1,60 @@
package bufio
import (
"crypto/rand"
"io"
"testing"
"github.com/stretchr/testify/require"
)
func TestWriteVectorised(t *testing.T) {
t.Parallel()
inputConn, outputConn := TCPPipe(t)
vectorisedWriter, created := CreateVectorisedWriter(inputConn)
require.True(t, created)
require.NotNil(t, vectorisedWriter)
var bufA [1024]byte
var bufB [1024]byte
var bufC [2048]byte
_, err := io.ReadFull(rand.Reader, bufA[:])
require.NoError(t, err)
_, err = io.ReadFull(rand.Reader, bufB[:])
require.NoError(t, err)
copy(bufC[:], bufA[:])
copy(bufC[1024:], bufB[:])
finish := Timeout(t)
_, err = WriteVectorised(vectorisedWriter, [][]byte{bufA[:], bufB[:]})
require.NoError(t, err)
output := make([]byte, 2048)
_, err = io.ReadFull(outputConn, output)
finish()
require.NoError(t, err)
require.Equal(t, bufC[:], output)
}
func TestWriteVectorisedPacket(t *testing.T) {
t.Parallel()
inputConn, outputConn, outputAddr := UDPPipe(t)
vectorisedWriter, created := CreateVectorisedPacketWriter(inputConn)
require.True(t, created)
require.NotNil(t, vectorisedWriter)
var bufA [1024]byte
var bufB [1024]byte
var bufC [2048]byte
_, err := io.ReadFull(rand.Reader, bufA[:])
require.NoError(t, err)
_, err = io.ReadFull(rand.Reader, bufB[:])
require.NoError(t, err)
copy(bufC[:], bufA[:])
copy(bufC[1024:], bufB[:])
finish := Timeout(t)
_, err = WriteVectorisedPacket(vectorisedWriter, [][]byte{bufA[:], bufB[:]}, outputAddr)
require.NoError(t, err)
output := make([]byte, 2048)
n, _, err := outputConn.ReadFrom(output)
finish()
require.NoError(t, err)
require.Equal(t, 2048, n)
require.Equal(t, bufC[:], output)
}

View file

@ -3,6 +3,8 @@
package bufio
import (
"os"
"sync"
"unsafe"
"github.com/sagernet/sing/common/buf"
@ -11,49 +13,81 @@ import (
"golang.org/x/sys/unix"
)
type syscallVectorisedWriterFields struct {
access sync.Mutex
iovecList *[]unix.Iovec
}
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
iovecList := make([]unix.Iovec, 0, len(buffers))
for _, buffer := range buffers {
var iovec unix.Iovec
iovec.Base = &buffer.Bytes()[0]
iovec.SetLen(buffer.Len())
iovecList = append(iovecList, iovec)
var iovecList []unix.Iovec
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for index, buffer := range buffers {
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
iovecList[index].SetLen(buffer.Len())
}
if w.iovecList == nil {
w.iovecList = new([]unix.Iovec)
}
*w.iovecList = iovecList // cache
var innerErr unix.Errno
err := w.rawConn.Write(func(fd uintptr) (done bool) {
//nolint:staticcheck
//goland:noinspection GoDeprecation
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
})
if innerErr != 0 {
err = innerErr
err = os.NewSyscallError("SYS_WRITEV", innerErr)
}
for index := range iovecList {
iovecList[index] = unix.Iovec{}
}
return err
}
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
var sockaddr unix.Sockaddr
if destination.IsIPv4() {
sockaddr = &unix.SockaddrInet4{
Port: int(destination.Port),
Addr: destination.Addr.As4(),
}
} else {
sockaddr = &unix.SockaddrInet6{
Port: int(destination.Port),
Addr: destination.Addr.As16(),
}
var iovecList []unix.Iovec
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for index, buffer := range buffers {
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
iovecList[index].SetLen(buffer.Len())
}
if w.iovecList == nil {
w.iovecList = new([]unix.Iovec)
}
*w.iovecList = iovecList // cache
var innerErr error
err := w.rawConn.Write(func(fd uintptr) (done bool) {
_, innerErr = unix.SendmsgBuffers(int(fd), buf.ToSliceMulti(buffers), nil, sockaddr, 0)
var msg unix.Msghdr
name, nameLen := ToSockaddr(destination.AddrPort())
msg.Name = (*byte)(name)
msg.Namelen = nameLen
if len(iovecList) > 0 {
msg.Iov = &iovecList[0]
msg.SetIovlen(len(iovecList))
}
_, innerErr = sendmsg(int(fd), &msg, 0)
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
for index := range iovecList {
iovecList[index] = unix.Iovec{}
}
return err
}
//go:linkname sendmsg golang.org/x/sys/unix.sendmsg
func sendmsg(s int, msg *unix.Msghdr, flags int) (n int, err error)

View file

@ -1,62 +1,93 @@
package bufio
import (
"sync"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/windows"
)
type syscallVectorisedWriterFields struct {
access sync.Mutex
iovecList *[]windows.WSABuf
}
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
iovecList := make([]*windows.WSABuf, len(buffers))
for i, buffer := range buffers {
iovecList[i] = &windows.WSABuf{
Len: uint32(buffer.Len()),
Buf: &buffer.Bytes()[0],
}
var iovecList []windows.WSABuf
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for _, buffer := range buffers {
iovecList = append(iovecList, windows.WSABuf{
Buf: &buffer.Bytes()[0],
Len: uint32(buffer.Len()),
})
}
if w.iovecList == nil {
w.iovecList = new([]windows.WSABuf)
}
*w.iovecList = iovecList // cache
var n uint32
var innerErr error
err := w.rawConn.Write(func(fd uintptr) (done bool) {
innerErr = windows.WSASend(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
innerErr = windows.WSASend(windows.Handle(fd), &iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
return innerErr != windows.WSAEWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
for index := range iovecList {
iovecList[index] = windows.WSABuf{}
}
return err
}
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
iovecList := make([]*windows.WSABuf, len(buffers))
for i, buffer := range buffers {
iovecList[i] = &windows.WSABuf{
Len: uint32(buffer.Len()),
var iovecList []windows.WSABuf
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for _, buffer := range buffers {
iovecList = append(iovecList, windows.WSABuf{
Buf: &buffer.Bytes()[0],
}
Len: uint32(buffer.Len()),
})
}
var sockaddr windows.Sockaddr
if destination.IsIPv4() {
sockaddr = &windows.SockaddrInet4{
Port: int(destination.Port),
Addr: destination.Addr.As4(),
}
} else {
sockaddr = &windows.SockaddrInet6{
Port: int(destination.Port),
Addr: destination.Addr.As16(),
}
if w.iovecList == nil {
w.iovecList = new([]windows.WSABuf)
}
*w.iovecList = iovecList // cache
var n uint32
var innerErr error
err := w.rawConn.Write(func(fd uintptr) (done bool) {
innerErr = windows.WSASendto(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, sockaddr, nil, nil)
name, nameLen := ToSockaddr(destination.AddrPort())
innerErr = windows.WSASendTo(
windows.Handle(fd),
&iovecList[0],
uint32(len(iovecList)),
&n,
0,
(*windows.RawSockaddrAny)(name),
nameLen,
nil,
nil)
return innerErr != windows.WSAEWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
for index := range iovecList {
iovecList[index] = windows.WSABuf{}
}
return err
}

View file

@ -0,0 +1,57 @@
// Code generated by 'go generate'; DO NOT EDIT.
package bufio
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procrecv = modws2_32.NewProc("recv")
)
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
r0, _, e1 := syscall.Syscall6(procrecv.Addr(), 4, uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags), 0, 0)
n = int32(r0)
if n == -1 {
err = errnoErr(e1)
}
return
}

View file

@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration {
return i.timeout
}
func (i *Instance) SetTimeout(timeout time.Duration) {
func (i *Instance) SetTimeout(timeout time.Duration) bool {
i.timeout = timeout
i.Update()
return i.Update()
}
func (i *Instance) wait() {

View file

@ -13,7 +13,7 @@ import (
type PacketConn interface {
N.PacketConn
Timeout() time.Duration
SetTimeout(timeout time.Duration)
SetTimeout(timeout time.Duration) bool
}
type TimerPacketConn struct {
@ -21,13 +21,15 @@ type TimerPacketConn struct {
instance *Instance
}
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) {
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
oldTimeout := timeoutConn.Timeout()
if timeout < oldTimeout {
timeoutConn.SetTimeout(timeout)
if oldTimeout > 0 && timeout >= oldTimeout {
return ctx, conn
}
if timeoutConn.SetTimeout(timeout) {
return ctx, conn
}
return ctx, timeoutConn
}
err := conn.SetReadDeadline(time.Time{})
if err == nil {
@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration {
return c.instance.Timeout()
}
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) {
c.instance.SetTimeout(timeout)
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
return c.instance.SetTimeout(timeout)
}
func (c *TimerPacketConn) Close() error {

View file

@ -2,6 +2,7 @@ package canceler
import (
"context"
"net"
"time"
"github.com/sagernet/sing/common"
@ -31,7 +32,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
for {
err = c.PacketConn.SetReadDeadline(time.Now().Add(c.timeout))
if err != nil {
return M.Socksaddr{}, err
return
}
destination, err = c.PacketConn.ReadPacket(buffer)
if err == nil {
@ -43,7 +44,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
return
}
} else {
return M.Socksaddr{}, err
return
}
}
}
@ -60,12 +61,13 @@ func (c *TimeoutPacketConn) Timeout() time.Duration {
return c.timeout
}
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
c.timeout = timeout
c.PacketConn.SetReadDeadline(time.Now())
return c.PacketConn.SetReadDeadline(time.Now()) == nil
}
func (c *TimeoutPacketConn) Close() error {
c.cancel(net.ErrClosed)
return c.PacketConn.Close()
}

11
common/clear.go Normal file
View file

@ -0,0 +1,11 @@
//go:build go1.21
package common
func ClearArray[T ~[]E, E any](t T) {
clear(t)
}
func ClearMap[T ~map[K]V, K comparable, V any](t T) {
clear(t)
}

16
common/clear_compat.go Normal file
View file

@ -0,0 +1,16 @@
//go:build !go1.21
package common
func ClearArray[T ~[]E, E any](t T) {
var defaultValue E
for i := range t {
t[i] = defaultValue
}
}
func ClearMap[T ~map[K]V, K comparable, V any](t T) {
for k := range t {
delete(t, k)
}
}

View file

@ -157,6 +157,18 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
return -1
}
func Equal[S ~[]E, E comparable](s1, s2 S) bool {
if len(s1) != len(s2) {
return false
}
for i := range s1 {
if s1[i] != s2[i] {
return false
}
}
return true
}
//go:norace
func Dup[T any](obj T) T {
pointer := uintptr(unsafe.Pointer(&obj))
@ -268,6 +280,14 @@ func Reverse[T any](arr []T) []T {
return arr
}
func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K {
ret := make(map[V]K, len(m))
for k, v := range m {
ret[v] = k
}
return ret
}
func Done(ctx context.Context) bool {
select {
case <-ctx.Done():
@ -336,6 +356,10 @@ func DefaultValue[T any]() T {
return defaultValue
}
func Ptr[T any](obj T) *T {
return &obj
}
func Close(closers ...any) error {
var retErr error
for _, closer := range closers {
@ -358,22 +382,3 @@ func Close(closers ...any) error {
}
return retErr
}
type Starter interface {
Start() error
}
func Start(starters ...any) error {
for _, rawStarter := range starters {
if rawStarter == nil {
continue
}
if starter, isStarter := rawStarter.(Starter); isStarter {
err := starter.Start()
if err != nil {
return err
}
}
}
return nil
}

View file

@ -5,6 +5,7 @@ import (
"reflect"
)
// Deprecated: not used
func SelectContext(contextList []context.Context) (int, error) {
if len(contextList) == 1 {
<-contextList[0].Done()

View file

@ -10,7 +10,7 @@ import (
func BindToInterface(finder InterfaceFinder, interfaceName string, interfaceIndex int) Func {
return func(network, address string, conn syscall.RawConn) error {
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false)
}
}
@ -20,16 +20,16 @@ func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, addr
if err != nil {
return err
}
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false)
}
}
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
if interfaceName == "" && interfaceIndex == -1 {
return E.New("interface not found: ", interfaceName)
}
if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) {
return nil
}
return bindToInterface(conn, network, address, finder, interfaceName, interfaceIndex)
return bindToInterface(conn, network, address, finder, interfaceName, interfaceIndex, preferInterfaceName)
}

View file

@ -7,17 +7,17 @@ import (
"golang.org/x/sys/unix"
)
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error {
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
var err error
if interfaceIndex == -1 {
if finder == nil {
return os.ErrInvalid
}
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
switch network {
case "tcp6", "udp6":

View file

@ -1,30 +1,59 @@
package control
import "net"
import (
"net"
"net/netip"
"unsafe"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
)
type InterfaceFinder interface {
InterfaceIndexByName(name string) (int, error)
InterfaceNameByIndex(index int) (string, error)
Update() error
Interfaces() []Interface
ByName(name string) (*Interface, error)
ByIndex(index int) (*Interface, error)
ByAddr(addr netip.Addr) (*Interface, error)
}
func DefaultInterfaceFinder() InterfaceFinder {
return (*netInterfaceFinder)(nil)
type Interface struct {
Index int
MTU int
Name string
HardwareAddr net.HardwareAddr
Flags net.Flags
Addresses []netip.Prefix
}
type netInterfaceFinder struct{}
func (i Interface) Equals(other Interface) bool {
return i.Index == other.Index &&
i.MTU == other.MTU &&
i.Name == other.Name &&
common.Equal(i.HardwareAddr, other.HardwareAddr) &&
i.Flags == other.Flags &&
common.Equal(i.Addresses, other.Addresses)
}
func (w *netInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
netInterface, err := net.InterfaceByName(name)
func (i Interface) NetInterface() net.Interface {
return *(*net.Interface)(unsafe.Pointer(&i))
}
func InterfaceFromNet(iif net.Interface) (Interface, error) {
ifAddrs, err := iif.Addrs()
if err != nil {
return 0, err
return Interface{}, err
}
return netInterface.Index, nil
return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil
}
func (w *netInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
netInterface, err := net.InterfaceByIndex(index)
if err != nil {
return "", err
func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface {
return Interface{
Index: iif.Index,
MTU: iif.MTU,
Name: iif.Name,
HardwareAddr: iif.HardwareAddr,
Flags: iif.Flags,
Addresses: addresses,
}
return netInterface.Name, nil
}

View file

@ -0,0 +1,89 @@
package control
import (
"net"
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
)
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
type DefaultInterfaceFinder struct {
interfaces []Interface
}
func NewDefaultInterfaceFinder() *DefaultInterfaceFinder {
return &DefaultInterfaceFinder{}
}
func (f *DefaultInterfaceFinder) Update() error {
netIfs, err := net.Interfaces()
if err != nil {
return err
}
interfaces := make([]Interface, 0, len(netIfs))
for _, netIf := range netIfs {
var iif Interface
iif, err = InterfaceFromNet(netIf)
if err != nil {
return err
}
interfaces = append(interfaces, iif)
}
f.interfaces = interfaces
return nil
}
func (f *DefaultInterfaceFinder) UpdateInterfaces(interfaces []Interface) {
f.interfaces = interfaces
}
func (f *DefaultInterfaceFinder) Interfaces() []Interface {
return f.interfaces
}
func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) {
for _, netInterface := range f.interfaces {
if netInterface.Name == name {
return &netInterface, nil
}
}
_, err := net.InterfaceByName(name)
if err == nil {
err = f.Update()
if err != nil {
return nil, err
}
return f.ByName(name)
}
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
}
func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) {
for _, netInterface := range f.interfaces {
if netInterface.Index == index {
return &netInterface, nil
}
}
_, err := net.InterfaceByIndex(index)
if err == nil {
err = f.Update()
if err != nil {
return nil, err
}
return f.ByIndex(index)
}
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
}
func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) {
for _, netInterface := range f.interfaces {
for _, prefix := range netInterface.Addresses {
if prefix.Contains(addr) {
return &netInterface, nil
}
}
}
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: addr.AsSlice()}, Err: E.New("no such network interface")}
}

View file

@ -12,20 +12,20 @@ import (
var ifIndexDisabled atomic.Bool
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error {
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
var err error
if !ifIndexDisabled.Load() {
if !preferInterfaceName && !ifIndexDisabled.Load() {
if interfaceIndex == -1 {
if finder == nil {
if interfaceName == "" {
return os.ErrInvalid
}
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
if err == nil {
return nil
} else if E.IsMulti(err, unix.ENOPROTOOPT, unix.EINVAL) {
@ -35,13 +35,7 @@ func bindToInterface(conn syscall.RawConn, network string, address string, finde
}
}
if interfaceName == "" {
if finder == nil {
return os.ErrInvalid
}
interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex)
if err != nil {
return err
}
return os.ErrInvalid
}
return unix.BindToDevice(int(fd), interfaceName)
})

View file

@ -4,6 +4,6 @@ package control
import "syscall"
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error {
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return nil
}

View file

@ -9,21 +9,21 @@ import (
M "github.com/sagernet/sing/common/metadata"
)
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error {
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
var err error
if interfaceIndex == -1 {
if finder == nil {
return os.ErrInvalid
}
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
handle := syscall.Handle(fd)
if M.ParseSocksaddr(address).AddrString() == "" {
err = bind4(handle, interfaceIndex)
err := bind4(handle, interfaceIndex)
if err != nil {
return err
}

View file

@ -4,19 +4,26 @@ import (
"os"
"syscall"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
)
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
if N.NetworkName(network) != N.NetworkUDP {
return nil
}
return Raw(conn, func(fd uintptr) error {
switch network {
case "udp4":
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1); err != nil {
if network == "udp" || network == "udp4" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
}
case "udp6":
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1); err != nil {
}
if network == "udp" || network == "udp6" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
}
}

View file

@ -11,17 +11,19 @@ import (
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
switch N.NetworkName(network) {
case N.NetworkUDP:
default:
if N.NetworkName(network) != N.NetworkUDP {
return nil
}
return Raw(conn, func(fd uintptr) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
if network == "udp" || network == "udp4" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}
if network == "udp6" {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
if network == "udp" || network == "udp6" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}

View file

@ -25,17 +25,19 @@ const (
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
switch N.NetworkName(network) {
case N.NetworkUDP:
default:
if N.NetworkName(network) != N.NetworkUDP {
return nil
}
return Raw(conn, func(fd uintptr) error {
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
if network == "udp" || network == "udp4" {
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}
if network == "udp6" {
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
if network == "udp" || network == "udp6" {
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}

View file

@ -3,6 +3,7 @@ package control
import (
"syscall"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
@ -30,6 +31,14 @@ func Conn(conn syscall.Conn, block func(fd uintptr) error) error {
return Raw(rawConn, block)
}
func Conn0[T any](conn syscall.Conn, block func(fd uintptr) (T, error)) (T, error) {
rawConn, err := conn.SyscallConn()
if err != nil {
return common.DefaultValue[T](), err
}
return Raw0[T](rawConn, block)
}
func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
var innerErr error
err := rawConn.Control(func(fd uintptr) {
@ -37,3 +46,14 @@ func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
})
return E.Errors(innerErr, err)
}
func Raw0[T any](rawConn syscall.RawConn, block func(fd uintptr) (T, error)) (T, error) {
var (
value T
innerErr error
)
err := rawConn.Control(func(fd uintptr) {
value, innerErr = block(fd)
})
return value, E.Errors(innerErr, err)
}

View file

@ -4,10 +4,10 @@ import (
"syscall"
)
func RoutingMark(mark int) Func {
func RoutingMark(mark uint32) Func {
return func(network, address string, conn syscall.RawConn) error {
return Raw(conn, func(fd uintptr) error {
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark)
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(mark))
})
}
}

View file

@ -2,6 +2,6 @@
package control
func RoutingMark(mark int) Func {
func RoutingMark(mark uint32) Func {
return nil
}

View file

@ -0,0 +1,58 @@
package control
import (
"encoding/binary"
"net"
"net/netip"
"syscall"
"unsafe"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
const (
PF_OUT = 0x2
DIOCNATLOOK = 0xc0544417
)
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
pfFd, err := syscall.Open("/dev/pf", 0, syscall.O_RDONLY)
if err != nil {
return netip.AddrPort{}, err
}
defer syscall.Close(pfFd)
nl := struct {
saddr, daddr, rsaddr, rdaddr [16]byte
sxport, dxport, rsxport, rdxport [4]byte
af, proto, protoVariant, direction uint8
}{
af: syscall.AF_INET,
proto: syscall.IPPROTO_TCP,
direction: PF_OUT,
}
localAddr := M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
removeAddr := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
if localAddr.IsIPv4() {
copy(nl.saddr[:net.IPv4len], removeAddr.Addr.AsSlice())
copy(nl.daddr[:net.IPv4len], localAddr.Addr.AsSlice())
nl.af = syscall.AF_INET
} else {
copy(nl.saddr[:], removeAddr.Addr.AsSlice())
copy(nl.daddr[:], localAddr.Addr.AsSlice())
nl.af = syscall.AF_INET6
}
binary.BigEndian.PutUint16(nl.sxport[:], removeAddr.Port)
binary.BigEndian.PutUint16(nl.dxport[:], localAddr.Port)
if _, _, errno := unix.Syscall(syscall.SYS_IOCTL, uintptr(pfFd), DIOCNATLOOK, uintptr(unsafe.Pointer(&nl))); errno != 0 {
return netip.AddrPort{}, errno
}
var address netip.Addr
if nl.af == unix.AF_INET {
address = M.AddrFromIP(nl.rdaddr[:net.IPv4len])
} else {
address = netip.AddrFrom16(nl.rdaddr)
}
return netip.AddrPortFrom(address, binary.BigEndian.Uint16(nl.rdxport[:])), nil
}

View file

@ -0,0 +1,38 @@
package control
import (
"encoding/binary"
"net"
"net/netip"
"os"
"syscall"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
syscallConn, loaded := common.Cast[syscall.Conn](conn)
if !loaded {
return netip.AddrPort{}, os.ErrInvalid
}
return Conn0[netip.AddrPort](syscallConn, func(fd uintptr) (netip.AddrPort, error) {
if M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap().IsIPv4() {
raw, err := unix.GetsockoptIPv6Mreq(int(fd), unix.IPPROTO_IP, unix.SO_ORIGINAL_DST)
if err != nil {
return netip.AddrPort{}, err
}
return netip.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
} else {
raw, err := unix.GetsockoptIPv6MTUInfo(int(fd), unix.IPPROTO_IPV6, unix.SO_ORIGINAL_DST)
if err != nil {
return netip.AddrPort{}, err
}
var port [2]byte
binary.BigEndian.PutUint16(port[:], raw.Addr.Port)
return netip.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), binary.LittleEndian.Uint16(port[:])), nil
}
})
}

View file

@ -0,0 +1,13 @@
//go:build !linux && !darwin
package control
import (
"net"
"net/netip"
"os"
)
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
return netip.AddrPort{}, os.ErrInvalid
}

View file

@ -0,0 +1,30 @@
package control
import (
"syscall"
"time"
_ "unsafe"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
)
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
return func(network, address string, conn syscall.RawConn) error {
if N.NetworkName(network) != N.NetworkTCP {
return nil
}
return Raw(conn, func(fd uintptr) error {
return E.Errors(
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPIDLE, int(roundDurationUp(idle, time.Second))),
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPINTVL, int(roundDurationUp(interval, time.Second))),
)
})
}
}
func roundDurationUp(d time.Duration, to time.Duration) time.Duration {
return (d + to - 1) / to
}

View file

@ -0,0 +1,11 @@
//go:build !linux
package control
import (
"time"
)
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
return nil
}

View file

@ -0,0 +1,56 @@
package control
import (
"encoding/binary"
"net/netip"
"syscall"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
func TProxy(fd uintptr, family int) error {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
if err == nil {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
}
if err == nil && family == unix.AF_INET6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
}
if err == nil {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
}
if err == nil && family == unix.AF_INET6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
}
return err
}
func TProxyWriteBack() Func {
return func(network, address string, conn syscall.RawConn) error {
return Raw(conn, func(fd uintptr) error {
if M.ParseSocksaddr(address).Addr.Is6() {
return syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
} else {
return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
}
})
}
}
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
controlMessages, err := unix.ParseSocketControlMessage(oob)
if err != nil {
return netip.AddrPort{}, err
}
for _, message := range controlMessages {
if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR {
return netip.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil
} else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR {
return netip.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil
}
}
return netip.AddrPort{}, E.New("not found")
}

View file

@ -0,0 +1,20 @@
//go:build !linux
package control
import (
"net/netip"
"os"
)
func TProxy(fd uintptr, isIPv6 bool) error {
return os.ErrInvalid
}
func TProxyWriteBack() Func {
return nil
}
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
return netip.AddrPort{}, os.ErrInvalid
}

View file

@ -0,0 +1,67 @@
package domain_test
import (
"sort"
"testing"
"github.com/sagernet/sing/common/domain"
"github.com/stretchr/testify/require"
)
func TestAdGuardMatcher(t *testing.T) {
t.Parallel()
ruleLines := []string{
"||example.org^",
"|example.com^",
"example.net^",
"||example.edu",
"||example.edu.tw^",
"|example.gov",
"example.arpa",
}
matcher := domain.NewAdGuardMatcher(ruleLines)
require.NotNil(t, matcher)
matchDomain := []string{
"example.org",
"www.example.org",
"example.com",
"example.net",
"isexample.net",
"www.example.net",
"example.edu",
"example.edu.cn",
"example.edu.tw",
"www.example.edu",
"www.example.edu.cn",
"example.gov",
"example.gov.cn",
"example.arpa",
"www.example.arpa",
"isexample.arpa",
"example.arpa.cn",
"www.example.arpa.cn",
"isexample.arpa.cn",
}
notMatchDomain := []string{
"example.org.cn",
"notexample.org",
"example.com.cn",
"www.example.com.cn",
"example.net.cn",
"notexample.edu",
"notexample.edu.cn",
"www.example.gov",
"notexample.gov",
}
for _, domain := range matchDomain {
require.True(t, matcher.Match(domain), domain)
}
for _, domain := range notMatchDomain {
require.False(t, matcher.Match(domain), domain)
}
dLines := matcher.Dump()
sort.Strings(ruleLines)
sort.Strings(dLines)
require.Equal(t, ruleLines, dLines)
}

View file

@ -0,0 +1,172 @@
package domain
import (
"bytes"
"sort"
"strings"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/varbin"
)
const (
anyLabel = '*'
suffixLabel = '\b'
)
type AdGuardMatcher struct {
set *succinctSet
}
func NewAdGuardMatcher(ruleLines []string) *AdGuardMatcher {
ruleList := make([]string, 0, len(ruleLines))
for _, ruleLine := range ruleLines {
var (
isSuffix bool // ||
hasStart bool // |
hasEnd bool // ^
)
if strings.HasPrefix(ruleLine, "||") {
ruleLine = ruleLine[2:]
isSuffix = true
} else if strings.HasPrefix(ruleLine, "|") {
ruleLine = ruleLine[1:]
hasStart = true
}
if strings.HasSuffix(ruleLine, "^") {
ruleLine = ruleLine[:len(ruleLine)-1]
hasEnd = true
}
if isSuffix {
ruleLine = string(rootLabel) + ruleLine
} else if !hasStart {
ruleLine = string(prefixLabel) + ruleLine
}
if !hasEnd {
if strings.HasSuffix(ruleLine, ".") {
ruleLine = ruleLine[:len(ruleLine)-1]
}
ruleLine += string(suffixLabel)
}
ruleList = append(ruleList, reverseDomain(ruleLine))
}
ruleList = common.Uniq(ruleList)
sort.Strings(ruleList)
return &AdGuardMatcher{newSuccinctSet(ruleList)}
}
func ReadAdGuardMatcher(reader varbin.Reader) (*AdGuardMatcher, error) {
set, err := readSuccinctSet(reader)
if err != nil {
return nil, err
}
return &AdGuardMatcher{set}, nil
}
func (m *AdGuardMatcher) Write(writer varbin.Writer) error {
return m.set.Write(writer)
}
func (m *AdGuardMatcher) Match(domain string) bool {
key := reverseDomain(domain)
if m.has([]byte(key), 0, 0) {
return true
}
for {
if m.has([]byte(string(suffixLabel)+key), 0, 0) {
return true
}
idx := strings.IndexByte(key, '.')
if idx == -1 {
return false
}
key = key[idx+1:]
}
}
func (m *AdGuardMatcher) has(key []byte, nodeId, bmIdx int) bool {
for i := 0; i < len(key); i++ {
currentChar := key[i]
for ; ; bmIdx++ {
if getBit(m.set.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := m.set.labels[bmIdx-nodeId]
if nextLabel == prefixLabel {
return true
}
if nextLabel == rootLabel {
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
hasNext := getBit(m.set.leaves, nextNodeId) != 0
if currentChar == '.' && hasNext {
return true
}
}
if nextLabel == currentChar {
break
}
if nextLabel == anyLabel {
idx := bytes.IndexRune(key[i:], '.')
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
if idx == -1 {
if getBit(m.set.leaves, nextNodeId) != 0 {
return true
}
idx = 0
}
nextBmIdx := selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nextNodeId-1) + 1
if m.has(key[i+idx:], nextNodeId, nextBmIdx) {
return true
}
}
}
nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1
}
if getBit(m.set.leaves, nodeId) != 0 {
return true
}
for ; ; bmIdx++ {
if getBit(m.set.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := m.set.labels[bmIdx-nodeId]
if nextLabel == prefixLabel || nextLabel == rootLabel {
return true
}
}
}
func (m *AdGuardMatcher) Dump() (ruleLines []string) {
for _, key := range m.set.keys() {
key = reverseDomain(key)
var (
isSuffix bool
hasStart bool
hasEnd bool
)
if key[0] == prefixLabel {
key = key[1:]
} else if key[0] == rootLabel {
key = key[1:]
isSuffix = true
} else {
hasStart = true
}
if key[len(key)-1] == suffixLabel {
key = key[:len(key)-1]
} else {
hasEnd = true
}
if isSuffix {
key = "||" + key
} else if hasStart {
key = "|" + key
}
if hasEnd {
key += "^"
}
ruleLines = append(ruleLines, key)
}
return
}

View file

@ -3,21 +3,39 @@ package domain
import (
"sort"
"unicode/utf8"
"github.com/sagernet/sing/common/varbin"
)
const (
prefixLabel = '\r'
rootLabel = '\n'
)
type Matcher struct {
set *succinctSet
}
func NewMatcher(domains []string, domainSuffix []string) *Matcher {
domainList := make([]string, 0, len(domains)+len(domainSuffix))
func NewMatcher(domains []string, domainSuffix []string, generateLegacy bool) *Matcher {
domainList := make([]string, 0, len(domains)+2*len(domainSuffix))
seen := make(map[string]bool, len(domainList))
for _, domain := range domainSuffix {
if seen[domain] {
continue
}
seen[domain] = true
domainList = append(domainList, reverseDomainSuffix(domain))
if domain[0] == '.' {
domainList = append(domainList, reverseDomain(string(prefixLabel)+domain))
} else if generateLegacy {
domainList = append(domainList, reverseDomain(domain))
suffixDomain := "." + domain
if !seen[suffixDomain] {
seen[suffixDomain] = true
domainList = append(domainList, reverseDomain(string(prefixLabel)+suffixDomain))
}
} else {
domainList = append(domainList, reverseDomain(string(rootLabel)+domain))
}
}
for _, domain := range domains {
if seen[domain] {
@ -27,13 +45,94 @@ func NewMatcher(domains []string, domainSuffix []string) *Matcher {
domainList = append(domainList, reverseDomain(domain))
}
sort.Strings(domainList)
return &Matcher{
newSuccinctSet(domainList),
return &Matcher{newSuccinctSet(domainList)}
}
func ReadMatcher(reader varbin.Reader) (*Matcher, error) {
set, err := readSuccinctSet(reader)
if err != nil {
return nil, err
}
return &Matcher{set}, nil
}
func (m *Matcher) Write(writer varbin.Writer) error {
return m.set.Write(writer)
}
func (m *Matcher) Match(domain string) bool {
return m.set.Has(reverseDomain(domain))
return m.has(reverseDomain(domain))
}
func (m *Matcher) has(key string) bool {
var nodeId, bmIdx int
for i := 0; i < len(key); i++ {
currentChar := key[i]
for ; ; bmIdx++ {
if getBit(m.set.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := m.set.labels[bmIdx-nodeId]
if nextLabel == prefixLabel {
return true
}
if nextLabel == rootLabel {
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
hasNext := getBit(m.set.leaves, nextNodeId) != 0
if currentChar == '.' && hasNext {
return true
}
}
if nextLabel == currentChar {
break
}
}
nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1
}
if getBit(m.set.leaves, nodeId) != 0 {
return true
}
for ; ; bmIdx++ {
if getBit(m.set.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := m.set.labels[bmIdx-nodeId]
if nextLabel == prefixLabel || nextLabel == rootLabel {
return true
}
}
}
func (m *Matcher) Dump() (domainList []string, prefixList []string) {
domainMap := make(map[string]bool)
prefixMap := make(map[string]bool)
for _, key := range m.set.keys() {
key = reverseDomain(key)
if key[0] == prefixLabel {
prefixMap[key[1:]] = true
} else if key[0] == rootLabel {
prefixList = append(prefixList, key[1:])
} else {
domainMap[key] = true
}
}
for rawPrefix := range prefixMap {
if rawPrefix[0] == '.' {
if rootDomain := rawPrefix[1:]; domainMap[rootDomain] {
delete(domainMap, rootDomain)
prefixList = append(prefixList, rootDomain)
continue
}
}
prefixList = append(prefixList, rawPrefix)
}
for domain := range domainMap {
domainList = append(domainList, domain)
}
sort.Strings(domainList)
sort.Strings(prefixList)
return domainList, prefixList
}
func reverseDomain(domain string) string {
@ -46,15 +145,3 @@ func reverseDomain(domain string) string {
}
return string(b)
}
func reverseDomainSuffix(domain string) string {
l := len(domain)
b := make([]byte, l+1)
for i := 0; i < l; {
r, n := utf8.DecodeRuneInString(domain[i:])
i += n
utf8.EncodeRune(b[l-i:], r)
}
b[l] = prefixLabel
return string(b)
}

View file

@ -0,0 +1,80 @@
package domain_test
import (
"encoding/json"
"net/http"
"sort"
"testing"
"github.com/sagernet/sing/common/domain"
"github.com/stretchr/testify/require"
)
func TestMatcher(t *testing.T) {
t.Parallel()
testDomain := []string{"example.com", "example.org"}
testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"}
matcher := domain.NewMatcher(testDomain, testDomainSuffix, false)
require.NotNil(t, matcher)
require.True(t, matcher.Match("example.com"))
require.True(t, matcher.Match("example.org"))
require.False(t, matcher.Match("example.cn"))
require.True(t, matcher.Match("example.com.cn"))
require.True(t, matcher.Match("example.org.cn"))
require.False(t, matcher.Match("com.cn"))
require.False(t, matcher.Match("org.cn"))
require.True(t, matcher.Match("sagernet.org"))
require.True(t, matcher.Match("sing-box.sagernet.org"))
dDomain, dDomainSuffix := matcher.Dump()
require.Equal(t, testDomain, dDomain)
require.Equal(t, testDomainSuffix, dDomainSuffix)
}
func TestMatcherLegacy(t *testing.T) {
t.Parallel()
testDomain := []string{"example.com", "example.org"}
testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"}
matcher := domain.NewMatcher(testDomain, testDomainSuffix, true)
require.NotNil(t, matcher)
require.True(t, matcher.Match("example.com"))
require.True(t, matcher.Match("example.org"))
require.False(t, matcher.Match("example.cn"))
require.True(t, matcher.Match("example.com.cn"))
require.True(t, matcher.Match("example.org.cn"))
require.False(t, matcher.Match("com.cn"))
require.False(t, matcher.Match("org.cn"))
require.True(t, matcher.Match("sagernet.org"))
require.True(t, matcher.Match("sing-box.sagernet.org"))
dDomain, dDomainSuffix := matcher.Dump()
require.Equal(t, testDomain, dDomain)
require.Equal(t, testDomainSuffix, dDomainSuffix)
}
type simpleRuleSet struct {
Rules []struct {
Domain []string `json:"domain"`
DomainSuffix []string `json:"domain_suffix"`
}
}
func TestDumpLarge(t *testing.T) {
t.Parallel()
response, err := http.Get("https://raw.githubusercontent.com/MetaCubeX/meta-rules-dat/sing/geo/geosite/cn.json")
require.NoError(t, err)
defer response.Body.Close()
var ruleSet simpleRuleSet
err = json.NewDecoder(response.Body).Decode(&ruleSet)
require.NoError(t, err)
domainList := ruleSet.Rules[0].Domain
domainSuffixList := ruleSet.Rules[0].DomainSuffix
require.Len(t, ruleSet.Rules, 1)
require.True(t, len(domainList)+len(domainSuffixList) > 0)
sort.Strings(domainList)
sort.Strings(domainSuffixList)
matcher := domain.NewMatcher(domainList, domainSuffixList, false)
require.NotNil(t, matcher)
dDomain, dDomainSuffix := matcher.Dump()
require.Equal(t, domainList, dDomain)
require.Equal(t, domainSuffixList, dDomainSuffix)
}

View file

@ -1,10 +1,11 @@
package domain
import (
"encoding/binary"
"math/bits"
)
const prefixLabel = '\r'
"github.com/sagernet/sing/common/varbin"
)
// mod from https://github.com/openacid/succinct
@ -42,36 +43,61 @@ func newSuccinctSet(keys []string) *succinctSet {
return ss
}
func (ss *succinctSet) Has(key string) bool {
var nodeId, bmIdx int
for i := 0; i < len(key); i++ {
currentChar := key[i]
func (ss *succinctSet) keys() []string {
var result []string
var currentKey []byte
var bmIdx, nodeId int
var traverse func(int, int)
traverse = func(nodeId, bmIdx int) {
if getBit(ss.leaves, nodeId) != 0 {
result = append(result, string(currentKey))
}
for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
return false
return
}
nextLabel := ss.labels[bmIdx-nodeId]
if nextLabel == prefixLabel {
return true
}
if nextLabel == currentChar {
break
}
}
nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1
}
if getBit(ss.leaves, nodeId) != 0 {
return true
}
for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
return false
}
if ss.labels[bmIdx-nodeId] == prefixLabel {
return true
currentKey = append(currentKey, nextLabel)
nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
nextBmIdx := selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nextNodeId-1) + 1
traverse(nextNodeId, nextBmIdx)
currentKey = currentKey[:len(currentKey)-1]
}
}
traverse(nodeId, bmIdx)
return result
}
type succinctSetData struct {
Reserved uint8
Leaves []uint64
LabelBitmap []uint64
Labels []byte
}
func readSuccinctSet(reader varbin.Reader) (*succinctSet, error) {
matcher, err := varbin.ReadValue[succinctSetData](reader, binary.BigEndian)
if err != nil {
return nil, err
}
set := &succinctSet{
leaves: matcher.Leaves,
labelBitmap: matcher.LabelBitmap,
labels: matcher.Labels,
}
set.init()
return set, nil
}
func (ss *succinctSet) Write(writer varbin.Writer) error {
return varbin.Write(writer, binary.BigEndian, succinctSetData{
Leaves: ss.leaves,
LabelBitmap: ss.labelBitmap,
Labels: ss.labels,
})
}
func setBit(bm *[]uint64, i int, v int) {

View file

@ -12,3 +12,16 @@ func (e *causeError) Error() string {
func (e *causeError) Unwrap() error {
return e.cause
}
type causeError1 struct {
error
cause error
}
func (e *causeError1) Error() string {
return e.error.Error() + ": " + e.cause.Error()
}
func (e *causeError1) Unwrap() []error {
return []error{e.error, e.cause}
}

View file

@ -12,6 +12,7 @@ import (
F "github.com/sagernet/sing/common/format"
)
// Deprecated: wtf is this?
type Handler interface {
NewError(ctx context.Context, err error)
}
@ -31,6 +32,13 @@ func Cause(cause error, message ...any) error {
return &causeError{F.ToString(message...), cause}
}
func Cause1(err error, cause error) error {
if cause == nil {
panic("cause on an nil error")
}
return &causeError1{err, cause}
}
func Extend(cause error, message ...any) error {
if cause == nil {
panic("extend on an nil error")
@ -39,11 +47,11 @@ func Extend(cause error, message ...any) error {
}
func IsClosedOrCanceled(err error) bool {
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded)
return IsClosed(err) || IsCanceled(err) || IsTimeout(err)
}
func IsClosed(err error) bool {
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET)
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, syscall.ENOTCONN)
}
func IsCanceled(err error) bool {

View file

@ -1,24 +1,14 @@
package exceptions
import "github.com/sagernet/sing/common"
import (
"errors"
type HasInnerError interface {
Unwrap() error
}
"github.com/sagernet/sing/common"
)
// Deprecated: Use errors.Unwrap instead.
func Unwrap(err error) error {
for {
inner, ok := err.(HasInnerError)
if !ok {
break
}
innerErr := inner.Unwrap()
if innerErr == nil {
break
}
err = innerErr
}
return err
return errors.Unwrap(err)
}
func Cast[T any](err error) (T, bool) {

View file

@ -37,10 +37,13 @@ func Errors(errors ...error) error {
}
func Expand(err error) []error {
if multiErr, isMultiErr := err.(MultiError); isMultiErr {
return ExpandAll(multiErr.Unwrap())
if err == nil {
return nil
} else if multiErr, isMultiErr := err.(MultiError); isMultiErr {
return ExpandAll(common.FilterNotNil(multiErr.Unwrap()))
} else {
return []error{err}
}
return []error{err}
}
func ExpandAll(errs []error) []error {
@ -60,12 +63,5 @@ func IsMulti(err error, targetList ...error) bool {
return true
}
}
err = Unwrap(err)
multiErr, isMulti := err.(MultiError)
if !isMulti {
return false
}
return common.All(multiErr.Unwrap(), func(it error) bool {
return IsMulti(it, targetList...)
})
return false
}

View file

@ -1,17 +1,21 @@
package exceptions
import "net"
import (
"errors"
"net"
)
type TimeoutError interface {
Timeout() bool
}
func IsTimeout(err error) bool {
if netErr, isNetErr := err.(net.Error); isNetErr {
//goland:noinspection GoDeprecation
var netErr net.Error
if errors.As(err, &netErr) {
//nolint:staticcheck
return netErr.Temporary() && netErr.Timeout()
} else if timeoutErr, isTimeout := Cast[TimeoutError](err); isTimeout {
}
if timeoutErr, isTimeout := Cast[TimeoutError](err); isTimeout {
return timeoutErr.Timeout()
}
return false

View file

@ -0,0 +1,59 @@
package badjson
import (
"bytes"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
type JSONArray []any
func (a JSONArray) IsEmpty() bool {
if len(a) == 0 {
return true
}
return common.All(a, func(it any) bool {
if valueInterface, valueMaybeEmpty := it.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
return true
}
return false
})
}
func (a JSONArray) MarshalJSON() ([]byte, error) {
return json.Marshal([]any(a))
}
func (a *JSONArray) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content))
arrayStart, err := decoder.Token()
if err != nil {
return err
} else if arrayStart != json.Delim('[') {
return E.New("excepted array start, but got ", arrayStart)
}
err = a.decodeJSON(decoder)
if err != nil {
return err
}
arrayEnd, err := decoder.Token()
if err != nil {
return err
} else if arrayEnd != json.Delim(']') {
return E.New("excepted array end, but got ", arrayEnd)
}
return nil
}
func (a *JSONArray) decodeJSON(decoder *json.Decoder) error {
for decoder.More() {
item, err := decodeJSON(decoder)
if err != nil {
return err
}
*a = append(*a, item)
}
return nil
}

View file

@ -0,0 +1,5 @@
package badjson
type isEmpty interface {
IsEmpty() bool
}

View file

@ -0,0 +1,55 @@
package badjson
import (
"bytes"
"context"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
func Decode(ctx context.Context, content []byte) (any, error) {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
return decodeJSON(decoder)
}
func decodeJSON(decoder *json.Decoder) (any, error) {
rawToken, err := decoder.Token()
if err != nil {
return nil, err
}
switch token := rawToken.(type) {
case json.Delim:
switch token {
case '{':
var object JSONObject
err = object.decodeJSON(decoder)
if err != nil {
return nil, err
}
rawToken, err = decoder.Token()
if err != nil {
return nil, err
} else if rawToken != json.Delim('}') {
return nil, E.New("excepted object end, but got ", rawToken)
}
return &object, nil
case '[':
var array JSONArray
err = array.decodeJSON(decoder)
if err != nil {
return nil, err
}
rawToken, err = decoder.Token()
if err != nil {
return nil, err
} else if rawToken != json.Delim(']') {
return nil, E.New("excepted array end, but got ", rawToken)
}
return array, nil
default:
return nil, E.New("excepted object or array end: ", token)
}
}
return rawToken, nil
}

View file

@ -0,0 +1,142 @@
package badjson
import (
"context"
"os"
"reflect"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
func Omitempty[T any](ctx context.Context, value T) (T, error) {
objectContent, err := json.MarshalContext(ctx, value)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal object")
}
rawNewObject, err := Decode(ctx, objectContent)
if err != nil {
return common.DefaultValue[T](), err
}
newObjectContent, err := json.MarshalContext(ctx, rawNewObject)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
}
var newObject T
err = json.UnmarshalContext(ctx, newObjectContent, &newObject)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
}
return newObject, nil
}
func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) {
rawSource, err := json.MarshalContext(ctx, source)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source")
}
rawDestination, err := json.MarshalContext(ctx, destination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
}
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}
func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
if rawSource == nil {
return destination, nil
}
rawDestination, err := json.MarshalContext(ctx, destination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
}
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}
func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
if rawDestination == nil {
return source, nil
}
rawSource, err := json.MarshalContext(ctx, source)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source")
}
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}
func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "merge options")
}
var merged T
err = json.UnmarshalContext(ctx, rawMerged, &merged)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
}
return merged, nil
}
func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
if rawSource == nil && rawDestination == nil {
return nil, os.ErrInvalid
} else if rawSource == nil {
return rawDestination, nil
} else if rawDestination == nil {
return rawSource, nil
}
source, err := Decode(ctx, rawSource)
if err != nil {
return nil, E.Cause(err, "decode source")
}
destination, err := Decode(ctx, rawDestination)
if err != nil {
return nil, E.Cause(err, "decode destination")
}
if source == nil {
return json.MarshalContext(ctx, destination)
} else if destination == nil {
return json.Marshal(source)
}
merged, err := mergeJSON(source, destination, disableAppend)
if err != nil {
return nil, err
}
return json.MarshalContext(ctx, merged)
}
func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) {
switch destination := anyDestination.(type) {
case JSONArray:
if !disableAppend {
switch source := anySource.(type) {
case JSONArray:
destination = append(destination, source...)
default:
destination = append(destination, source)
}
}
return destination, nil
case *JSONObject:
switch source := anySource.(type) {
case *JSONObject:
for _, entry := range source.Entries() {
oldValue, loaded := destination.Get(entry.Key)
if loaded {
var err error
entry.Value, err = mergeJSON(entry.Value, oldValue, disableAppend)
if err != nil {
return nil, E.Cause(err, "merge object item ", entry.Key)
}
}
destination.Put(entry.Key, entry.Value)
}
default:
return nil, E.New("cannot merge json object into ", reflect.TypeOf(source))
}
return destination, nil
default:
return destination, nil
}
}

View file

@ -0,0 +1,68 @@
package badjson
import (
"context"
"reflect"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
cJSON "github.com/sagernet/sing/common/json/internal/contextjson"
)
func MarshallObjects(objects ...any) ([]byte, error) {
return MarshallObjectsContext(context.Background(), objects...)
}
func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) {
if len(objects) == 1 {
return json.Marshal(objects[0])
}
var content JSONObject
for _, object := range objects {
objectMap, err := newJSONObject(ctx, object)
if err != nil {
return nil, err
}
content.PutAll(objectMap)
}
return content.MarshalJSONContext(ctx)
}
func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error {
return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object)
}
func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
var content JSONObject
err := content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return err
}
for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) {
content.Remove(key)
}
if object == nil {
if content.IsEmpty() {
return nil
}
return E.New("unexpected key: ", content.Keys()[0])
}
inputContent, err = content.MarshalJSONContext(ctx)
if err != nil {
return err
}
return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object)
}
func newJSONObject(ctx context.Context, object any) (*JSONObject, error) {
inputContent, err := json.MarshalContext(ctx, object)
if err != nil {
return nil, err
}
var content JSONObject
err = content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return nil, err
}
return &content, nil
}

View file

@ -0,0 +1,107 @@
package badjson
import (
"bytes"
"context"
"strings"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/x/collections"
"github.com/sagernet/sing/common/x/linkedhashmap"
)
type JSONObject struct {
linkedhashmap.Map[string, any]
}
func (m *JSONObject) IsEmpty() bool {
if m.Size() == 0 {
return true
}
return common.All(m.Entries(), func(it collections.MapEntry[string, any]) bool {
if valueInterface, valueMaybeEmpty := it.Value.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
return true
}
return false
})
}
func (m *JSONObject) MarshalJSON() ([]byte, error) {
return m.MarshalJSONContext(context.Background())
}
func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
buffer := new(bytes.Buffer)
buffer.WriteString("{")
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
if valueInterface, valueMaybeEmpty := it.Value.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
return false
}
return true
})
iLen := len(items)
for i, entry := range items {
keyContent, err := json.MarshalContext(ctx, entry.Key)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ")
valueContent, err := json.MarshalContext(ctx, entry.Value)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(valueContent)))
if i < iLen-1 {
buffer.WriteString(", ")
}
}
buffer.WriteString("}")
return buffer.Bytes(), nil
}
func (m *JSONObject) UnmarshalJSON(content []byte) error {
return m.UnmarshalJSONContext(context.Background(), content)
}
func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
m.Clear()
objectStart, err := decoder.Token()
if err != nil {
return err
} else if objectStart != json.Delim('{') {
return E.New("expected json object start, but starts with ", objectStart)
}
err = m.decodeJSON(decoder)
if err != nil {
return E.Cause(err, "decode json object content")
}
objectEnd, err := decoder.Token()
if err != nil {
return err
} else if objectEnd != json.Delim('}') {
return E.New("expected json object end, but ends with ", objectEnd)
}
return nil
}
func (m *JSONObject) decodeJSON(decoder *json.Decoder) error {
for decoder.More() {
var entryKey string
keyToken, err := decoder.Token()
if err != nil {
return err
}
entryKey = keyToken.(string)
var entryValue any
entryValue, err = decodeJSON(decoder)
if err != nil {
return E.Cause(err, "decode value for ", entryKey)
}
m.Put(entryKey, entryValue)
}
return nil
}

View file

@ -0,0 +1,95 @@
package badjson
import (
"bytes"
"context"
"strings"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/x/linkedhashmap"
)
type TypedMap[K comparable, V any] struct {
linkedhashmap.Map[K, V]
}
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
return m.MarshalJSONContext(context.Background())
}
func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
buffer := new(bytes.Buffer)
buffer.WriteString("{")
items := m.Entries()
iLen := len(items)
for i, entry := range items {
keyContent, err := json.MarshalContext(ctx, entry.Key)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ")
valueContent, err := json.MarshalContext(ctx, entry.Value)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(valueContent)))
if i < iLen-1 {
buffer.WriteString(", ")
}
}
buffer.WriteString("}")
return buffer.Bytes(), nil
}
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
return m.UnmarshalJSONContext(context.Background(), content)
}
func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
m.Clear()
objectStart, err := decoder.Token()
if err != nil {
return err
} else if objectStart != json.Delim('{') {
return E.New("expected json object start, but starts with ", objectStart)
}
err = m.decodeJSON(ctx, decoder)
if err != nil {
return E.Cause(err, "decode json object content")
}
objectEnd, err := decoder.Token()
if err != nil {
return err
} else if objectEnd != json.Delim('}') {
return E.New("expected json object end, but ends with ", objectEnd)
}
return nil
}
func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error {
for decoder.More() {
keyToken, err := decoder.Token()
if err != nil {
return err
}
keyContent, err := json.MarshalContext(ctx, keyToken)
if err != nil {
return err
}
var entryKey K
err = json.UnmarshalContext(ctx, keyContent, &entryKey)
if err != nil {
return err
}
var entryValue V
err = decoder.Decode(&entryValue)
if err != nil {
return err
}
m.Put(entryKey, entryValue)
}
return nil
}

View file

@ -0,0 +1,32 @@
package badoption
import (
"time"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badoption/internal/my_time"
)
type Duration time.Duration
func (d Duration) Build() time.Duration {
return time.Duration(d)
}
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal((time.Duration)(d).String())
}
func (d *Duration) UnmarshalJSON(bytes []byte) error {
var value string
err := json.Unmarshal(bytes, &value)
if err != nil {
return err
}
duration, err := my_time.ParseDuration(value)
if err != nil {
return err
}
*d = Duration(duration)
return nil
}

View file

@ -0,0 +1,15 @@
package badoption
import "net/http"
type HTTPHeader map[string]Listable[string]
func (h HTTPHeader) Build() http.Header {
header := make(http.Header)
for name, values := range h {
for _, value := range values {
header.Add(name, value)
}
}
return header
}

View file

@ -0,0 +1,226 @@
package my_time
import (
"errors"
"time"
)
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
const durationDay = 24 * time.Hour
var unitMap = map[string]uint64{
"ns": uint64(time.Nanosecond),
"us": uint64(time.Microsecond),
"µs": uint64(time.Microsecond), // U+00B5 = micro symbol
"μs": uint64(time.Microsecond), // U+03BC = Greek letter mu
"ms": uint64(time.Millisecond),
"s": uint64(time.Second),
"m": uint64(time.Minute),
"h": uint64(time.Hour),
"d": uint64(durationDay),
}
// ParseDuration parses a duration string.
// A duration string is a possibly signed sequence of
// decimal numbers, each with optional fraction and a unit suffix,
// such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func ParseDuration(s string) (time.Duration, error) {
// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
orig := s
var d uint64
neg := false
// Consume [-+]?
if s != "" {
c := s[0]
if c == '-' || c == '+' {
neg = c == '-'
s = s[1:]
}
}
// Special case: if all that is left is "0", this is zero.
if s == "0" {
return 0, nil
}
if s == "" {
return 0, errors.New("time: invalid duration " + quote(orig))
}
for s != "" {
var (
v, f uint64 // integers before, after decimal point
scale float64 = 1 // value = v + f/scale
)
var err error
// The next character must be [0-9.]
if !(s[0] == '.' || '0' <= s[0] && s[0] <= '9') {
return 0, errors.New("time: invalid duration " + quote(orig))
}
// Consume [0-9]*
pl := len(s)
v, s, err = leadingInt(s)
if err != nil {
return 0, errors.New("time: invalid duration " + quote(orig))
}
pre := pl != len(s) // whether we consumed anything before a period
// Consume (\.[0-9]*)?
post := false
if s != "" && s[0] == '.' {
s = s[1:]
pl := len(s)
f, scale, s = leadingFraction(s)
post = pl != len(s)
}
if !pre && !post {
// no digits (e.g. ".s" or "-.s")
return 0, errors.New("time: invalid duration " + quote(orig))
}
// Consume unit.
i := 0
for ; i < len(s); i++ {
c := s[i]
if c == '.' || '0' <= c && c <= '9' {
break
}
}
if i == 0 {
return 0, errors.New("time: missing unit in duration " + quote(orig))
}
u := s[:i]
s = s[i:]
unit, ok := unitMap[u]
if !ok {
return 0, errors.New("time: unknown unit " + quote(u) + " in duration " + quote(orig))
}
if v > 1<<63/unit {
// overflow
return 0, errors.New("time: invalid duration " + quote(orig))
}
v *= unit
if f > 0 {
// float64 is needed to be nanosecond accurate for fractions of hours.
// v >= 0 && (f*unit/scale) <= 3.6e+12 (ns/h, h is the largest unit)
v += uint64(float64(f) * (float64(unit) / scale))
if v > 1<<63 {
// overflow
return 0, errors.New("time: invalid duration " + quote(orig))
}
}
d += v
if d > 1<<63 {
return 0, errors.New("time: invalid duration " + quote(orig))
}
}
if neg {
return -time.Duration(d), nil
}
if d > 1<<63-1 {
return 0, errors.New("time: invalid duration " + quote(orig))
}
return time.Duration(d), nil
}
var errLeadingInt = errors.New("time: bad [0-9]*") // never printed
// leadingInt consumes the leading [0-9]* from s.
func leadingInt[bytes []byte | string](s bytes) (x uint64, rem bytes, err error) {
i := 0
for ; i < len(s); i++ {
c := s[i]
if c < '0' || c > '9' {
break
}
if x > 1<<63/10 {
// overflow
return 0, rem, errLeadingInt
}
x = x*10 + uint64(c) - '0'
if x > 1<<63 {
// overflow
return 0, rem, errLeadingInt
}
}
return x, s[i:], nil
}
// leadingFraction consumes the leading [0-9]* from s.
// It is used only for fractions, so does not return an error on overflow,
// it just stops accumulating precision.
func leadingFraction(s string) (x uint64, scale float64, rem string) {
i := 0
scale = 1
overflow := false
for ; i < len(s); i++ {
c := s[i]
if c < '0' || c > '9' {
break
}
if overflow {
continue
}
if x > (1<<63-1)/10 {
// It's possible for overflow to give a positive number, so take care.
overflow = true
continue
}
y := x*10 + uint64(c) - '0'
if y > 1<<63 {
overflow = true
continue
}
x = y
scale *= 10
}
return x, scale, s[i:]
}
// These are borrowed from unicode/utf8 and strconv and replicate behavior in
// that package, since we can't take a dependency on either.
const (
lowerhex = "0123456789abcdef"
runeSelf = 0x80
runeError = '\uFFFD'
)
func quote(s string) string {
buf := make([]byte, 1, len(s)+2) // slice will be at least len(s) + quotes
buf[0] = '"'
for i, c := range s {
if c >= runeSelf || c < ' ' {
// This means you are asking us to parse a time.Duration or
// time.Location with unprintable or non-ASCII characters in it.
// We don't expect to hit this case very often. We could try to
// reproduce strconv.Quote's behavior with full fidelity but
// given how rarely we expect to hit these edge cases, speed and
// conciseness are better.
var width int
if c == runeError {
width = 1
if i+2 < len(s) && s[i:i+3] == string(runeError) {
width = 3
}
} else {
width = len(string(c))
}
for j := 0; j < width; j++ {
buf = append(buf, `\x`...)
buf = append(buf, lowerhex[s[i+j]>>4])
buf = append(buf, lowerhex[s[i+j]&0xF])
}
} else {
if c == '"' || c == '\\' {
buf = append(buf, '\\')
}
buf = append(buf, string(c)...)
}
}
buf = append(buf, '"')
return string(buf)
}

View file

@ -0,0 +1,35 @@
package badoption
import (
"context"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
type Listable[T any] []T
func (l Listable[T]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
arrayList := []T(l)
if len(arrayList) == 1 {
return json.Marshal(arrayList[0])
}
return json.MarshalContext(ctx, arrayList)
}
func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
if string(content) == "null" {
return nil
}
var singleItem T
err := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem)
if err == nil {
*l = []T{singleItem}
return nil
}
newErr := json.UnmarshalContextDisallowUnknownFields(ctx, content, (*[]T)(l))
if newErr == nil {
return nil
}
return E.Errors(err, newErr)
}

View file

@ -0,0 +1,98 @@
package badoption
import (
"net/netip"
"github.com/sagernet/sing/common/json"
)
type Addr netip.Addr
func (a *Addr) Build(defaultAddr netip.Addr) netip.Addr {
if a == nil {
return defaultAddr
}
return netip.Addr(*a)
}
func (a *Addr) MarshalJSON() ([]byte, error) {
return json.Marshal(netip.Addr(*a).String())
}
func (a *Addr) UnmarshalJSON(content []byte) error {
var value string
err := json.Unmarshal(content, &value)
if err != nil {
return err
}
addr, err := netip.ParseAddr(value)
if err != nil {
return err
}
*a = Addr(addr)
return nil
}
type Prefix netip.Prefix
func (p *Prefix) Build(defaultPrefix netip.Prefix) netip.Prefix {
if p == nil {
return defaultPrefix
}
return netip.Prefix(*p)
}
func (p *Prefix) MarshalJSON() ([]byte, error) {
return json.Marshal(netip.Prefix(*p).String())
}
func (p *Prefix) UnmarshalJSON(content []byte) error {
var value string
err := json.Unmarshal(content, &value)
if err != nil {
return err
}
prefix, err := netip.ParsePrefix(value)
if err != nil {
return err
}
*p = Prefix(prefix)
return nil
}
type Prefixable netip.Prefix
func (p *Prefixable) Build(defaultPrefix netip.Prefix) netip.Prefix {
if p == nil {
return defaultPrefix
}
return netip.Prefix(*p)
}
func (p *Prefixable) MarshalJSON() ([]byte, error) {
prefix := netip.Prefix(*p)
if prefix.Bits() == prefix.Addr().BitLen() {
return json.Marshal(prefix.Addr().String())
} else {
return json.Marshal(prefix.String())
}
}
func (p *Prefixable) UnmarshalJSON(content []byte) error {
var value string
err := json.Unmarshal(content, &value)
if err != nil {
return err
}
prefix, prefixErr := netip.ParsePrefix(value)
if prefixErr == nil {
*p = Prefixable(prefix)
return nil
}
addr, addrErr := netip.ParseAddr(value)
if addrErr == nil {
*p = Prefixable(netip.PrefixFrom(addr, addr.BitLen()))
return nil
}
return prefixErr
}

View file

@ -0,0 +1,31 @@
package badoption
import (
"regexp"
"github.com/sagernet/sing/common/json"
)
type Regexp regexp.Regexp
func (r *Regexp) Build() *regexp.Regexp {
return (*regexp.Regexp)(r)
}
func (r *Regexp) MarshalJSON() ([]byte, error) {
return json.Marshal((*regexp.Regexp)(r).String())
}
func (r *Regexp) UnmarshalJSON(content []byte) error {
var stringValue string
err := json.Unmarshal(content, &stringValue)
if err != nil {
return err
}
regex, err := regexp.Compile(stringValue)
if err != nil {
return err
}
*r = Regexp(*regex)
return nil
}

128
common/json/comment.go Normal file
View file

@ -0,0 +1,128 @@
package json
import (
"bufio"
"io"
)
// kanged from v2ray
type commentFilterState = byte
const (
commentFilterStateContent commentFilterState = iota
commentFilterStateEscape
commentFilterStateDoubleQuote
commentFilterStateDoubleQuoteEscape
commentFilterStateSingleQuote
commentFilterStateSingleQuoteEscape
commentFilterStateComment
commentFilterStateSlash
commentFilterStateMultilineComment
commentFilterStateMultilineCommentStar
)
type CommentFilter struct {
br *bufio.Reader
state commentFilterState
}
func NewCommentFilter(reader io.Reader) io.Reader {
return &CommentFilter{br: bufio.NewReader(reader)}
}
func (v *CommentFilter) Read(b []byte) (int, error) {
p := b[:0]
for len(p) < len(b)-2 {
x, err := v.br.ReadByte()
if err != nil {
if len(p) == 0 {
return 0, err
}
return len(p), nil
}
switch v.state {
case commentFilterStateContent:
switch x {
case '"':
v.state = commentFilterStateDoubleQuote
p = append(p, x)
case '\'':
v.state = commentFilterStateSingleQuote
p = append(p, x)
case '\\':
v.state = commentFilterStateEscape
case '#':
v.state = commentFilterStateComment
case '/':
v.state = commentFilterStateSlash
default:
p = append(p, x)
}
case commentFilterStateEscape:
p = append(p, '\\', x)
v.state = commentFilterStateContent
case commentFilterStateDoubleQuote:
switch x {
case '"':
v.state = commentFilterStateContent
p = append(p, x)
case '\\':
v.state = commentFilterStateDoubleQuoteEscape
default:
p = append(p, x)
}
case commentFilterStateDoubleQuoteEscape:
p = append(p, '\\', x)
v.state = commentFilterStateDoubleQuote
case commentFilterStateSingleQuote:
switch x {
case '\'':
v.state = commentFilterStateContent
p = append(p, x)
case '\\':
v.state = commentFilterStateSingleQuoteEscape
default:
p = append(p, x)
}
case commentFilterStateSingleQuoteEscape:
p = append(p, '\\', x)
v.state = commentFilterStateSingleQuote
case commentFilterStateComment:
if x == '\n' {
v.state = commentFilterStateContent
p = append(p, '\n')
}
case commentFilterStateSlash:
switch x {
case '/':
v.state = commentFilterStateComment
case '*':
v.state = commentFilterStateMultilineComment
default:
p = append(p, '/', x)
}
case commentFilterStateMultilineComment:
switch x {
case '*':
v.state = commentFilterStateMultilineCommentStar
case '\n':
p = append(p, '\n')
}
case commentFilterStateMultilineCommentStar:
switch x {
case '/':
v.state = commentFilterStateContent
case '*':
// Stay
case '\n':
p = append(p, '\n')
default:
v.state = commentFilterStateMultilineComment
}
default:
panic("Unknown state.")
}
}
return len(p), nil
}

23
common/json/context.go Normal file
View file

@ -0,0 +1,23 @@
//go:build go1.20 && !without_contextjson
package json
import (
"github.com/sagernet/sing/common/json/internal/contextjson"
)
var (
Marshal = json.Marshal
Unmarshal = json.Unmarshal
NewEncoder = json.NewEncoder
NewDecoder = json.NewDecoder
)
type (
Encoder = json.Encoder
Decoder = json.Decoder
Token = json.Token
Delim = json.Delim
SyntaxError = json.SyntaxError
RawMessage = json.RawMessage
)

View file

@ -0,0 +1,23 @@
package json
import (
"context"
"github.com/sagernet/sing/common/json/internal/contextjson"
)
var (
MarshalContext = json.MarshalContext
UnmarshalContext = json.UnmarshalContext
NewEncoderContext = json.NewEncoderContext
NewDecoderContext = json.NewDecoderContext
UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields
)
type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}
type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}

View file

@ -0,0 +1,3 @@
# contextjson
mod from go1.21.4

View file

@ -0,0 +1,11 @@
package json
import "context"
type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}
type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}

View file

@ -0,0 +1,43 @@
package json_test
import (
"context"
"testing"
"github.com/sagernet/sing/common/json/internal/contextjson"
"github.com/stretchr/testify/require"
)
type myStruct struct {
value string
}
func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) {
return json.Marshal(ctx.Value("key").(string))
}
func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error {
m.value = ctx.Value("key").(string)
return nil
}
//nolint:staticcheck
func TestMarshalContext(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), "key", "value")
var s myStruct
b, err := json.MarshalContext(ctx, &s)
require.NoError(t, err)
require.Equal(t, []byte(`"value"`), b)
}
//nolint:staticcheck
func TestUnmarshalContext(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), "key", "value")
var s myStruct
err := json.UnmarshalContext(ctx, []byte(`{}`), &s)
require.NoError(t, err)
require.Equal(t, "value", s.value)
}

Some files were not shown because too many files have changed in this diff Show more