mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 20:37:40 +03:00
Compare commits
335 commits
Author | SHA1 | Date | |
---|---|---|---|
|
159e489fc3 | ||
|
d39c2c2fdd | ||
|
ea82ac275f | ||
|
ea0ac932ae | ||
|
2b41455f5a | ||
|
23b0180a1b | ||
|
ce1b4851a4 | ||
|
2238a05966 | ||
|
b55d1c78b3 | ||
|
d54716612c | ||
|
9eafc7fc62 | ||
|
d8153df67f | ||
|
d9f6eb136d | ||
|
4dabb9be97 | ||
|
be9840c70f | ||
|
aa7d2543a3 | ||
|
33beacc053 | ||
|
442cceb9fa | ||
|
3374a45475 | ||
|
73776cf797 | ||
|
957166799e | ||
|
809d8eca13 | ||
|
9f69e7f9f7 | ||
|
478265cd45 | ||
|
3f30aaf25e | ||
|
39040e06dc | ||
|
6edd2ce0ea | ||
|
0a2e2a3eaf | ||
|
4ba1eb123c | ||
|
c44912a861 | ||
|
a8f5bf4eb0 | ||
|
30e9d91b57 | ||
|
7fd3517e4d | ||
|
a8285e06a5 | ||
|
3613ead480 | ||
|
c8f251c668 | ||
|
fa5355e99e | ||
|
30fbafd954 | ||
|
fdca9b3f8e | ||
|
e52e04f721 | ||
|
7f621fdd78 | ||
|
ae139d9ee1 | ||
|
c432befd02 | ||
|
cc7e630923 | ||
|
0998999911 | ||
|
72ff654ee0 | ||
|
11ffb962ae | ||
|
fcb19641e6 | ||
|
524a6bd0d1 | ||
|
b5f9e70ffd | ||
|
c80c8f907c | ||
|
a4eb7fa900 | ||
|
7ec09d6045 | ||
|
0641c71805 | ||
|
e7ec021b81 | ||
|
0f2447a95b | ||
|
72db784fc7 | ||
|
d59ac57aaa | ||
|
c63546470b | ||
|
55908bea36 | ||
|
6567829958 | ||
|
c324d4143d | ||
|
0acb36c118 | ||
|
26511a251f | ||
|
afd8993773 | ||
|
96bef0733f | ||
|
ec1df651e8 | ||
|
e33b1d67d5 | ||
|
ed6cde73f7 | ||
|
73cc65605e | ||
|
6c19e0736d | ||
|
08e8c02fb1 | ||
|
7beca62e4f | ||
|
e422e3d048 | ||
|
fa81eabc29 | ||
|
4498e57839 | ||
|
f97054e917 | ||
|
a2f9fef936 | ||
|
7893a74f75 | ||
|
332e470075 | ||
|
2bf9cc7253 | ||
|
bf8fc103a4 | ||
|
774893928c | ||
|
7ceaf63d41 | ||
|
8806e421f2 | ||
|
c37f988a4f | ||
|
0b4c0a1283 | ||
|
4745c34b4c | ||
|
e0196407a3 | ||
|
d8ec9c46cc | ||
|
a33349366d | ||
|
3155c16990 | ||
|
caa4340dc9 | ||
|
a31dba8ad2 | ||
|
9571124cf4 | ||
|
1c495c9b07 | ||
|
0f95dfe0e3 | ||
|
f3380c8dfe | ||
|
ab4353dd13 | ||
|
e0e490af7b | ||
|
ad4d59e2ed | ||
|
aca2a85545 | ||
|
589c7eb4df | ||
|
2873799b6d | ||
|
47cc308abf | ||
|
d9f2559214 | ||
|
b8736cc58d | ||
|
e82ff8e2e6 | ||
|
ba68e017a9 | ||
|
967afcf6c1 | ||
|
0c110ad733 | ||
|
de1b0bd772 | ||
|
f67a0988a6 | ||
|
e0ee7f49e2 | ||
|
284cb5ce98 | ||
|
8fb1634c9a | ||
|
4ab8cac5eb | ||
|
eec2fc325a | ||
|
2fa039945c | ||
|
e5825dcb59 | ||
|
6b73a57a24 | ||
|
3e2631ef0b | ||
|
4d96f15eca | ||
|
f9c59e9940 | ||
|
8b68fc4d7a | ||
|
5bfc326913 | ||
|
04152ea672 | ||
|
a069af4787 | ||
|
807a51bb81 | ||
|
ec2595f010 | ||
|
c98e8b6921 | ||
|
a4a9ec42c6 | ||
|
6e3921083b | ||
|
8e89f9b4dc | ||
|
5f02cb1cff | ||
|
ef00a1ec1e | ||
|
5ee4f84faf | ||
|
30f7629317 | ||
|
9e1749e108 | ||
|
b1355d7a4b | ||
|
45f572495e | ||
|
3ac055b755 | ||
|
a6e8fa3019 | ||
|
57b8a4c64a | ||
|
b7a631f798 | ||
|
fa0cc448dc | ||
|
0d1b3d6d6d | ||
|
2196f193ac | ||
|
c501a58ae7 | ||
|
c9319a35ee | ||
|
cdb9908442 | ||
|
81d1bc2768 | ||
|
2e36fa6849 | ||
|
edd320c3a8 | ||
|
56b953e091 | ||
|
4c4773fe54 | ||
|
36acc18bfb | ||
|
ad670bab68 | ||
|
2a2dbf1971 | ||
|
afa72012e5 | ||
|
c7ef05a85b | ||
|
0f7de716ac | ||
|
231d7607bc | ||
|
8b43ec8058 | ||
|
c17babe0ba | ||
|
1f02d6daca | ||
|
aa34723225 | ||
|
ae8098ad39 | ||
|
05c71c99d1 | ||
|
060edf2d69 | ||
|
d171f04941 | ||
|
51aeb14a87 | ||
|
96f5dea24b | ||
|
3336b50119 | ||
|
36be4ef141 | ||
|
843bab522a | ||
|
af92594d6d | ||
|
f23499eaea | ||
|
d7ce998e7e | ||
|
99d07d6e5a | ||
|
028dcd722c | ||
|
349d7d31b3 | ||
|
544863e3f4 | ||
|
5a3d0edd1c | ||
|
0d701cfff0 | ||
|
01c915e1e4 | ||
|
6b69046063 | ||
|
1ee2a5bd0e | ||
|
0ba5576c7b | ||
|
bca74039ea | ||
|
2dcabf4bfc | ||
|
7c05b33b2d | ||
|
0d98e82146 | ||
|
e50e7ae2d3 | ||
|
5b9d6eba38 | ||
|
d6fe25153c | ||
|
81c1436b69 | ||
|
38cdffccc5 | ||
|
8002db54c0 | ||
|
27518fdf12 | ||
|
570295cd12 | ||
|
49f5dfd767 | ||
|
96a05f9afe | ||
|
d16ad13362 | ||
|
e0ec961fb1 | ||
|
e727641a98 | ||
|
63b82af61f | ||
|
e781e86e32 | ||
|
5b05b5c147 | ||
|
494f88c9b8 | ||
|
57f342a847 | ||
|
3c4a2b06a9 | ||
|
bc044ee31d | ||
|
b1cca65a05 | ||
|
1453c7c8c2 | ||
|
b0849c43a6 | ||
|
03c21c0a12 | ||
|
0eec7bbe19 | ||
|
30bf19f283 | ||
|
8d731e6885 | ||
|
620f3a3b88 | ||
|
4db0062caa | ||
|
a755de3bbd | ||
|
c6a69b4912 | ||
|
83ce0be4d4 | ||
|
26d3f3d91b | ||
|
221477cf17 | ||
|
32f9f628a0 | ||
|
8807070904 | ||
|
37622ea16f | ||
|
f494f694c7 | ||
|
f8874e3e1c | ||
|
c68251b6d0 | ||
|
d852e9c03d | ||
|
dc27334e9a | ||
|
e2392d8d40 | ||
|
a3b120b25e | ||
|
00f3153336 | ||
|
2812461739 | ||
|
221f066dba | ||
|
49166ac427 | ||
|
be60138936 | ||
|
9be7806bab | ||
|
ab3e4694cb | ||
|
e29eff15cd | ||
|
1fa58a8663 | ||
|
f60c80c56f | ||
|
8365dd48a1 | ||
|
7662278795 | ||
|
ed2d05ab51 | ||
|
a23ffbaeb5 | ||
|
0c037cb0e2 | ||
|
bf0aaacc67 | ||
|
44534566a3 | ||
|
99737e617d | ||
|
81e2f7d664 | ||
|
a39dcdba79 | ||
|
2fc9c6028c | ||
|
72471d9b35 | ||
|
5fae6fa434 | ||
|
3b5e6c1812 | ||
|
f33bd0f122 | ||
|
d9426b04ab | ||
|
f196b4303e | ||
|
0b4d134fe9 | ||
|
f4e8bc868f | ||
|
28dfeaa762 | ||
|
421056635d | ||
|
816684484a | ||
|
0478ecc1e6 | ||
|
18a3739974 | ||
|
ad6dd0aa7b | ||
|
59e662e6e2 | ||
|
520dd58fb0 | ||
|
3abad1519f | ||
|
f8049ca89b | ||
|
d88db59703 | ||
|
28b0682207 | ||
|
20b4148381 | ||
|
b7cd741872 | ||
|
a82d82e559 | ||
|
7f8eaee1b6 | ||
|
b0d2e900ca | ||
|
3bedba1e1e | ||
|
df54c89b04 | ||
|
cee74ef1f4 | ||
|
121c0b14e4 | ||
|
6d63c1a7dc | ||
|
5326612db4 | ||
|
8d4b1ac38d | ||
|
8afcf45878 | ||
|
f8038854d2 | ||
|
1c4c60c739 | ||
|
3e60222a1a | ||
|
4bbf5f2c30 | ||
|
18cd006d26 | ||
|
2adcd8e205 | ||
|
46fc706837 | ||
|
85a9429ead | ||
|
ed73785ecc | ||
|
2731df1672 | ||
|
2f422b53b0 | ||
|
2cee5a24f6 | ||
|
448948d26d | ||
|
b8ca9f5424 | ||
|
bef6988dcf | ||
|
3ccf42b7d5 | ||
|
0560a4da41 | ||
|
e16845727f | ||
|
c875a4ffab | ||
|
6c2116b204 | ||
|
9fab0a9f43 | ||
|
bb61749065 | ||
|
b60f6390df | ||
|
8540030b40 | ||
|
83d9121b04 | ||
|
e839483670 | ||
|
7def9588a5 | ||
|
3401d21038 | ||
|
989b59665f | ||
|
76391bb71c | ||
|
5f1ef3441c | ||
|
e6cec20420 | ||
|
c28588d162 | ||
|
3e93dc9574 | ||
|
7faee4bf60 | ||
|
27d2950cdb | ||
|
3f1a1243a5 | ||
|
af83426404 | ||
|
439ecb1a20 | ||
|
bc788b0271 | ||
|
620a4e75cd | ||
|
6adf1a53aa | ||
|
2e9748c3ef | ||
|
bb02986064 |
279 changed files with 20571 additions and 1912 deletions
25
.github/renovate.json
vendored
Normal file
25
.github/renovate.json
vendored
Normal file
|
@ -0,0 +1,25 @@
|
|||
{
|
||||
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
||||
"commitMessagePrefix": "[dependencies]",
|
||||
"extends": [
|
||||
"config:base",
|
||||
":disableRateLimiting"
|
||||
],
|
||||
"baseBranches": [
|
||||
"dev"
|
||||
],
|
||||
"packageRules": [
|
||||
{
|
||||
"matchManagers": [
|
||||
"github-actions"
|
||||
],
|
||||
"groupName": "github-actions"
|
||||
},
|
||||
{
|
||||
"matchManagers": [
|
||||
"dockerfile"
|
||||
],
|
||||
"groupName": "Dockerfile"
|
||||
}
|
||||
]
|
||||
}
|
46
.github/workflows/debug.yml
vendored
46
.github/workflows/debug.yml
vendored
|
@ -1,46 +0,0 @@
|
|||
name: Debug build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- '.github/**'
|
||||
- '!.github/workflows/debug.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Debug build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Get latest go version
|
||||
id: version
|
||||
run: |
|
||||
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ${{ steps.version.outputs.go_version }}
|
||||
- name: Add cache to Go proxy
|
||||
run: |
|
||||
version=`git rev-parse HEAD`
|
||||
mkdir build
|
||||
pushd build
|
||||
go mod init build
|
||||
go get -v github.com/sagernet/sing@$version
|
||||
popd
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
make lint_install
|
||||
make lint
|
39
.github/workflows/lint.yml
vendored
Normal file
39
.github/workflows/lint.yml
vendored
Normal file
|
@ -0,0 +1,39 @@
|
|||
name: lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- '.github/**'
|
||||
- '!.github/workflows/lint.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.23
|
||||
- name: Cache go module
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
key: go-${{ hashFiles('**/go.sum') }}
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
with:
|
||||
version: latest
|
112
.github/workflows/test.yml
vendored
Normal file
112
.github/workflows/test.yml
vendored
Normal file
|
@ -0,0 +1,112 @@
|
|||
name: test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- '.github/**'
|
||||
- '!.github/workflows/debug.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Linux
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.23
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go120:
|
||||
name: Linux (Go 1.20)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.20
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go121:
|
||||
name: Linux (Go 1.21)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.21
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go122:
|
||||
name: Linux (Go 1.22)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.22
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_windows:
|
||||
name: Windows
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.23
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_darwin:
|
||||
name: macOS
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.23
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,2 +1,3 @@
|
|||
/.idea/
|
||||
/vendor/
|
||||
.DS_Store
|
||||
|
|
|
@ -3,14 +3,22 @@ linters:
|
|||
enable:
|
||||
- gofumpt
|
||||
- govet
|
||||
# - gci
|
||||
- gci
|
||||
- staticcheck
|
||||
- paralleltest
|
||||
- ineffassign
|
||||
|
||||
linters-settings:
|
||||
# gci:
|
||||
# sections:
|
||||
# - standard
|
||||
# - prefix(github.com/sagernet/)
|
||||
# - default
|
||||
gci:
|
||||
custom-order: true
|
||||
sections:
|
||||
- standard
|
||||
- prefix(github.com/sagernet/)
|
||||
- default
|
||||
staticcheck:
|
||||
go: '1.19'
|
||||
checks:
|
||||
- all
|
||||
- -SA1003
|
||||
|
||||
run:
|
||||
go: "1.23"
|
16
Makefile
16
Makefile
|
@ -1,21 +1,21 @@
|
|||
fmt:
|
||||
@gofumpt -l -w .
|
||||
@gofmt -s -w .
|
||||
@gci write -s "standard,prefix(github.com/sagernet/),default" .
|
||||
@gci write --custom-order -s standard -s "prefix(github.com/sagernet/)" -s "default" .
|
||||
|
||||
fmt_install:
|
||||
go install -v mvdan.cc/gofumpt@latest
|
||||
go install -v github.com/daixiang0/gci@v0.4.0
|
||||
go install -v github.com/daixiang0/gci@latest
|
||||
|
||||
lint:
|
||||
GOOS=linux golangci-lint run ./...
|
||||
GOOS=android golangci-lint run ./...
|
||||
GOOS=windows golangci-lint run ./...
|
||||
GOOS=darwin golangci-lint run ./...
|
||||
GOOS=freebsd golangci-lint run ./...
|
||||
GOOS=linux golangci-lint run
|
||||
GOOS=android golangci-lint run
|
||||
GOOS=windows golangci-lint run
|
||||
GOOS=darwin golangci-lint run
|
||||
GOOS=freebsd golangci-lint run
|
||||
|
||||
lint_install:
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
test:
|
||||
go test -v ./...
|
||||
go test ./...
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# sing
|
||||
|
||||

|
||||

|
||||
|
||||
Do you hear the people sing?
|
|
@ -18,14 +18,13 @@ var _ xml.TokenReader = (*Reader)(nil)
|
|||
type Reader struct {
|
||||
reader *bytes.Reader
|
||||
stringRefs []string
|
||||
attrs []xml.Attr
|
||||
}
|
||||
|
||||
func NewReader(content []byte) (xml.TokenReader, bool) {
|
||||
if len(content) < 4 || !bytes.Equal(content[:4], ProtocolMagicVersion0) {
|
||||
return nil, false
|
||||
}
|
||||
return &Reader{reader: bytes.NewReader(content)}, true
|
||||
return &Reader{reader: bytes.NewReader(content[4:])}, true
|
||||
}
|
||||
|
||||
func (r *Reader) Token() (token xml.Token, err error) {
|
||||
|
@ -47,7 +46,7 @@ func (r *Reader) Token() (token xml.Token, err error) {
|
|||
return
|
||||
}
|
||||
var attrs []xml.Attr
|
||||
attrs, err = r.pullAttributes()
|
||||
attrs, err = r.readAttributes()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -93,35 +92,41 @@ func (r *Reader) Token() (token xml.Token, err error) {
|
|||
_, err = r.readUTF()
|
||||
return
|
||||
case ATTRIBUTE:
|
||||
return nil, E.New("unexpected attribute")
|
||||
_, err = r.readAttribute()
|
||||
return
|
||||
}
|
||||
return nil, E.New("unknown token type ", tokenType, " with type ", eventType)
|
||||
}
|
||||
|
||||
func (r *Reader) pullAttributes() ([]xml.Attr, error) {
|
||||
err := r.pullAttribute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (r *Reader) readAttributes() ([]xml.Attr, error) {
|
||||
var attrs []xml.Attr
|
||||
for {
|
||||
attr, err := r.readAttribute()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
attrs = append(attrs, attr)
|
||||
}
|
||||
attrs := r.attrs
|
||||
r.attrs = nil
|
||||
return attrs, nil
|
||||
}
|
||||
|
||||
func (r *Reader) pullAttribute() error {
|
||||
func (r *Reader) readAttribute() (xml.Attr, error) {
|
||||
event, err := r.reader.ReadByte()
|
||||
if err != nil {
|
||||
return nil
|
||||
return xml.Attr{}, nil
|
||||
}
|
||||
tokenType := event & 0x0f
|
||||
eventType := event & 0xf0
|
||||
if tokenType != ATTRIBUTE {
|
||||
return r.reader.UnreadByte()
|
||||
err = r.reader.UnreadByte()
|
||||
if err != nil {
|
||||
return xml.Attr{}, nil
|
||||
}
|
||||
return xml.Attr{}, io.EOF
|
||||
}
|
||||
var name string
|
||||
name, err = r.readInternedUTF()
|
||||
name, err := r.readInternedUTF()
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
var value string
|
||||
switch eventType {
|
||||
|
@ -134,74 +139,73 @@ func (r *Reader) pullAttribute() error {
|
|||
case TypeString:
|
||||
value, err = r.readUTF()
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
case TypeStringInterned:
|
||||
value, err = r.readInternedUTF()
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
case TypeBytesHex:
|
||||
var data []byte
|
||||
data, err = r.readBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = hex.EncodeToString(data)
|
||||
case TypeBytesBase64:
|
||||
var data []byte
|
||||
data, err = r.readBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = base64.StdEncoding.EncodeToString(data)
|
||||
case TypeInt:
|
||||
var data int32
|
||||
err = binary.Read(r.reader, binary.BigEndian, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = strconv.FormatInt(int64(data), 10)
|
||||
case TypeIntHex:
|
||||
var data int32
|
||||
err = binary.Read(r.reader, binary.BigEndian, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = "0x" + strconv.FormatInt(int64(data), 16)
|
||||
case TypeLong:
|
||||
var data int64
|
||||
err = binary.Read(r.reader, binary.BigEndian, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = strconv.FormatInt(data, 10)
|
||||
case TypeLongHex:
|
||||
var data int64
|
||||
err = binary.Read(r.reader, binary.BigEndian, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = "0x" + strconv.FormatInt(data, 16)
|
||||
case TypeFloat:
|
||||
var data float32
|
||||
err = binary.Read(r.reader, binary.BigEndian, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = strconv.FormatFloat(float64(data), 'g', -1, 32)
|
||||
case TypeDouble:
|
||||
var data float64
|
||||
err = binary.Read(r.reader, binary.BigEndian, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = strconv.FormatFloat(data, 'g', -1, 64)
|
||||
default:
|
||||
return E.New("unexpected attribute type, ", eventType)
|
||||
return xml.Attr{}, E.New("unexpected attribute type, ", eventType)
|
||||
}
|
||||
r.attrs = append(r.attrs, xml.Attr{Name: xml.Name{Local: name}, Value: value})
|
||||
return r.pullAttribute()
|
||||
return xml.Attr{Name: xml.Name{Local: name}, Value: value}, nil
|
||||
}
|
||||
|
||||
func (r *Reader) readUnsignedShort() (uint16, error) {
|
||||
|
|
46
common/atomic/typed.go
Normal file
46
common/atomic/typed.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package atomic
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
type TypedValue[T any] struct {
|
||||
value atomic.Value
|
||||
}
|
||||
|
||||
// typedValue is a struct with determined type to resolve atomic.Value usages with interface types
|
||||
// https://github.com/golang/go/issues/22550
|
||||
//
|
||||
// The intention to have an atomic value store for errors. However, running this code panics:
|
||||
// panic: sync/atomic: store of inconsistently typed value into Value
|
||||
// This is because atomic.Value requires that the underlying concrete type be the same (which is a reasonable expectation for its implementation).
|
||||
// When going through the atomic.Value.Store method call, the fact that both these are of the error interface is lost.
|
||||
type typedValue[T any] struct {
|
||||
value T
|
||||
}
|
||||
|
||||
func (t *TypedValue[T]) Load() T {
|
||||
value := t.value.Load()
|
||||
if value == nil {
|
||||
return common.DefaultValue[T]()
|
||||
}
|
||||
return value.(typedValue[T]).value
|
||||
}
|
||||
|
||||
func (t *TypedValue[T]) Store(value T) {
|
||||
t.value.Store(typedValue[T]{value})
|
||||
}
|
||||
|
||||
func (t *TypedValue[T]) Swap(new T) T {
|
||||
old := t.value.Swap(typedValue[T]{new})
|
||||
if old == nil {
|
||||
return common.DefaultValue[T]()
|
||||
}
|
||||
return old.(typedValue[T]).value
|
||||
}
|
||||
|
||||
func (t *TypedValue[T]) CompareAndSwap(old, new T) bool {
|
||||
return t.value.CompareAndSwap(typedValue[T]{old}, typedValue[T]{new})
|
||||
}
|
19
common/atomic/types.go
Normal file
19
common/atomic/types.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
//go:build go1.19
|
||||
|
||||
package atomic
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
type (
|
||||
Bool = atomic.Bool
|
||||
Int32 = atomic.Int32
|
||||
Int64 = atomic.Int64
|
||||
Uint32 = atomic.Uint32
|
||||
Uint64 = atomic.Uint64
|
||||
Uintptr = atomic.Uintptr
|
||||
Value = atomic.Value
|
||||
)
|
||||
|
||||
type Pointer[T any] struct {
|
||||
atomic.Pointer[T]
|
||||
}
|
198
common/atomic/types_compat.go
Normal file
198
common/atomic/types_compat.go
Normal file
|
@ -0,0 +1,198 @@
|
|||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !go1.19
|
||||
|
||||
package atomic
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// A Bool is an atomic boolean value.
|
||||
// The zero value is false.
|
||||
type Bool struct {
|
||||
_ noCopy
|
||||
v uint32
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Bool) Load() bool { return atomic.LoadUint32(&x.v) != 0 }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Bool) Store(val bool) { atomic.StoreUint32(&x.v, b32(val)) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Bool) Swap(new bool) (old bool) { return atomic.SwapUint32(&x.v, b32(new)) != 0 }
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for the boolean value x.
|
||||
func (x *Bool) CompareAndSwap(old, new bool) (swapped bool) {
|
||||
return atomic.CompareAndSwapUint32(&x.v, b32(old), b32(new))
|
||||
}
|
||||
|
||||
// b32 returns a uint32 0 or 1 representing b.
|
||||
func b32(b bool) uint32 {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// A Pointer is an atomic pointer of type *T. The zero value is a nil *T.
|
||||
type Pointer[T any] struct {
|
||||
// Mention *T in a field to disallow conversion between Pointer types.
|
||||
// See go.dev/issue/56603 for more details.
|
||||
// Use *T, not T, to avoid spurious recursive type definition errors.
|
||||
_ [0]*T
|
||||
|
||||
_ noCopy
|
||||
v unsafe.Pointer
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Pointer[T]) Load() *T { return (*T)(atomic.LoadPointer(&x.v)) }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Pointer[T]) Store(val *T) { atomic.StorePointer(&x.v, unsafe.Pointer(val)) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Pointer[T]) Swap(new *T) (old *T) {
|
||||
return (*T)(atomic.SwapPointer(&x.v, unsafe.Pointer(new)))
|
||||
}
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for x.
|
||||
func (x *Pointer[T]) CompareAndSwap(old, new *T) (swapped bool) {
|
||||
return atomic.CompareAndSwapPointer(&x.v, unsafe.Pointer(old), unsafe.Pointer(new))
|
||||
}
|
||||
|
||||
// An Int32 is an atomic int32. The zero value is zero.
|
||||
type Int32 struct {
|
||||
_ noCopy
|
||||
v int32
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Int32) Load() int32 { return atomic.LoadInt32(&x.v) }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Int32) Store(val int32) { atomic.StoreInt32(&x.v, val) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Int32) Swap(new int32) (old int32) { return atomic.SwapInt32(&x.v, new) }
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for x.
|
||||
func (x *Int32) CompareAndSwap(old, new int32) (swapped bool) {
|
||||
return atomic.CompareAndSwapInt32(&x.v, old, new)
|
||||
}
|
||||
|
||||
// Add atomically adds delta to x and returns the new value.
|
||||
func (x *Int32) Add(delta int32) (new int32) { return atomic.AddInt32(&x.v, delta) }
|
||||
|
||||
// An Int64 is an atomic int64. The zero value is zero.
|
||||
type Int64 struct {
|
||||
_ noCopy
|
||||
v int64
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Int64) Load() int64 { return atomic.LoadInt64(&x.v) }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Int64) Store(val int64) { atomic.StoreInt64(&x.v, val) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Int64) Swap(new int64) (old int64) { return atomic.SwapInt64(&x.v, new) }
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for x.
|
||||
func (x *Int64) CompareAndSwap(old, new int64) (swapped bool) {
|
||||
return atomic.CompareAndSwapInt64(&x.v, old, new)
|
||||
}
|
||||
|
||||
// Add atomically adds delta to x and returns the new value.
|
||||
func (x *Int64) Add(delta int64) (new int64) { return atomic.AddInt64(&x.v, delta) }
|
||||
|
||||
// An Uint32 is an atomic uint32. The zero value is zero.
|
||||
type Uint32 struct {
|
||||
_ noCopy
|
||||
v uint32
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Uint32) Load() uint32 { return atomic.LoadUint32(&x.v) }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Uint32) Store(val uint32) { atomic.StoreUint32(&x.v, val) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Uint32) Swap(new uint32) (old uint32) { return atomic.SwapUint32(&x.v, new) }
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for x.
|
||||
func (x *Uint32) CompareAndSwap(old, new uint32) (swapped bool) {
|
||||
return atomic.CompareAndSwapUint32(&x.v, old, new)
|
||||
}
|
||||
|
||||
// Add atomically adds delta to x and returns the new value.
|
||||
func (x *Uint32) Add(delta uint32) (new uint32) { return atomic.AddUint32(&x.v, delta) }
|
||||
|
||||
// An Uint64 is an atomic uint64. The zero value is zero.
|
||||
type Uint64 struct {
|
||||
_ noCopy
|
||||
v uint64
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Uint64) Load() uint64 { return atomic.LoadUint64(&x.v) }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Uint64) Store(val uint64) { atomic.StoreUint64(&x.v, val) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Uint64) Swap(new uint64) (old uint64) { return atomic.SwapUint64(&x.v, new) }
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for x.
|
||||
func (x *Uint64) CompareAndSwap(old, new uint64) (swapped bool) {
|
||||
return atomic.CompareAndSwapUint64(&x.v, old, new)
|
||||
}
|
||||
|
||||
// Add atomically adds delta to x and returns the new value.
|
||||
func (x *Uint64) Add(delta uint64) (new uint64) { return atomic.AddUint64(&x.v, delta) }
|
||||
|
||||
// An Uintptr is an atomic uintptr. The zero value is zero.
|
||||
type Uintptr struct {
|
||||
_ noCopy
|
||||
v uintptr
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Uintptr) Load() uintptr { return atomic.LoadUintptr(&x.v) }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Uintptr) Store(val uintptr) { atomic.StoreUintptr(&x.v, val) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Uintptr) Swap(new uintptr) (old uintptr) { return atomic.SwapUintptr(&x.v, new) }
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for x.
|
||||
func (x *Uintptr) CompareAndSwap(old, new uintptr) (swapped bool) {
|
||||
return atomic.CompareAndSwapUintptr(&x.v, old, new)
|
||||
}
|
||||
|
||||
// Add atomically adds delta to x and returns the new value.
|
||||
func (x *Uintptr) Add(delta uintptr) (new uintptr) { return atomic.AddUintptr(&x.v, delta) }
|
||||
|
||||
// noCopy may be added to structs which must not be copied
|
||||
// after the first use.
|
||||
//
|
||||
// See https://golang.org/issues/8005#issuecomment-190753527
|
||||
// for details.
|
||||
//
|
||||
// Note that it must not be embedded, due to the Lock and Unlock methods.
|
||||
type noCopy struct{}
|
||||
|
||||
// Lock is a no-op used by -copylocks checker from `go vet`.
|
||||
func (*noCopy) Lock() {}
|
||||
func (*noCopy) Unlock() {}
|
||||
|
||||
type Value = atomic.Value
|
|
@ -1,38 +1,30 @@
|
|||
package auth
|
||||
|
||||
type Authenticator interface {
|
||||
Verify(user string, pass string) bool
|
||||
Users() []string
|
||||
}
|
||||
import "github.com/sagernet/sing/common"
|
||||
|
||||
type User struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
type inMemoryAuthenticator struct {
|
||||
storage map[string]string
|
||||
usernames []string
|
||||
type Authenticator struct {
|
||||
userMap map[string][]string
|
||||
}
|
||||
|
||||
func (au *inMemoryAuthenticator) Verify(username string, password string) bool {
|
||||
realPass, ok := au.storage[username]
|
||||
return ok && realPass == password
|
||||
}
|
||||
|
||||
func (au *inMemoryAuthenticator) Users() []string { return au.usernames }
|
||||
|
||||
func NewAuthenticator(users []User) Authenticator {
|
||||
func NewAuthenticator(users []User) *Authenticator {
|
||||
if len(users) == 0 {
|
||||
return nil
|
||||
}
|
||||
au := &inMemoryAuthenticator{
|
||||
storage: make(map[string]string),
|
||||
usernames: make([]string, 0, len(users)),
|
||||
au := &Authenticator{
|
||||
userMap: make(map[string][]string),
|
||||
}
|
||||
for _, user := range users {
|
||||
au.storage[user.Username] = user.Password
|
||||
au.usernames = append(au.usernames, user.Username)
|
||||
au.userMap[user.Username] = append(au.userMap[user.Username], user.Password)
|
||||
}
|
||||
return au
|
||||
}
|
||||
|
||||
func (au *Authenticator) Verify(username string, password string) bool {
|
||||
passwordList, ok := au.userMap[username]
|
||||
return ok && common.Contains(passwordList, password)
|
||||
}
|
||||
|
|
63
common/baderror/baderror.go
Normal file
63
common/baderror/baderror.go
Normal file
|
@ -0,0 +1,63 @@
|
|||
package baderror
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Contains(err error, msgList ...string) bool {
|
||||
for _, msg := range msgList {
|
||||
if strings.Contains(err.Error(), msg) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func WrapH2(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return io.EOF
|
||||
}
|
||||
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func WrapGRPC(err error) error {
|
||||
// grpc uses stupid internal error types
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if Contains(err, "EOF") {
|
||||
return io.EOF
|
||||
}
|
||||
if Contains(err, "Canceled") {
|
||||
return context.Canceled
|
||||
}
|
||||
if Contains(err,
|
||||
"the client connection is closing",
|
||||
"server closed the stream without sending trailers") {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func WrapQUIC(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if Contains(err,
|
||||
"canceled by remote with error code 0",
|
||||
"canceled by local with error code 0",
|
||||
) {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -3,6 +3,9 @@ package batch
|
|||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type Option[T any] func(b *Batch[T])
|
||||
|
@ -17,6 +20,10 @@ type Error struct {
|
|||
Err error
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return E.Cause(e.Err, e.Key).Error()
|
||||
}
|
||||
|
||||
func WithConcurrencyNum[T any](n int) Option[T] {
|
||||
return func(b *Batch[T]) {
|
||||
q := make(chan struct{}, n)
|
||||
|
@ -35,7 +42,7 @@ type Batch[T any] struct {
|
|||
mux sync.Mutex
|
||||
err *Error
|
||||
once sync.Once
|
||||
cancel func()
|
||||
cancel common.ContextCancelCauseFunc
|
||||
}
|
||||
|
||||
func (b *Batch[T]) Go(key string, fn func() (T, error)) {
|
||||
|
@ -54,7 +61,7 @@ func (b *Batch[T]) Go(key string, fn func() (T, error)) {
|
|||
b.once.Do(func() {
|
||||
b.err = &Error{key, err}
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
b.cancel(b.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -69,7 +76,7 @@ func (b *Batch[T]) Go(key string, fn func() (T, error)) {
|
|||
func (b *Batch[T]) Wait() *Error {
|
||||
b.wg.Wait()
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
b.cancel(nil)
|
||||
}
|
||||
return b.err
|
||||
}
|
||||
|
@ -90,7 +97,7 @@ func (b *Batch[T]) Result() map[string]Result[T] {
|
|||
}
|
||||
|
||||
func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
ctx, cancel := common.ContextWithCancelCause(ctx)
|
||||
|
||||
b := &Batch[T]{
|
||||
result: map[string]Result[T]{},
|
||||
|
|
3
common/binary/README.md
Normal file
3
common/binary/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# binary
|
||||
|
||||
mod from go 1.22.3
|
817
common/binary/binary.go
Normal file
817
common/binary/binary.go
Normal file
|
@ -0,0 +1,817 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package binary implements simple translation between numbers and byte
|
||||
// sequences and encoding and decoding of varints.
|
||||
//
|
||||
// Numbers are translated by reading and writing fixed-size values.
|
||||
// A fixed-size value is either a fixed-size arithmetic
|
||||
// type (bool, int8, uint8, int16, float32, complex64, ...)
|
||||
// or an array or struct containing only fixed-size values.
|
||||
//
|
||||
// The varint functions encode and decode single integer values using
|
||||
// a variable-length encoding; smaller values require fewer bytes.
|
||||
// For a specification, see
|
||||
// https://developers.google.com/protocol-buffers/docs/encoding.
|
||||
//
|
||||
// This package favors simplicity over efficiency. Clients that require
|
||||
// high-performance serialization, especially for large data structures,
|
||||
// should look at more advanced solutions such as the [encoding/gob]
|
||||
// package or [google.golang.org/protobuf] for protocol buffers.
|
||||
package binary
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A ByteOrder specifies how to convert byte slices into
|
||||
// 16-, 32-, or 64-bit unsigned integers.
|
||||
//
|
||||
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
|
||||
type ByteOrder interface {
|
||||
Uint16([]byte) uint16
|
||||
Uint32([]byte) uint32
|
||||
Uint64([]byte) uint64
|
||||
PutUint16([]byte, uint16)
|
||||
PutUint32([]byte, uint32)
|
||||
PutUint64([]byte, uint64)
|
||||
String() string
|
||||
}
|
||||
|
||||
// AppendByteOrder specifies how to append 16-, 32-, or 64-bit unsigned integers
|
||||
// into a byte slice.
|
||||
//
|
||||
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
|
||||
type AppendByteOrder interface {
|
||||
AppendUint16([]byte, uint16) []byte
|
||||
AppendUint32([]byte, uint32) []byte
|
||||
AppendUint64([]byte, uint64) []byte
|
||||
String() string
|
||||
}
|
||||
|
||||
// LittleEndian is the little-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var LittleEndian littleEndian
|
||||
|
||||
// BigEndian is the big-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var BigEndian bigEndian
|
||||
|
||||
type littleEndian struct{}
|
||||
|
||||
func (littleEndian) Uint16(b []byte) uint16 {
|
||||
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint16(b[0]) | uint16(b[1])<<8
|
||||
}
|
||||
|
||||
func (littleEndian) PutUint16(b []byte, v uint16) {
|
||||
_ = b[1] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v)
|
||||
b[1] = byte(v >> 8)
|
||||
}
|
||||
|
||||
func (littleEndian) AppendUint16(b []byte, v uint16) []byte {
|
||||
return append(b,
|
||||
byte(v),
|
||||
byte(v>>8),
|
||||
)
|
||||
}
|
||||
|
||||
func (littleEndian) Uint32(b []byte) uint32 {
|
||||
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
|
||||
}
|
||||
|
||||
func (littleEndian) PutUint32(b []byte, v uint32) {
|
||||
_ = b[3] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v)
|
||||
b[1] = byte(v >> 8)
|
||||
b[2] = byte(v >> 16)
|
||||
b[3] = byte(v >> 24)
|
||||
}
|
||||
|
||||
func (littleEndian) AppendUint32(b []byte, v uint32) []byte {
|
||||
return append(b,
|
||||
byte(v),
|
||||
byte(v>>8),
|
||||
byte(v>>16),
|
||||
byte(v>>24),
|
||||
)
|
||||
}
|
||||
|
||||
func (littleEndian) Uint64(b []byte) uint64 {
|
||||
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
|
||||
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
|
||||
}
|
||||
|
||||
func (littleEndian) PutUint64(b []byte, v uint64) {
|
||||
_ = b[7] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v)
|
||||
b[1] = byte(v >> 8)
|
||||
b[2] = byte(v >> 16)
|
||||
b[3] = byte(v >> 24)
|
||||
b[4] = byte(v >> 32)
|
||||
b[5] = byte(v >> 40)
|
||||
b[6] = byte(v >> 48)
|
||||
b[7] = byte(v >> 56)
|
||||
}
|
||||
|
||||
func (littleEndian) AppendUint64(b []byte, v uint64) []byte {
|
||||
return append(b,
|
||||
byte(v),
|
||||
byte(v>>8),
|
||||
byte(v>>16),
|
||||
byte(v>>24),
|
||||
byte(v>>32),
|
||||
byte(v>>40),
|
||||
byte(v>>48),
|
||||
byte(v>>56),
|
||||
)
|
||||
}
|
||||
|
||||
func (littleEndian) String() string { return "LittleEndian" }
|
||||
|
||||
func (littleEndian) GoString() string { return "binary.LittleEndian" }
|
||||
|
||||
type bigEndian struct{}
|
||||
|
||||
func (bigEndian) Uint16(b []byte) uint16 {
|
||||
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint16(b[1]) | uint16(b[0])<<8
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint16(b []byte, v uint16) {
|
||||
_ = b[1] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v >> 8)
|
||||
b[1] = byte(v)
|
||||
}
|
||||
|
||||
func (bigEndian) AppendUint16(b []byte, v uint16) []byte {
|
||||
return append(b,
|
||||
byte(v>>8),
|
||||
byte(v),
|
||||
)
|
||||
}
|
||||
|
||||
func (bigEndian) Uint32(b []byte) uint32 {
|
||||
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint32(b []byte, v uint32) {
|
||||
_ = b[3] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v >> 24)
|
||||
b[1] = byte(v >> 16)
|
||||
b[2] = byte(v >> 8)
|
||||
b[3] = byte(v)
|
||||
}
|
||||
|
||||
func (bigEndian) AppendUint32(b []byte, v uint32) []byte {
|
||||
return append(b,
|
||||
byte(v>>24),
|
||||
byte(v>>16),
|
||||
byte(v>>8),
|
||||
byte(v),
|
||||
)
|
||||
}
|
||||
|
||||
func (bigEndian) Uint64(b []byte) uint64 {
|
||||
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
|
||||
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint64(b []byte, v uint64) {
|
||||
_ = b[7] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v >> 56)
|
||||
b[1] = byte(v >> 48)
|
||||
b[2] = byte(v >> 40)
|
||||
b[3] = byte(v >> 32)
|
||||
b[4] = byte(v >> 24)
|
||||
b[5] = byte(v >> 16)
|
||||
b[6] = byte(v >> 8)
|
||||
b[7] = byte(v)
|
||||
}
|
||||
|
||||
func (bigEndian) AppendUint64(b []byte, v uint64) []byte {
|
||||
return append(b,
|
||||
byte(v>>56),
|
||||
byte(v>>48),
|
||||
byte(v>>40),
|
||||
byte(v>>32),
|
||||
byte(v>>24),
|
||||
byte(v>>16),
|
||||
byte(v>>8),
|
||||
byte(v),
|
||||
)
|
||||
}
|
||||
|
||||
func (bigEndian) String() string { return "BigEndian" }
|
||||
|
||||
func (bigEndian) GoString() string { return "binary.BigEndian" }
|
||||
|
||||
func (nativeEndian) String() string { return "NativeEndian" }
|
||||
|
||||
func (nativeEndian) GoString() string { return "binary.NativeEndian" }
|
||||
|
||||
// Read reads structured binary data from r into data.
|
||||
// Data must be a pointer to a fixed-size value or a slice
|
||||
// of fixed-size values.
|
||||
// Bytes read from r are decoded using the specified byte order
|
||||
// and written to successive fields of the data.
|
||||
// When decoding boolean values, a zero byte is decoded as false, and
|
||||
// any other non-zero byte is decoded as true.
|
||||
// When reading into structs, the field data for fields with
|
||||
// blank (_) field names is skipped; i.e., blank field names
|
||||
// may be used for padding.
|
||||
// When reading into a struct, all non-blank fields must be exported
|
||||
// or Read may panic.
|
||||
//
|
||||
// The error is [io.EOF] only if no bytes were read.
|
||||
// If an [io.EOF] happens after reading some but not all the bytes,
|
||||
// Read returns [io.ErrUnexpectedEOF].
|
||||
func Read(r io.Reader, order ByteOrder, data any) error {
|
||||
// Fast path for basic types and slices.
|
||||
if n := intDataSize(data); n != 0 {
|
||||
bs := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, bs); err != nil {
|
||||
return err
|
||||
}
|
||||
switch data := data.(type) {
|
||||
case *bool:
|
||||
*data = bs[0] != 0
|
||||
case *int8:
|
||||
*data = int8(bs[0])
|
||||
case *uint8:
|
||||
*data = bs[0]
|
||||
case *int16:
|
||||
*data = int16(order.Uint16(bs))
|
||||
case *uint16:
|
||||
*data = order.Uint16(bs)
|
||||
case *int32:
|
||||
*data = int32(order.Uint32(bs))
|
||||
case *uint32:
|
||||
*data = order.Uint32(bs)
|
||||
case *int64:
|
||||
*data = int64(order.Uint64(bs))
|
||||
case *uint64:
|
||||
*data = order.Uint64(bs)
|
||||
case *float32:
|
||||
*data = math.Float32frombits(order.Uint32(bs))
|
||||
case *float64:
|
||||
*data = math.Float64frombits(order.Uint64(bs))
|
||||
case []bool:
|
||||
for i, x := range bs { // Easier to loop over the input for 8-bit values.
|
||||
data[i] = x != 0
|
||||
}
|
||||
case []int8:
|
||||
for i, x := range bs {
|
||||
data[i] = int8(x)
|
||||
}
|
||||
case []uint8:
|
||||
copy(data, bs)
|
||||
case []int16:
|
||||
for i := range data {
|
||||
data[i] = int16(order.Uint16(bs[2*i:]))
|
||||
}
|
||||
case []uint16:
|
||||
for i := range data {
|
||||
data[i] = order.Uint16(bs[2*i:])
|
||||
}
|
||||
case []int32:
|
||||
for i := range data {
|
||||
data[i] = int32(order.Uint32(bs[4*i:]))
|
||||
}
|
||||
case []uint32:
|
||||
for i := range data {
|
||||
data[i] = order.Uint32(bs[4*i:])
|
||||
}
|
||||
case []int64:
|
||||
for i := range data {
|
||||
data[i] = int64(order.Uint64(bs[8*i:]))
|
||||
}
|
||||
case []uint64:
|
||||
for i := range data {
|
||||
data[i] = order.Uint64(bs[8*i:])
|
||||
}
|
||||
case []float32:
|
||||
for i := range data {
|
||||
data[i] = math.Float32frombits(order.Uint32(bs[4*i:]))
|
||||
}
|
||||
case []float64:
|
||||
for i := range data {
|
||||
data[i] = math.Float64frombits(order.Uint64(bs[8*i:]))
|
||||
}
|
||||
default:
|
||||
n = 0 // fast path doesn't apply
|
||||
}
|
||||
if n != 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to reflect-based decoding.
|
||||
v := reflect.ValueOf(data)
|
||||
size := -1
|
||||
switch v.Kind() {
|
||||
case reflect.Pointer:
|
||||
v = v.Elem()
|
||||
size = dataSize(v)
|
||||
case reflect.Slice:
|
||||
size = dataSize(v)
|
||||
}
|
||||
if size < 0 {
|
||||
return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String())
|
||||
}
|
||||
d := &decoder{order: order, buf: make([]byte, size)}
|
||||
if _, err := io.ReadFull(r, d.buf); err != nil {
|
||||
return err
|
||||
}
|
||||
d.value(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes the binary representation of data into w.
|
||||
// Data must be a fixed-size value or a slice of fixed-size
|
||||
// values, or a pointer to such data.
|
||||
// Boolean values encode as one byte: 1 for true, and 0 for false.
|
||||
// Bytes written to w are encoded using the specified byte order
|
||||
// and read from successive fields of the data.
|
||||
// When writing structs, zero values are written for fields
|
||||
// with blank (_) field names.
|
||||
func Write(w io.Writer, order ByteOrder, data any) error {
|
||||
// Fast path for basic types and slices.
|
||||
if n := intDataSize(data); n != 0 {
|
||||
bs := make([]byte, n)
|
||||
switch v := data.(type) {
|
||||
case *bool:
|
||||
if *v {
|
||||
bs[0] = 1
|
||||
} else {
|
||||
bs[0] = 0
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
bs[0] = 1
|
||||
} else {
|
||||
bs[0] = 0
|
||||
}
|
||||
case []bool:
|
||||
for i, x := range v {
|
||||
if x {
|
||||
bs[i] = 1
|
||||
} else {
|
||||
bs[i] = 0
|
||||
}
|
||||
}
|
||||
case *int8:
|
||||
bs[0] = byte(*v)
|
||||
case int8:
|
||||
bs[0] = byte(v)
|
||||
case []int8:
|
||||
for i, x := range v {
|
||||
bs[i] = byte(x)
|
||||
}
|
||||
case *uint8:
|
||||
bs[0] = *v
|
||||
case uint8:
|
||||
bs[0] = v
|
||||
case []uint8:
|
||||
bs = v
|
||||
case *int16:
|
||||
order.PutUint16(bs, uint16(*v))
|
||||
case int16:
|
||||
order.PutUint16(bs, uint16(v))
|
||||
case []int16:
|
||||
for i, x := range v {
|
||||
order.PutUint16(bs[2*i:], uint16(x))
|
||||
}
|
||||
case *uint16:
|
||||
order.PutUint16(bs, *v)
|
||||
case uint16:
|
||||
order.PutUint16(bs, v)
|
||||
case []uint16:
|
||||
for i, x := range v {
|
||||
order.PutUint16(bs[2*i:], x)
|
||||
}
|
||||
case *int32:
|
||||
order.PutUint32(bs, uint32(*v))
|
||||
case int32:
|
||||
order.PutUint32(bs, uint32(v))
|
||||
case []int32:
|
||||
for i, x := range v {
|
||||
order.PutUint32(bs[4*i:], uint32(x))
|
||||
}
|
||||
case *uint32:
|
||||
order.PutUint32(bs, *v)
|
||||
case uint32:
|
||||
order.PutUint32(bs, v)
|
||||
case []uint32:
|
||||
for i, x := range v {
|
||||
order.PutUint32(bs[4*i:], x)
|
||||
}
|
||||
case *int64:
|
||||
order.PutUint64(bs, uint64(*v))
|
||||
case int64:
|
||||
order.PutUint64(bs, uint64(v))
|
||||
case []int64:
|
||||
for i, x := range v {
|
||||
order.PutUint64(bs[8*i:], uint64(x))
|
||||
}
|
||||
case *uint64:
|
||||
order.PutUint64(bs, *v)
|
||||
case uint64:
|
||||
order.PutUint64(bs, v)
|
||||
case []uint64:
|
||||
for i, x := range v {
|
||||
order.PutUint64(bs[8*i:], x)
|
||||
}
|
||||
case *float32:
|
||||
order.PutUint32(bs, math.Float32bits(*v))
|
||||
case float32:
|
||||
order.PutUint32(bs, math.Float32bits(v))
|
||||
case []float32:
|
||||
for i, x := range v {
|
||||
order.PutUint32(bs[4*i:], math.Float32bits(x))
|
||||
}
|
||||
case *float64:
|
||||
order.PutUint64(bs, math.Float64bits(*v))
|
||||
case float64:
|
||||
order.PutUint64(bs, math.Float64bits(v))
|
||||
case []float64:
|
||||
for i, x := range v {
|
||||
order.PutUint64(bs[8*i:], math.Float64bits(x))
|
||||
}
|
||||
}
|
||||
_, err := w.Write(bs)
|
||||
return err
|
||||
}
|
||||
|
||||
// Fallback to reflect-based encoding.
|
||||
v := reflect.Indirect(reflect.ValueOf(data))
|
||||
size := dataSize(v)
|
||||
if size < 0 {
|
||||
return errors.New("binary.Write: some values are not fixed-sized in type " + reflect.TypeOf(data).String())
|
||||
}
|
||||
buf := make([]byte, size)
|
||||
e := &encoder{order: order, buf: buf}
|
||||
e.value(v)
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// Size returns how many bytes [Write] would generate to encode the value v, which
|
||||
// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data.
|
||||
// If v is neither of these, Size returns -1.
|
||||
func Size(v any) int {
|
||||
return dataSize(reflect.Indirect(reflect.ValueOf(v)))
|
||||
}
|
||||
|
||||
var structSize sync.Map // map[reflect.Type]int
|
||||
|
||||
// dataSize returns the number of bytes the actual data represented by v occupies in memory.
|
||||
// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice
|
||||
// it returns the length of the slice times the element size and does not count the memory
|
||||
// occupied by the header. If the type of v is not acceptable, dataSize returns -1.
|
||||
func dataSize(v reflect.Value) int {
|
||||
switch v.Kind() {
|
||||
case reflect.Slice:
|
||||
if s := sizeof(v.Type().Elem()); s >= 0 {
|
||||
return s * v.Len()
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
if size, ok := structSize.Load(t); ok {
|
||||
return size.(int)
|
||||
}
|
||||
size := sizeof(t)
|
||||
structSize.Store(t, size)
|
||||
return size
|
||||
|
||||
default:
|
||||
if v.IsValid() {
|
||||
return sizeof(v.Type())
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// sizeof returns the size >= 0 of variables for the given type or -1 if the type is not acceptable.
|
||||
func sizeof(t reflect.Type) int {
|
||||
switch t.Kind() {
|
||||
case reflect.Array:
|
||||
if s := sizeof(t.Elem()); s >= 0 {
|
||||
return s * t.Len()
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
sum := 0
|
||||
for i, n := 0, t.NumField(); i < n; i++ {
|
||||
s := sizeof(t.Field(i).Type)
|
||||
if s < 0 {
|
||||
return -1
|
||||
}
|
||||
sum += s
|
||||
}
|
||||
return sum
|
||||
|
||||
case reflect.Bool,
|
||||
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
return int(t.Size())
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
type coder struct {
|
||||
order ByteOrder
|
||||
buf []byte
|
||||
offset int
|
||||
}
|
||||
|
||||
type (
|
||||
decoder coder
|
||||
encoder coder
|
||||
)
|
||||
|
||||
func (d *decoder) bool() bool {
|
||||
x := d.buf[d.offset]
|
||||
d.offset++
|
||||
return x != 0
|
||||
}
|
||||
|
||||
func (e *encoder) bool(x bool) {
|
||||
if x {
|
||||
e.buf[e.offset] = 1
|
||||
} else {
|
||||
e.buf[e.offset] = 0
|
||||
}
|
||||
e.offset++
|
||||
}
|
||||
|
||||
func (d *decoder) uint8() uint8 {
|
||||
x := d.buf[d.offset]
|
||||
d.offset++
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint8(x uint8) {
|
||||
e.buf[e.offset] = x
|
||||
e.offset++
|
||||
}
|
||||
|
||||
func (d *decoder) uint16() uint16 {
|
||||
x := d.order.Uint16(d.buf[d.offset : d.offset+2])
|
||||
d.offset += 2
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint16(x uint16) {
|
||||
e.order.PutUint16(e.buf[e.offset:e.offset+2], x)
|
||||
e.offset += 2
|
||||
}
|
||||
|
||||
func (d *decoder) uint32() uint32 {
|
||||
x := d.order.Uint32(d.buf[d.offset : d.offset+4])
|
||||
d.offset += 4
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint32(x uint32) {
|
||||
e.order.PutUint32(e.buf[e.offset:e.offset+4], x)
|
||||
e.offset += 4
|
||||
}
|
||||
|
||||
func (d *decoder) uint64() uint64 {
|
||||
x := d.order.Uint64(d.buf[d.offset : d.offset+8])
|
||||
d.offset += 8
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint64(x uint64) {
|
||||
e.order.PutUint64(e.buf[e.offset:e.offset+8], x)
|
||||
e.offset += 8
|
||||
}
|
||||
|
||||
func (d *decoder) int8() int8 { return int8(d.uint8()) }
|
||||
|
||||
func (e *encoder) int8(x int8) { e.uint8(uint8(x)) }
|
||||
|
||||
func (d *decoder) int16() int16 { return int16(d.uint16()) }
|
||||
|
||||
func (e *encoder) int16(x int16) { e.uint16(uint16(x)) }
|
||||
|
||||
func (d *decoder) int32() int32 { return int32(d.uint32()) }
|
||||
|
||||
func (e *encoder) int32(x int32) { e.uint32(uint32(x)) }
|
||||
|
||||
func (d *decoder) int64() int64 { return int64(d.uint64()) }
|
||||
|
||||
func (e *encoder) int64(x int64) { e.uint64(uint64(x)) }
|
||||
|
||||
func (d *decoder) value(v reflect.Value) {
|
||||
switch v.Kind() {
|
||||
case reflect.Array:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
d.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
l := v.NumField()
|
||||
for i := 0; i < l; i++ {
|
||||
// Note: Calling v.CanSet() below is an optimization.
|
||||
// It would be sufficient to check the field name,
|
||||
// but creating the StructField info for each field is
|
||||
// costly (run "go test -bench=ReadStruct" and compare
|
||||
// results when making changes to this code).
|
||||
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
|
||||
d.value(v)
|
||||
} else {
|
||||
d.skip(v)
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
d.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Bool:
|
||||
v.SetBool(d.bool())
|
||||
|
||||
case reflect.Int8:
|
||||
v.SetInt(int64(d.int8()))
|
||||
case reflect.Int16:
|
||||
v.SetInt(int64(d.int16()))
|
||||
case reflect.Int32:
|
||||
v.SetInt(int64(d.int32()))
|
||||
case reflect.Int64:
|
||||
v.SetInt(d.int64())
|
||||
|
||||
case reflect.Uint8:
|
||||
v.SetUint(uint64(d.uint8()))
|
||||
case reflect.Uint16:
|
||||
v.SetUint(uint64(d.uint16()))
|
||||
case reflect.Uint32:
|
||||
v.SetUint(uint64(d.uint32()))
|
||||
case reflect.Uint64:
|
||||
v.SetUint(d.uint64())
|
||||
|
||||
case reflect.Float32:
|
||||
v.SetFloat(float64(math.Float32frombits(d.uint32())))
|
||||
case reflect.Float64:
|
||||
v.SetFloat(math.Float64frombits(d.uint64()))
|
||||
|
||||
case reflect.Complex64:
|
||||
v.SetComplex(complex(
|
||||
float64(math.Float32frombits(d.uint32())),
|
||||
float64(math.Float32frombits(d.uint32())),
|
||||
))
|
||||
case reflect.Complex128:
|
||||
v.SetComplex(complex(
|
||||
math.Float64frombits(d.uint64()),
|
||||
math.Float64frombits(d.uint64()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *encoder) value(v reflect.Value) {
|
||||
switch v.Kind() {
|
||||
case reflect.Array:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
e.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
l := v.NumField()
|
||||
for i := 0; i < l; i++ {
|
||||
// see comment for corresponding code in decoder.value()
|
||||
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
|
||||
e.value(v)
|
||||
} else {
|
||||
e.skip(v)
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
e.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Bool:
|
||||
e.bool(v.Bool())
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Int8:
|
||||
e.int8(int8(v.Int()))
|
||||
case reflect.Int16:
|
||||
e.int16(int16(v.Int()))
|
||||
case reflect.Int32:
|
||||
e.int32(int32(v.Int()))
|
||||
case reflect.Int64:
|
||||
e.int64(v.Int())
|
||||
}
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Uint8:
|
||||
e.uint8(uint8(v.Uint()))
|
||||
case reflect.Uint16:
|
||||
e.uint16(uint16(v.Uint()))
|
||||
case reflect.Uint32:
|
||||
e.uint32(uint32(v.Uint()))
|
||||
case reflect.Uint64:
|
||||
e.uint64(v.Uint())
|
||||
}
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Float32:
|
||||
e.uint32(math.Float32bits(float32(v.Float())))
|
||||
case reflect.Float64:
|
||||
e.uint64(math.Float64bits(v.Float()))
|
||||
}
|
||||
|
||||
case reflect.Complex64, reflect.Complex128:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Complex64:
|
||||
x := v.Complex()
|
||||
e.uint32(math.Float32bits(float32(real(x))))
|
||||
e.uint32(math.Float32bits(float32(imag(x))))
|
||||
case reflect.Complex128:
|
||||
x := v.Complex()
|
||||
e.uint64(math.Float64bits(real(x)))
|
||||
e.uint64(math.Float64bits(imag(x)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *decoder) skip(v reflect.Value) {
|
||||
d.offset += dataSize(v)
|
||||
}
|
||||
|
||||
func (e *encoder) skip(v reflect.Value) {
|
||||
n := dataSize(v)
|
||||
zero := e.buf[e.offset : e.offset+n]
|
||||
for i := range zero {
|
||||
zero[i] = 0
|
||||
}
|
||||
e.offset += n
|
||||
}
|
||||
|
||||
// intDataSize returns the size of the data required to represent the data when encoded.
|
||||
// It returns zero if the type cannot be implemented by the fast path in Read or Write.
|
||||
func intDataSize(data any) int {
|
||||
switch data := data.(type) {
|
||||
case bool, int8, uint8, *bool, *int8, *uint8:
|
||||
return 1
|
||||
case []bool:
|
||||
return len(data)
|
||||
case []int8:
|
||||
return len(data)
|
||||
case []uint8:
|
||||
return len(data)
|
||||
case int16, uint16, *int16, *uint16:
|
||||
return 2
|
||||
case []int16:
|
||||
return 2 * len(data)
|
||||
case []uint16:
|
||||
return 2 * len(data)
|
||||
case int32, uint32, *int32, *uint32:
|
||||
return 4
|
||||
case []int32:
|
||||
return 4 * len(data)
|
||||
case []uint32:
|
||||
return 4 * len(data)
|
||||
case int64, uint64, *int64, *uint64:
|
||||
return 8
|
||||
case []int64:
|
||||
return 8 * len(data)
|
||||
case []uint64:
|
||||
return 8 * len(data)
|
||||
case float32, *float32:
|
||||
return 4
|
||||
case float64, *float64:
|
||||
return 8
|
||||
case []float32:
|
||||
return 4 * len(data)
|
||||
case []float64:
|
||||
return 8 * len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
18
common/binary/export.go
Normal file
18
common/binary/export.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package binary
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func DataSize(t reflect.Value) int {
|
||||
return dataSize(t)
|
||||
}
|
||||
|
||||
func EncodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
|
||||
(&encoder{order: order, buf: buf}).value(v)
|
||||
}
|
||||
|
||||
func DecodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
|
||||
(&decoder{order: order, buf: buf}).value(v)
|
||||
}
|
14
common/binary/native_endian_big.go
Normal file
14
common/binary/native_endian_big.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build armbe || arm64be || m68k || mips || mips64 || mips64p32 || ppc || ppc64 || s390 || s390x || shbe || sparc || sparc64
|
||||
|
||||
package binary
|
||||
|
||||
type nativeEndian struct {
|
||||
bigEndian
|
||||
}
|
||||
|
||||
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var NativeEndian nativeEndian
|
14
common/binary/native_endian_little.go
Normal file
14
common/binary/native_endian_little.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build 386 || amd64 || amd64p32 || alpha || arm || arm64 || loong64 || mipsle || mips64le || mips64p32le || nios2 || ppc64le || riscv || riscv64 || sh || wasm
|
||||
|
||||
package binary
|
||||
|
||||
type nativeEndian struct {
|
||||
littleEndian
|
||||
}
|
||||
|
||||
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var NativeEndian nativeEndian
|
166
common/binary/varint.go
Normal file
166
common/binary/varint.go
Normal file
|
@ -0,0 +1,166 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package binary
|
||||
|
||||
// This file implements "varint" encoding of 64-bit integers.
|
||||
// The encoding is:
|
||||
// - unsigned integers are serialized 7 bits at a time, starting with the
|
||||
// least significant bits
|
||||
// - the most significant bit (msb) in each output byte indicates if there
|
||||
// is a continuation byte (msb = 1)
|
||||
// - signed integers are mapped to unsigned integers using "zig-zag"
|
||||
// encoding: Positive values x are written as 2*x + 0, negative values
|
||||
// are written as 2*(^x) + 1; that is, negative numbers are complemented
|
||||
// and whether to complement is encoded in bit 0.
|
||||
//
|
||||
// Design note:
|
||||
// At most 10 bytes are needed for 64-bit values. The encoding could
|
||||
// be more dense: a full 64-bit value needs an extra byte just to hold bit 63.
|
||||
// Instead, the msb of the previous byte could be used to hold bit 63 since we
|
||||
// know there can't be more than 64 bits. This is a trivial improvement and
|
||||
// would reduce the maximum encoding length to 9 bytes. However, it breaks the
|
||||
// invariant that the msb is always the "continuation bit" and thus makes the
|
||||
// format incompatible with a varint encoding for larger numbers (say 128-bit).
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer.
|
||||
const (
|
||||
MaxVarintLen16 = 3
|
||||
MaxVarintLen32 = 5
|
||||
MaxVarintLen64 = 10
|
||||
)
|
||||
|
||||
// AppendUvarint appends the varint-encoded form of x,
|
||||
// as generated by [PutUvarint], to buf and returns the extended buffer.
|
||||
func AppendUvarint(buf []byte, x uint64) []byte {
|
||||
for x >= 0x80 {
|
||||
buf = append(buf, byte(x)|0x80)
|
||||
x >>= 7
|
||||
}
|
||||
return append(buf, byte(x))
|
||||
}
|
||||
|
||||
// PutUvarint encodes a uint64 into buf and returns the number of bytes written.
|
||||
// If the buffer is too small, PutUvarint will panic.
|
||||
func PutUvarint(buf []byte, x uint64) int {
|
||||
i := 0
|
||||
for x >= 0x80 {
|
||||
buf[i] = byte(x) | 0x80
|
||||
x >>= 7
|
||||
i++
|
||||
}
|
||||
buf[i] = byte(x)
|
||||
return i + 1
|
||||
}
|
||||
|
||||
// Uvarint decodes a uint64 from buf and returns that value and the
|
||||
// number of bytes read (> 0). If an error occurred, the value is 0
|
||||
// and the number of bytes n is <= 0 meaning:
|
||||
//
|
||||
// n == 0: buf too small
|
||||
// n < 0: value larger than 64 bits (overflow)
|
||||
// and -n is the number of bytes read
|
||||
func Uvarint(buf []byte) (uint64, int) {
|
||||
var x uint64
|
||||
var s uint
|
||||
for i, b := range buf {
|
||||
if i == MaxVarintLen64 {
|
||||
// Catch byte reads past MaxVarintLen64.
|
||||
// See issue https://golang.org/issues/41185
|
||||
return 0, -(i + 1) // overflow
|
||||
}
|
||||
if b < 0x80 {
|
||||
if i == MaxVarintLen64-1 && b > 1 {
|
||||
return 0, -(i + 1) // overflow
|
||||
}
|
||||
return x | uint64(b)<<s, i + 1
|
||||
}
|
||||
x |= uint64(b&0x7f) << s
|
||||
s += 7
|
||||
}
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
// AppendVarint appends the varint-encoded form of x,
|
||||
// as generated by [PutVarint], to buf and returns the extended buffer.
|
||||
func AppendVarint(buf []byte, x int64) []byte {
|
||||
ux := uint64(x) << 1
|
||||
if x < 0 {
|
||||
ux = ^ux
|
||||
}
|
||||
return AppendUvarint(buf, ux)
|
||||
}
|
||||
|
||||
// PutVarint encodes an int64 into buf and returns the number of bytes written.
|
||||
// If the buffer is too small, PutVarint will panic.
|
||||
func PutVarint(buf []byte, x int64) int {
|
||||
ux := uint64(x) << 1
|
||||
if x < 0 {
|
||||
ux = ^ux
|
||||
}
|
||||
return PutUvarint(buf, ux)
|
||||
}
|
||||
|
||||
// Varint decodes an int64 from buf and returns that value and the
|
||||
// number of bytes read (> 0). If an error occurred, the value is 0
|
||||
// and the number of bytes n is <= 0 with the following meaning:
|
||||
//
|
||||
// n == 0: buf too small
|
||||
// n < 0: value larger than 64 bits (overflow)
|
||||
// and -n is the number of bytes read
|
||||
func Varint(buf []byte) (int64, int) {
|
||||
ux, n := Uvarint(buf) // ok to continue in presence of error
|
||||
x := int64(ux >> 1)
|
||||
if ux&1 != 0 {
|
||||
x = ^x
|
||||
}
|
||||
return x, n
|
||||
}
|
||||
|
||||
var errOverflow = errors.New("binary: varint overflows a 64-bit integer")
|
||||
|
||||
// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64.
|
||||
// The error is [io.EOF] only if no bytes were read.
|
||||
// If an [io.EOF] happens after reading some but not all the bytes,
|
||||
// ReadUvarint returns [io.ErrUnexpectedEOF].
|
||||
func ReadUvarint(r io.ByteReader) (uint64, error) {
|
||||
var x uint64
|
||||
var s uint
|
||||
for i := 0; i < MaxVarintLen64; i++ {
|
||||
b, err := r.ReadByte()
|
||||
if err != nil {
|
||||
if i > 0 && err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return x, err
|
||||
}
|
||||
if b < 0x80 {
|
||||
if i == MaxVarintLen64-1 && b > 1 {
|
||||
return x, errOverflow
|
||||
}
|
||||
return x | uint64(b)<<s, nil
|
||||
}
|
||||
x |= uint64(b&0x7f) << s
|
||||
s += 7
|
||||
}
|
||||
return x, errOverflow
|
||||
}
|
||||
|
||||
// ReadVarint reads an encoded signed integer from r and returns it as an int64.
|
||||
// The error is [io.EOF] only if no bytes were read.
|
||||
// If an [io.EOF] happens after reading some but not all the bytes,
|
||||
// ReadVarint returns [io.ErrUnexpectedEOF].
|
||||
func ReadVarint(r io.ByteReader) (int64, error) {
|
||||
ux, err := ReadUvarint(r) // ok to continue in presence of error
|
||||
x := int64(ux >> 1)
|
||||
if ux&1 != 0 {
|
||||
x = ^x
|
||||
}
|
||||
return x, err
|
||||
}
|
|
@ -5,11 +5,10 @@ package buf
|
|||
import (
|
||||
"errors"
|
||||
"math/bits"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var DefaultAllocator = newDefaultAllocer()
|
||||
var DefaultAllocator = newDefaultAllocator()
|
||||
|
||||
type Allocator interface {
|
||||
Get(size int) []byte
|
||||
|
@ -18,36 +17,72 @@ type Allocator interface {
|
|||
|
||||
// defaultAllocator for incoming frames, optimized to prevent overwriting after zeroing
|
||||
type defaultAllocator struct {
|
||||
buffers []sync.Pool
|
||||
buffers [11]sync.Pool
|
||||
}
|
||||
|
||||
// NewAllocator initiates a []byte allocator for frames less than 65536 bytes,
|
||||
// the waste(memory fragmentation) of space allocation is guaranteed to be
|
||||
// no more than 50%.
|
||||
func newDefaultAllocer() Allocator {
|
||||
alloc := new(defaultAllocator)
|
||||
alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K
|
||||
for k := range alloc.buffers {
|
||||
i := k
|
||||
alloc.buffers[k].New = func() any {
|
||||
return make([]byte, 1<<uint32(i))
|
||||
}
|
||||
func newDefaultAllocator() Allocator {
|
||||
return &defaultAllocator{
|
||||
buffers: [...]sync.Pool{ // 64B -> 64K
|
||||
{New: func() any { return new([1 << 6]byte) }},
|
||||
{New: func() any { return new([1 << 7]byte) }},
|
||||
{New: func() any { return new([1 << 8]byte) }},
|
||||
{New: func() any { return new([1 << 9]byte) }},
|
||||
{New: func() any { return new([1 << 10]byte) }},
|
||||
{New: func() any { return new([1 << 11]byte) }},
|
||||
{New: func() any { return new([1 << 12]byte) }},
|
||||
{New: func() any { return new([1 << 13]byte) }},
|
||||
{New: func() any { return new([1 << 14]byte) }},
|
||||
{New: func() any { return new([1 << 15]byte) }},
|
||||
{New: func() any { return new([1 << 16]byte) }},
|
||||
},
|
||||
}
|
||||
return alloc
|
||||
}
|
||||
|
||||
// Get a []byte from pool with most appropriate cap
|
||||
func (alloc *defaultAllocator) Get(size int) []byte {
|
||||
if size <= 0 || size > 65536 {
|
||||
panic("alloc bad size: " + strconv.Itoa(size))
|
||||
return nil
|
||||
}
|
||||
|
||||
bits := msb(size)
|
||||
if size == 1<<bits {
|
||||
return alloc.buffers[bits].Get().([]byte)[:size]
|
||||
var index uint16
|
||||
if size > 64 {
|
||||
index = msb(size)
|
||||
if size != 1<<index {
|
||||
index += 1
|
||||
}
|
||||
index -= 6
|
||||
}
|
||||
|
||||
return alloc.buffers[bits+1].Get().([]byte)[:size]
|
||||
buffer := alloc.buffers[index].Get()
|
||||
switch index {
|
||||
case 0:
|
||||
return buffer.(*[1 << 6]byte)[:size]
|
||||
case 1:
|
||||
return buffer.(*[1 << 7]byte)[:size]
|
||||
case 2:
|
||||
return buffer.(*[1 << 8]byte)[:size]
|
||||
case 3:
|
||||
return buffer.(*[1 << 9]byte)[:size]
|
||||
case 4:
|
||||
return buffer.(*[1 << 10]byte)[:size]
|
||||
case 5:
|
||||
return buffer.(*[1 << 11]byte)[:size]
|
||||
case 6:
|
||||
return buffer.(*[1 << 12]byte)[:size]
|
||||
case 7:
|
||||
return buffer.(*[1 << 13]byte)[:size]
|
||||
case 8:
|
||||
return buffer.(*[1 << 14]byte)[:size]
|
||||
case 9:
|
||||
return buffer.(*[1 << 15]byte)[:size]
|
||||
case 10:
|
||||
return buffer.(*[1 << 16]byte)[:size]
|
||||
default:
|
||||
panic("invalid pool index")
|
||||
}
|
||||
}
|
||||
|
||||
// Put returns a []byte to pool for future use,
|
||||
|
@ -57,10 +92,37 @@ func (alloc *defaultAllocator) Put(buf []byte) error {
|
|||
if cap(buf) == 0 || cap(buf) > 65536 || cap(buf) != 1<<bits {
|
||||
return errors.New("allocator Put() incorrect buffer size")
|
||||
}
|
||||
bits -= 6
|
||||
buf = buf[:cap(buf)]
|
||||
|
||||
//nolint
|
||||
//lint:ignore SA6002 ignore temporarily
|
||||
alloc.buffers[bits].Put(buf)
|
||||
switch bits {
|
||||
case 0:
|
||||
alloc.buffers[bits].Put((*[1 << 6]byte)(buf))
|
||||
case 1:
|
||||
alloc.buffers[bits].Put((*[1 << 7]byte)(buf))
|
||||
case 2:
|
||||
alloc.buffers[bits].Put((*[1 << 8]byte)(buf))
|
||||
case 3:
|
||||
alloc.buffers[bits].Put((*[1 << 9]byte)(buf))
|
||||
case 4:
|
||||
alloc.buffers[bits].Put((*[1 << 10]byte)(buf))
|
||||
case 5:
|
||||
alloc.buffers[bits].Put((*[1 << 11]byte)(buf))
|
||||
case 6:
|
||||
alloc.buffers[bits].Put((*[1 << 12]byte)(buf))
|
||||
case 7:
|
||||
alloc.buffers[bits].Put((*[1 << 13]byte)(buf))
|
||||
case 8:
|
||||
alloc.buffers[bits].Put((*[1 << 14]byte)(buf))
|
||||
case 9:
|
||||
alloc.buffers[bits].Put((*[1 << 15]byte)(buf))
|
||||
case 10:
|
||||
alloc.buffers[bits].Put((*[1 << 16]byte)(buf))
|
||||
default:
|
||||
panic("invalid pool index")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -4,109 +4,70 @@ import (
|
|||
"crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/debug"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const (
|
||||
ReversedHeader = 1024
|
||||
BufferSize = 32 * 1024
|
||||
UDPBufferSize = 16 * 1024
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
)
|
||||
|
||||
type Buffer struct {
|
||||
data []byte
|
||||
start int
|
||||
end int
|
||||
refs int32
|
||||
managed bool
|
||||
closed bool
|
||||
data []byte
|
||||
start int
|
||||
end int
|
||||
capacity int
|
||||
refs atomic.Int32
|
||||
managed bool
|
||||
}
|
||||
|
||||
func New() *Buffer {
|
||||
return &Buffer{
|
||||
data: Get(BufferSize),
|
||||
start: ReversedHeader,
|
||||
end: ReversedHeader,
|
||||
managed: true,
|
||||
data: Get(BufferSize),
|
||||
capacity: BufferSize,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
||||
func NewPacket() *Buffer {
|
||||
return &Buffer{
|
||||
data: Get(UDPBufferSize),
|
||||
start: ReversedHeader,
|
||||
end: ReversedHeader,
|
||||
managed: true,
|
||||
data: Get(UDPBufferSize),
|
||||
capacity: UDPBufferSize,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
||||
func NewSize(size int) *Buffer {
|
||||
if size > 65535 {
|
||||
if size == 0 {
|
||||
return &Buffer{}
|
||||
} else if size > 65535 {
|
||||
return &Buffer{
|
||||
data: make([]byte, size),
|
||||
data: make([]byte, size),
|
||||
capacity: size,
|
||||
}
|
||||
}
|
||||
return &Buffer{
|
||||
data: Get(size),
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
||||
func StackNew() *Buffer {
|
||||
if common.UnsafeBuffer {
|
||||
return &Buffer{
|
||||
data: make([]byte, BufferSize),
|
||||
start: ReversedHeader,
|
||||
end: ReversedHeader,
|
||||
}
|
||||
} else {
|
||||
return New()
|
||||
}
|
||||
}
|
||||
|
||||
func StackNewPacket() *Buffer {
|
||||
if common.UnsafeBuffer {
|
||||
return &Buffer{
|
||||
data: make([]byte, UDPBufferSize),
|
||||
start: ReversedHeader,
|
||||
end: ReversedHeader,
|
||||
}
|
||||
} else {
|
||||
return NewPacket()
|
||||
}
|
||||
}
|
||||
|
||||
func StackNewSize(size int) *Buffer {
|
||||
if common.UnsafeBuffer {
|
||||
return &Buffer{
|
||||
data: Make(size),
|
||||
}
|
||||
} else {
|
||||
return NewSize(size)
|
||||
data: Get(size),
|
||||
capacity: size,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
||||
func As(data []byte) *Buffer {
|
||||
return &Buffer{
|
||||
data: data,
|
||||
end: len(data),
|
||||
data: data,
|
||||
end: len(data),
|
||||
capacity: len(data),
|
||||
}
|
||||
}
|
||||
|
||||
func With(data []byte) *Buffer {
|
||||
return &Buffer{
|
||||
data: data,
|
||||
data: data,
|
||||
capacity: len(data),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) Closed() bool {
|
||||
return b.closed
|
||||
}
|
||||
|
||||
func (b *Buffer) Byte(index int) byte {
|
||||
return b.data[b.start+index]
|
||||
}
|
||||
|
@ -117,8 +78,8 @@ func (b *Buffer) SetByte(index int, value byte) {
|
|||
|
||||
func (b *Buffer) Extend(n int) []byte {
|
||||
end := b.end + n
|
||||
if end > cap(b.data) {
|
||||
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",end " + strconv.Itoa(b.end) + ", need " + strconv.Itoa(n))
|
||||
if end > b.capacity {
|
||||
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",end ", b.end, ", need ", n))
|
||||
}
|
||||
ext := b.data[b.end:end]
|
||||
b.end = end
|
||||
|
@ -140,14 +101,14 @@ func (b *Buffer) Write(data []byte) (n int, err error) {
|
|||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n = copy(b.data[b.end:], data)
|
||||
n = copy(b.data[b.end:b.capacity], data)
|
||||
b.end += n
|
||||
return
|
||||
}
|
||||
|
||||
func (b *Buffer) ExtendHeader(n int) []byte {
|
||||
if b.start < n {
|
||||
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",start " + strconv.Itoa(b.start) + ", need " + strconv.Itoa(n))
|
||||
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",start ", b.start, ", need ", n))
|
||||
}
|
||||
b.start -= n
|
||||
return b.data[b.start : b.start+n]
|
||||
|
@ -168,13 +129,13 @@ func (b *Buffer) WriteByte(d byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (b *Buffer) ReadOnceFrom(r io.Reader) (int64, error) {
|
||||
func (b *Buffer) ReadOnceFrom(r io.Reader) (int, error) {
|
||||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n, err := r.Read(b.FreeBytes())
|
||||
b.end += n
|
||||
return int64(n), err
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) {
|
||||
|
@ -188,7 +149,8 @@ func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) {
|
|||
|
||||
func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
|
||||
if min <= 0 {
|
||||
return b.ReadOnceFrom(r)
|
||||
n, err := b.ReadOnceFrom(r)
|
||||
return int64(n), err
|
||||
}
|
||||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
|
@ -199,7 +161,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
|
|||
}
|
||||
|
||||
func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) {
|
||||
if b.end+size > b.Cap() {
|
||||
if b.end+size > b.capacity {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n, err = io.ReadFull(r, b.data[b.end:b.end+size])
|
||||
|
@ -229,8 +191,16 @@ func (b *Buffer) WriteRune(s rune) (int, error) {
|
|||
return b.Write([]byte{byte(s)})
|
||||
}
|
||||
|
||||
func (b *Buffer) WriteString(s string) (int, error) {
|
||||
return b.Write([]byte(s))
|
||||
func (b *Buffer) WriteString(s string) (n int, err error) {
|
||||
if len(s) == 0 {
|
||||
return
|
||||
}
|
||||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n = copy(b.data[b.end:b.capacity], s)
|
||||
b.end += n
|
||||
return
|
||||
}
|
||||
|
||||
func (b *Buffer) WriteZero() error {
|
||||
|
@ -243,13 +213,10 @@ func (b *Buffer) WriteZero() error {
|
|||
}
|
||||
|
||||
func (b *Buffer) WriteZeroN(n int) error {
|
||||
if b.end+n > b.Cap() {
|
||||
if b.end+n > b.capacity {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
for i := b.end; i <= b.end+n; i++ {
|
||||
b.data[i] = 0
|
||||
}
|
||||
b.end += n
|
||||
common.ClearArray(b.Extend(n))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -292,40 +259,63 @@ func (b *Buffer) Resize(start, end int) {
|
|||
b.end = b.start + end
|
||||
}
|
||||
|
||||
func (b *Buffer) Reset() {
|
||||
b.start = ReversedHeader
|
||||
b.end = ReversedHeader
|
||||
func (b *Buffer) Reserve(n int) {
|
||||
if n > b.capacity {
|
||||
panic(F.ToString("buffer overflow: capacity ", b.capacity, ", need ", n))
|
||||
}
|
||||
b.capacity -= n
|
||||
}
|
||||
|
||||
func (b *Buffer) FullReset() {
|
||||
func (b *Buffer) OverCap(n int) {
|
||||
if b.capacity+n > len(b.data) {
|
||||
panic(F.ToString("buffer overflow: capacity ", len(b.data), ", need ", b.capacity+n))
|
||||
}
|
||||
b.capacity += n
|
||||
}
|
||||
|
||||
func (b *Buffer) Reset() {
|
||||
b.start = 0
|
||||
b.end = 0
|
||||
b.capacity = len(b.data)
|
||||
}
|
||||
|
||||
// Deprecated: use Reset instead.
|
||||
func (b *Buffer) FullReset() {
|
||||
b.Reset()
|
||||
}
|
||||
|
||||
func (b *Buffer) IncRef() {
|
||||
atomic.AddInt32(&b.refs, 1)
|
||||
b.refs.Add(1)
|
||||
}
|
||||
|
||||
func (b *Buffer) DecRef() {
|
||||
atomic.AddInt32(&b.refs, -1)
|
||||
b.refs.Add(-1)
|
||||
}
|
||||
|
||||
func (b *Buffer) Release() {
|
||||
if b == nil || b.closed || !b.managed {
|
||||
if b == nil || !b.managed {
|
||||
return
|
||||
}
|
||||
if atomic.LoadInt32(&b.refs) > 0 {
|
||||
if b.refs.Load() > 0 {
|
||||
return
|
||||
}
|
||||
common.Must(Put(b.data))
|
||||
*b = Buffer{closed: true}
|
||||
*b = Buffer{}
|
||||
}
|
||||
|
||||
func (b *Buffer) Cut(start int, end int) *Buffer {
|
||||
b.start += start
|
||||
b.end = len(b.data) - end
|
||||
return &Buffer{
|
||||
data: b.data[b.start:b.end],
|
||||
func (b *Buffer) Leak() {
|
||||
if debug.Enabled {
|
||||
if b == nil || !b.managed {
|
||||
return
|
||||
}
|
||||
refs := b.refs.Load()
|
||||
if refs == 0 {
|
||||
panic("leaking buffer")
|
||||
} else {
|
||||
panic(F.ToString("leaking buffer with ", refs, " references"))
|
||||
}
|
||||
} else {
|
||||
b.Release()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -338,6 +328,10 @@ func (b *Buffer) Len() int {
|
|||
}
|
||||
|
||||
func (b *Buffer) Cap() int {
|
||||
return b.capacity
|
||||
}
|
||||
|
||||
func (b *Buffer) RawCap() int {
|
||||
return len(b.data)
|
||||
}
|
||||
|
||||
|
@ -345,10 +339,6 @@ func (b *Buffer) Bytes() []byte {
|
|||
return b.data[b.start:b.end]
|
||||
}
|
||||
|
||||
func (b *Buffer) Slice() []byte {
|
||||
return b.data
|
||||
}
|
||||
|
||||
func (b *Buffer) From(n int) []byte {
|
||||
return b.data[b.start+n : b.end]
|
||||
}
|
||||
|
@ -366,11 +356,11 @@ func (b *Buffer) Index(start int) []byte {
|
|||
}
|
||||
|
||||
func (b *Buffer) FreeLen() int {
|
||||
return b.Cap() - b.end
|
||||
return b.capacity - b.end
|
||||
}
|
||||
|
||||
func (b *Buffer) FreeBytes() []byte {
|
||||
return b.data[b.end:b.Cap()]
|
||||
return b.data[b.end:b.capacity]
|
||||
}
|
||||
|
||||
func (b *Buffer) IsEmpty() bool {
|
||||
|
@ -378,7 +368,7 @@ func (b *Buffer) IsEmpty() bool {
|
|||
}
|
||||
|
||||
func (b *Buffer) IsFull() bool {
|
||||
return b.end == b.Cap()
|
||||
return b.end == b.capacity
|
||||
}
|
||||
|
||||
func (b *Buffer) ToOwned() *Buffer {
|
||||
|
@ -386,5 +376,6 @@ func (b *Buffer) ToOwned() *Buffer {
|
|||
copy(n.data[b.start:b.end], b.data[b.start:b.end])
|
||||
n.start = b.start
|
||||
n.end = b.end
|
||||
n.capacity = b.capacity
|
||||
return n
|
||||
}
|
||||
|
|
8
common/buf/buffer_low_memory.go
Normal file
8
common/buf/buffer_low_memory.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
//go:build with_low_memory
|
||||
|
||||
package buf
|
||||
|
||||
const (
|
||||
BufferSize = 16 * 1024
|
||||
UDPBufferSize = 8 * 1024
|
||||
)
|
8
common/buf/buffer_standard.go
Normal file
8
common/buf/buffer_standard.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
//go:build !with_low_memory
|
||||
|
||||
package buf
|
||||
|
||||
const (
|
||||
BufferSize = 32 * 1024
|
||||
UDPBufferSize = 16 * 1024
|
||||
)
|
|
@ -1,9 +0,0 @@
|
|||
package buf
|
||||
|
||||
import "encoding/hex"
|
||||
|
||||
func EncodeHexString(src []byte) string {
|
||||
dst := Make(hex.EncodedLen(len(src)))
|
||||
hex.Encode(dst, src)
|
||||
return string(dst)
|
||||
}
|
|
@ -16,6 +16,14 @@ func ToSliceMulti(buffers []*Buffer) [][]byte {
|
|||
})
|
||||
}
|
||||
|
||||
func CopyMulti(toBuffer []byte, buffers []*Buffer) int {
|
||||
var n int
|
||||
for _, buffer := range buffers {
|
||||
n += copy(toBuffer[n:], buffer.Bytes())
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func ReleaseMulti(buffers []*Buffer) {
|
||||
for _, buffer := range buffers {
|
||||
buffer.Release()
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package buf
|
||||
|
||||
func Get(size int) []byte {
|
||||
if size == 0 {
|
||||
return nil
|
||||
}
|
||||
return DefaultAllocator.Get(size)
|
||||
}
|
||||
|
||||
|
@ -8,43 +11,7 @@ func Put(buf []byte) error {
|
|||
return DefaultAllocator.Put(buf)
|
||||
}
|
||||
|
||||
// Deprecated: use array instead.
|
||||
func Make(size int) []byte {
|
||||
var buffer []byte
|
||||
switch {
|
||||
case size <= 2:
|
||||
buffer = make([]byte, 2)
|
||||
case size <= 4:
|
||||
buffer = make([]byte, 4)
|
||||
case size <= 8:
|
||||
buffer = make([]byte, 8)
|
||||
case size <= 16:
|
||||
buffer = make([]byte, 16)
|
||||
case size <= 32:
|
||||
buffer = make([]byte, 32)
|
||||
case size <= 64:
|
||||
buffer = make([]byte, 64)
|
||||
case size <= 128:
|
||||
buffer = make([]byte, 128)
|
||||
case size <= 256:
|
||||
buffer = make([]byte, 256)
|
||||
case size <= 512:
|
||||
buffer = make([]byte, 512)
|
||||
case size <= 1024:
|
||||
buffer = make([]byte, 1024)
|
||||
case size <= 2048:
|
||||
buffer = make([]byte, 2048)
|
||||
case size <= 4096:
|
||||
buffer = make([]byte, 4096)
|
||||
case size <= 8192:
|
||||
buffer = make([]byte, 8192)
|
||||
case size <= 16384:
|
||||
buffer = make([]byte, 16384)
|
||||
case size <= 32768:
|
||||
buffer = make([]byte, 32768)
|
||||
case size <= 65535:
|
||||
buffer = make([]byte, 65535)
|
||||
default:
|
||||
return make([]byte, size)
|
||||
}
|
||||
return buffer[:size]
|
||||
return make([]byte, size)
|
||||
}
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
//go:build !disable_unsafe
|
||||
|
||||
package buf
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
type dbgVar struct {
|
||||
name string
|
||||
value *int32
|
||||
}
|
||||
|
||||
//go:linkname dbgvars runtime.dbgvars
|
||||
var dbgvars any
|
||||
|
||||
// go.info.runtime.dbgvars: relocation target go.info.[]github.com/sagernet/sing/common/buf.dbgVar not defined
|
||||
// var dbgvars []dbgVar
|
||||
|
||||
func init() {
|
||||
if !common.UnsafeBuffer {
|
||||
return
|
||||
}
|
||||
debugVars := *(*[]dbgVar)(unsafe.Pointer(&dbgvars))
|
||||
for _, v := range debugVars {
|
||||
if v.name == "invalidptr" {
|
||||
*v.value = 0
|
||||
return
|
||||
}
|
||||
}
|
||||
panic("can't disable invalidptr")
|
||||
}
|
34
common/bufio/addr_bsd.go
Normal file
34
common/bufio/addr_bsd.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package bufio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
|
||||
if destination.Addr().Is4() {
|
||||
sa := unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: destination.Addr().As4(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = unix.SizeofSockaddrInet4
|
||||
} else {
|
||||
sa := unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: destination.Addr().As16(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = unix.SizeofSockaddrInet6
|
||||
}
|
||||
return
|
||||
}
|
|
@ -9,19 +9,20 @@ import (
|
|||
|
||||
type AddrConn struct {
|
||||
net.Conn
|
||||
M.Metadata
|
||||
Source M.Socksaddr
|
||||
Destination M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *AddrConn) LocalAddr() net.Addr {
|
||||
if c.Metadata.Destination.IsValid() {
|
||||
return c.Metadata.Destination.TCPAddr()
|
||||
if c.Destination.IsValid() {
|
||||
return c.Destination.TCPAddr()
|
||||
}
|
||||
return c.Conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *AddrConn) RemoteAddr() net.Addr {
|
||||
if c.Metadata.Source.IsValid() {
|
||||
return c.Metadata.Source.TCPAddr()
|
||||
if c.Source.IsValid() {
|
||||
return c.Source.TCPAddr()
|
||||
}
|
||||
return c.Conn.RemoteAddr()
|
||||
}
|
30
common/bufio/addr_linux.go
Normal file
30
common/bufio/addr_linux.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
|
||||
if destination.Addr().Is4() {
|
||||
sa := unix.RawSockaddrInet4{
|
||||
Family: unix.AF_INET,
|
||||
Addr: destination.Addr().As4(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = unix.SizeofSockaddrInet4
|
||||
} else {
|
||||
sa := unix.RawSockaddrInet6{
|
||||
Family: unix.AF_INET6,
|
||||
Addr: destination.Addr().As16(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = unix.SizeofSockaddrInet6
|
||||
}
|
||||
return
|
||||
}
|
30
common/bufio/addr_windows.go
Normal file
30
common/bufio/addr_windows.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen int32) {
|
||||
if destination.Addr().Is4() {
|
||||
sa := windows.RawSockaddrInet4{
|
||||
Family: windows.AF_INET,
|
||||
Addr: destination.Addr().As4(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = int32(unsafe.Sizeof(sa))
|
||||
} else {
|
||||
sa := windows.RawSockaddrInet6{
|
||||
Family: windows.AF_INET6,
|
||||
Addr: destination.Addr().As16(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = int32(unsafe.Sizeof(sa))
|
||||
}
|
||||
return
|
||||
}
|
81
common/bufio/append.go
Normal file
81
common/bufio/append.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type appendConn struct {
|
||||
N.ExtendedConn
|
||||
reader N.ExtendedReader
|
||||
writer N.ExtendedWriter
|
||||
}
|
||||
|
||||
func NewAppendConn(conn N.ExtendedConn, reader N.ExtendedReader, writer N.ExtendedWriter) N.ExtendedConn {
|
||||
return &appendConn{
|
||||
ExtendedConn: conn,
|
||||
reader: reader,
|
||||
writer: writer,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *appendConn) Read(p []byte) (n int, err error) {
|
||||
if c.reader == nil {
|
||||
return c.ExtendedConn.Read(p)
|
||||
} else {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *appendConn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
if c.reader == nil {
|
||||
return c.ExtendedConn.ReadBuffer(buffer)
|
||||
} else {
|
||||
return c.reader.ReadBuffer(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *appendConn) Write(p []byte) (n int, err error) {
|
||||
if c.writer == nil {
|
||||
return c.ExtendedConn.Write(p)
|
||||
} else {
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *appendConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
if c.writer == nil {
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
} else {
|
||||
return c.writer.WriteBuffer(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *appendConn) Close() error {
|
||||
return common.Close(
|
||||
c.ExtendedConn,
|
||||
c.reader,
|
||||
c.writer,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *appendConn) UpstreamReader() any {
|
||||
return c.reader
|
||||
}
|
||||
|
||||
func (c *appendConn) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *appendConn) UpstreamWriter() any {
|
||||
return c.writer
|
||||
}
|
||||
|
||||
func (c *appendConn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *appendConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
165
common/bufio/bind.go
Normal file
165
common/bufio/bind.go
Normal file
|
@ -0,0 +1,165 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type BindPacketConn interface {
|
||||
N.NetPacketConn
|
||||
net.Conn
|
||||
}
|
||||
|
||||
type bindPacketConn struct {
|
||||
N.NetPacketConn
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) BindPacketConn {
|
||||
return &bindPacketConn{
|
||||
NewPacketConn(conn),
|
||||
addr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *bindPacketConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *bindPacketConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.addr)
|
||||
}
|
||||
|
||||
func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
|
||||
if !isReadWaiter {
|
||||
return nil, false
|
||||
}
|
||||
return &bindPacketReadWaiter{readWaiter}, true
|
||||
}
|
||||
|
||||
func (c *bindPacketConn) RemoteAddr() net.Addr {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (c *bindPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
var (
|
||||
_ N.NetPacketConn = (*UnbindPacketConn)(nil)
|
||||
_ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil)
|
||||
)
|
||||
|
||||
type UnbindPacketConn struct {
|
||||
N.ExtendedConn
|
||||
addr M.Socksaddr
|
||||
}
|
||||
|
||||
func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn {
|
||||
return &UnbindPacketConn{
|
||||
NewExtendedConn(conn),
|
||||
M.SocksaddrFromNet(conn.RemoteAddr()),
|
||||
}
|
||||
}
|
||||
|
||||
func NewUnbindPacketConnWithAddr(conn net.Conn, addr M.Socksaddr) N.NetPacketConn {
|
||||
return &UnbindPacketConn{
|
||||
NewExtendedConn(conn),
|
||||
addr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, err = c.ExtendedConn.Read(p)
|
||||
if err == nil {
|
||||
addr = c.addr.UDPAddr()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) WriteTo(p []byte, _ net.Addr) (n int, err error) {
|
||||
return c.ExtendedConn.Write(p)
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
err = c.ExtendedConn.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination = c.addr
|
||||
return
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
|
||||
readWaiter, isReadWaiter := CreateReadWaiter(c.ExtendedConn)
|
||||
if !isReadWaiter {
|
||||
return nil, false
|
||||
}
|
||||
return &unbindPacketReadWaiter{readWaiter, c.addr}, true
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
func NewServerPacketConn(conn net.PacketConn) N.ExtendedConn {
|
||||
return &serverPacketConn{
|
||||
NetPacketConn: NewPacketConn(conn),
|
||||
}
|
||||
}
|
||||
|
||||
type serverPacketConn struct {
|
||||
N.NetPacketConn
|
||||
remoteAddr M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Read(p []byte) (n int, err error) {
|
||||
n, addr, err := c.NetPacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.remoteAddr = M.SocksaddrFromNet(addr)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
destination, err := c.NetPacketConn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.remoteAddr = destination
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Write(p []byte) (n int, err error) {
|
||||
return c.NetPacketConn.WriteTo(p, c.remoteAddr.UDPAddr())
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
return c.NetPacketConn.WritePacket(buffer, c.remoteAddr)
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
|
||||
if !isReadWaiter {
|
||||
return nil, false
|
||||
}
|
||||
return &serverPacketReadWaiter{c, readWaiter}, true
|
||||
}
|
62
common/bufio/bind_wait.go
Normal file
62
common/bufio/bind_wait.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ N.ReadWaiter = (*bindPacketReadWaiter)(nil)
|
||||
|
||||
type bindPacketReadWaiter struct {
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (w *bindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return w.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (w *bindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
buffer, _, err = w.readWaiter.WaitReadPacket()
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*unbindPacketReadWaiter)(nil)
|
||||
|
||||
type unbindPacketReadWaiter struct {
|
||||
readWaiter N.ReadWaiter
|
||||
addr M.Socksaddr
|
||||
}
|
||||
|
||||
func (w *unbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return w.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (w *unbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
buffer, err = w.readWaiter.WaitReadBuffer()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination = w.addr
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.ReadWaiter = (*serverPacketReadWaiter)(nil)
|
||||
|
||||
type serverPacketReadWaiter struct {
|
||||
*serverPacketConn
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (w *serverPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return w.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (w *serverPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
buffer, destination, err := w.readWaiter.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.remoteAddr = destination
|
||||
return
|
||||
}
|
|
@ -2,79 +2,16 @@ package bufio
|
|||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type BufferedReader struct {
|
||||
upstream N.ExtendedReader
|
||||
buffer *buf.Buffer
|
||||
}
|
||||
|
||||
func NewBufferedReader(upstream io.Reader, buffer *buf.Buffer) *BufferedReader {
|
||||
return &BufferedReader{
|
||||
upstream: NewExtendedReader(upstream),
|
||||
buffer: buffer,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BufferedReader) Read(p []byte) (n int, err error) {
|
||||
if r.buffer.Closed() {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
if r.buffer.IsEmpty() {
|
||||
r.buffer.Reset()
|
||||
err = r.upstream.ReadBuffer(r.buffer)
|
||||
if err != nil {
|
||||
r.buffer.Release()
|
||||
return
|
||||
}
|
||||
}
|
||||
return r.buffer.Read(p)
|
||||
}
|
||||
|
||||
func (r *BufferedReader) ReadBuffer(buffer *buf.Buffer) error {
|
||||
if r.buffer.Closed() {
|
||||
return os.ErrClosed
|
||||
}
|
||||
var err error
|
||||
if r.buffer.IsEmpty() {
|
||||
r.buffer.Reset()
|
||||
err = r.upstream.ReadBuffer(r.buffer)
|
||||
if err != nil {
|
||||
r.buffer.Release()
|
||||
return err
|
||||
}
|
||||
}
|
||||
if r.buffer.Len() > buffer.FreeLen() {
|
||||
err = common.Error(buffer.ReadFullFrom(r.buffer, buffer.FreeLen()))
|
||||
} else {
|
||||
err = common.Error(buffer.ReadFullFrom(r.buffer, r.buffer.Len()))
|
||||
}
|
||||
if err != nil {
|
||||
r.buffer.Release()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if r.buffer.Closed() {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
defer r.buffer.Release()
|
||||
return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r.upstream), r.buffer)
|
||||
}
|
||||
|
||||
func (r *BufferedReader) Upstream() any {
|
||||
return r.upstream
|
||||
}
|
||||
|
||||
type BufferedWriter struct {
|
||||
upstream io.Writer
|
||||
buffer *buf.Buffer
|
||||
access sync.Mutex
|
||||
}
|
||||
|
||||
func NewBufferedWriter(upstream io.Writer, buffer *buf.Buffer) *BufferedWriter {
|
||||
|
@ -85,6 +22,11 @@ func NewBufferedWriter(upstream io.Writer, buffer *buf.Buffer) *BufferedWriter {
|
|||
}
|
||||
|
||||
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
if w.buffer == nil {
|
||||
return w.upstream.Write(p)
|
||||
}
|
||||
for {
|
||||
var writeN int
|
||||
writeN, err = w.buffer.Write(p[n:])
|
||||
|
@ -96,10 +38,46 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.buffer.FullReset()
|
||||
w.buffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r), w.buffer)
|
||||
func (w *BufferedWriter) WriteByte(c byte) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
if w.buffer == nil {
|
||||
return common.Error(w.upstream.Write([]byte{c}))
|
||||
}
|
||||
for {
|
||||
err := w.buffer.WriteByte(c)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
_, err = w.upstream.Write(w.buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.buffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Fallthrough() error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
if w.buffer == nil {
|
||||
return nil
|
||||
}
|
||||
if !w.buffer.IsEmpty() {
|
||||
_, err := w.upstream.Write(w.buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.buffer.Release()
|
||||
w.buffer = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) WriterReplaceable() bool {
|
||||
return w.buffer == nil
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package bufio
|
|||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
@ -16,6 +15,7 @@ type CachedConn struct {
|
|||
}
|
||||
|
||||
func NewCachedConn(conn net.Conn, buffer *buf.Buffer) *CachedConn {
|
||||
buffer.IncRef()
|
||||
return &CachedConn{
|
||||
Conn: conn,
|
||||
buffer: buffer,
|
||||
|
@ -25,6 +25,9 @@ func NewCachedConn(conn net.Conn, buffer *buf.Buffer) *CachedConn {
|
|||
func (c *CachedConn) ReadCached() *buf.Buffer {
|
||||
buffer := c.buffer
|
||||
c.buffer = nil
|
||||
if buffer != nil {
|
||||
buffer.DecRef()
|
||||
}
|
||||
return buffer
|
||||
}
|
||||
|
||||
|
@ -34,6 +37,7 @@ func (c *CachedConn) Read(p []byte) (n int, err error) {
|
|||
if err == nil {
|
||||
return
|
||||
}
|
||||
c.buffer.DecRef()
|
||||
c.buffer.Release()
|
||||
c.buffer = nil
|
||||
}
|
||||
|
@ -44,6 +48,7 @@ func (c *CachedConn) WriteTo(w io.Writer) (n int64, err error) {
|
|||
if c.buffer != nil {
|
||||
wn, wErr := w.Write(c.buffer.Bytes())
|
||||
if wErr != nil {
|
||||
c.buffer.DecRef()
|
||||
c.buffer.Release()
|
||||
c.buffer = nil
|
||||
}
|
||||
|
@ -54,13 +59,6 @@ func (c *CachedConn) WriteTo(w io.Writer) (n int64, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *CachedConn) SetReadDeadline(t time.Time) error {
|
||||
if c.buffer != nil && !c.buffer.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
return c.Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *CachedConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
return Copy(c.Conn, r)
|
||||
}
|
||||
|
@ -78,7 +76,11 @@ func (c *CachedConn) WriterReplaceable() bool {
|
|||
}
|
||||
|
||||
func (c *CachedConn) Close() error {
|
||||
c.buffer.Release()
|
||||
if buffer := c.buffer; buffer != nil {
|
||||
buffer.DecRef()
|
||||
buffer.Release()
|
||||
c.buffer = nil
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
|
@ -88,6 +90,7 @@ type CachedReader struct {
|
|||
}
|
||||
|
||||
func NewCachedReader(upstream io.Reader, buffer *buf.Buffer) *CachedReader {
|
||||
buffer.IncRef()
|
||||
return &CachedReader{
|
||||
upstream: upstream,
|
||||
buffer: buffer,
|
||||
|
@ -97,6 +100,9 @@ func NewCachedReader(upstream io.Reader, buffer *buf.Buffer) *CachedReader {
|
|||
func (r *CachedReader) ReadCached() *buf.Buffer {
|
||||
buffer := r.buffer
|
||||
r.buffer = nil
|
||||
if buffer != nil {
|
||||
buffer.DecRef()
|
||||
}
|
||||
return buffer
|
||||
}
|
||||
|
||||
|
@ -106,6 +112,7 @@ func (r *CachedReader) Read(p []byte) (n int, err error) {
|
|||
if err == nil {
|
||||
return
|
||||
}
|
||||
r.buffer.DecRef()
|
||||
r.buffer.Release()
|
||||
r.buffer = nil
|
||||
}
|
||||
|
@ -134,7 +141,11 @@ func (r *CachedReader) ReaderReplaceable() bool {
|
|||
}
|
||||
|
||||
func (r *CachedReader) Close() error {
|
||||
r.buffer.Release()
|
||||
if buffer := r.buffer; buffer != nil {
|
||||
buffer.DecRef()
|
||||
buffer.Release()
|
||||
r.buffer = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -145,6 +156,7 @@ type CachedPacketConn struct {
|
|||
}
|
||||
|
||||
func NewCachedPacketConn(conn N.PacketConn, buffer *buf.Buffer, destination M.Socksaddr) *CachedPacketConn {
|
||||
buffer.IncRef()
|
||||
return &CachedPacketConn{
|
||||
PacketConn: conn,
|
||||
buffer: buffer,
|
||||
|
@ -158,16 +170,26 @@ func (c *CachedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
|
|||
if err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
c.buffer.DecRef()
|
||||
c.buffer.Release()
|
||||
c.buffer = nil
|
||||
return c.destination, nil
|
||||
}
|
||||
return c.PacketConn.ReadPacket(buffer)
|
||||
}
|
||||
|
||||
func (c *CachedPacketConn) ReadCachedPacket() (destination M.Socksaddr, buffer *buf.Buffer) {
|
||||
buffer = c.buffer
|
||||
func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
|
||||
buffer := c.buffer
|
||||
c.buffer = nil
|
||||
return c.destination, buffer
|
||||
if buffer != nil {
|
||||
buffer.DecRef()
|
||||
}
|
||||
packet := N.NewPacketBuffer()
|
||||
*packet = N.PacketBuffer{
|
||||
Buffer: buffer,
|
||||
Destination: c.destination,
|
||||
}
|
||||
return packet
|
||||
}
|
||||
|
||||
func (c *CachedPacketConn) Upstream() any {
|
||||
|
@ -181,3 +203,12 @@ func (c *CachedPacketConn) ReaderReplaceable() bool {
|
|||
func (c *CachedPacketConn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *CachedPacketConn) Close() error {
|
||||
if buffer := c.buffer; buffer != nil {
|
||||
buffer.DecRef()
|
||||
buffer.Release()
|
||||
c.buffer = nil
|
||||
}
|
||||
return c.PacketConn.Close()
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ func (c *ChunkReader) ReadBuffer(buffer *buf.Buffer) error {
|
|||
} else if !c.cache.IsEmpty() {
|
||||
return common.Error(buffer.ReadFrom(c.cache))
|
||||
}
|
||||
c.cache.FullReset()
|
||||
c.cache.Reset()
|
||||
err := c.upstream.ReadBuffer(c.cache)
|
||||
if err != nil {
|
||||
c.cache.Release()
|
||||
|
@ -46,14 +46,40 @@ func (c *ChunkReader) Read(p []byte) (n int, err error) {
|
|||
} else if !c.cache.IsEmpty() {
|
||||
return c.cache.Read(p)
|
||||
}
|
||||
c.cache.FullReset()
|
||||
c.cache.Reset()
|
||||
err = c.upstream.ReadBuffer(c.cache)
|
||||
if err != nil {
|
||||
c.cache.Release()
|
||||
c.cache = nil
|
||||
return
|
||||
}
|
||||
return c.cache.Read(p)
|
||||
}
|
||||
|
||||
func (c *ChunkReader) ReadByte() (byte, error) {
|
||||
buffer, err := c.ReadChunk()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return buffer.ReadByte()
|
||||
}
|
||||
|
||||
func (c *ChunkReader) ReadChunk() (*buf.Buffer, error) {
|
||||
if c.cache == nil {
|
||||
c.cache = buf.NewSize(c.maxChunkSize)
|
||||
} else if !c.cache.IsEmpty() {
|
||||
return c.cache, nil
|
||||
}
|
||||
c.cache.Reset()
|
||||
err := c.upstream.ReadBuffer(c.cache)
|
||||
if err != nil {
|
||||
c.cache.Release()
|
||||
c.cache = nil
|
||||
return nil, err
|
||||
}
|
||||
return c.cache, nil
|
||||
}
|
||||
|
||||
func (c *ChunkReader) MTU() int {
|
||||
return c.maxChunkSize
|
||||
}
|
||||
|
|
|
@ -35,14 +35,7 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
|||
|
||||
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
defer buffer.Release()
|
||||
if destination.IsFqdn() {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
|
||||
}
|
||||
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
|
||||
return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
|
||||
}
|
||||
|
||||
func (w *ExtendedUDPConn) Upstream() any {
|
||||
|
@ -70,69 +63,6 @@ func (w *ExtendedPacketConn) Upstream() any {
|
|||
return w.PacketConn
|
||||
}
|
||||
|
||||
type BindPacketConn struct {
|
||||
net.PacketConn
|
||||
Addr net.Addr
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.Addr)
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) RemoteAddr() net.Addr {
|
||||
return c.Addr
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) Upstream() any {
|
||||
return c.PacketConn
|
||||
}
|
||||
|
||||
type UnbindPacketConn struct {
|
||||
N.ExtendedConn
|
||||
Addr M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, err = c.ExtendedConn.Read(p)
|
||||
if err == nil {
|
||||
addr = c.Addr.UDPAddr()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) WriteTo(p []byte, _ net.Addr) (n int, err error) {
|
||||
return c.ExtendedConn.Write(p)
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
err = c.ExtendedConn.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination = c.Addr
|
||||
return
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
func NewUnbindPacketConn(conn net.Conn) *UnbindPacketConn {
|
||||
return &UnbindPacketConn{
|
||||
NewExtendedConn(conn),
|
||||
M.SocksaddrFromNet(conn.RemoteAddr()),
|
||||
}
|
||||
}
|
||||
|
||||
type ExtendedReaderWrapper struct {
|
||||
io.Reader
|
||||
}
|
||||
|
@ -188,7 +118,7 @@ func (w *ExtendedWriterWrapper) Upstream() any {
|
|||
return w.Writer
|
||||
}
|
||||
|
||||
func (w *ExtendedReaderWrapper) WriterReplaceable() bool {
|
||||
func (w *ExtendedWriterWrapper) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
|
@ -2,348 +2,316 @@ package bufio
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
)
|
||||
|
||||
type readOnlyReader struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return Copy(w, r.Reader)
|
||||
}
|
||||
|
||||
func (r *readOnlyReader) Upstream() any {
|
||||
return r.Reader
|
||||
}
|
||||
|
||||
func (r *readOnlyReader) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type writeOnlyWriter struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (w *writeOnlyWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
return Copy(w.Writer, r)
|
||||
}
|
||||
|
||||
func (w *writeOnlyWriter) Upstream() any {
|
||||
return w.Writer
|
||||
}
|
||||
|
||||
func (w *writeOnlyWriter) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func needWrapper(src, dst any) bool {
|
||||
_, srcTCPConn := src.(*net.TCPConn)
|
||||
_, dstTCPConn := dst.(*net.TCPConn)
|
||||
return (srcTCPConn || dstTCPConn) && !(srcTCPConn && dstTCPConn)
|
||||
}
|
||||
|
||||
func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
|
||||
if src == nil {
|
||||
func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
||||
if source == nil {
|
||||
return 0, E.New("nil reader")
|
||||
} else if dst == nil {
|
||||
} else if destination == nil {
|
||||
return 0, E.New("nil writer")
|
||||
}
|
||||
src = N.UnwrapReader(src)
|
||||
dst = N.UnwrapWriter(dst)
|
||||
if wt, ok := src.(io.WriterTo); ok {
|
||||
if needWrapper(dst, src) {
|
||||
dst = &writeOnlyWriter{dst}
|
||||
originSource := source
|
||||
var readCounters, writeCounters []N.CountFunc
|
||||
for {
|
||||
source, readCounters = N.UnwrapCountReader(source, readCounters)
|
||||
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
|
||||
if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
||||
cachedBuffer := cachedSrc.ReadCached()
|
||||
if cachedBuffer != nil {
|
||||
dataLen := cachedBuffer.Len()
|
||||
_, err = destination.Write(cachedBuffer.Bytes())
|
||||
cachedBuffer.Release()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
return wt.WriteTo(dst)
|
||||
break
|
||||
}
|
||||
if rt, ok := dst.(io.ReaderFrom); ok {
|
||||
if needWrapper(rt, src) {
|
||||
src = &readOnlyReader{src}
|
||||
}
|
||||
return rt.ReadFrom(src)
|
||||
}
|
||||
return CopyExtended(NewExtendedWriter(dst), NewExtendedReader(src))
|
||||
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
|
||||
}
|
||||
|
||||
func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
|
||||
safeSrc := N.IsSafeReader(src)
|
||||
headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst)
|
||||
if safeSrc != nil {
|
||||
if headroom == 0 {
|
||||
return CopyExtendedWithSrcBuffer(dst, safeSrc)
|
||||
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
|
||||
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
||||
if srcIsSyscall && dstIsSyscall {
|
||||
var handled bool
|
||||
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
if N.IsUnsafeWriter(dst) {
|
||||
return CopyExtendedWithPool(dst, src)
|
||||
}
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += headroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
_buffer := buf.StackNewSize(bufferSize)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
return CopyExtendedBuffer(dst, src, buffer)
|
||||
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
||||
}
|
||||
|
||||
func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) {
|
||||
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
readWaiter, isReadWaiter := CreateReadWaiter(source)
|
||||
if isReadWaiter {
|
||||
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
FrontHeadroom: frontHeadroom,
|
||||
RearHeadroom: rearHeadroom,
|
||||
MTU: N.CalculateMTU(source, destination),
|
||||
})
|
||||
if !needCopy || common.LowMemory {
|
||||
var handled bool
|
||||
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
|
||||
}
|
||||
|
||||
// Deprecated: not used
|
||||
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
buffer.Resize(frontHeadroom, 0)
|
||||
buffer.Reserve(rearHeadroom)
|
||||
var notFirstTime bool
|
||||
for {
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = src.ReadBuffer(readBuffer)
|
||||
err = source.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) {
|
||||
var notFirstTime bool
|
||||
for {
|
||||
var buffer *buf.Buffer
|
||||
buffer, err = src.ReadBufferThreadSafe()
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = dst.WriteBuffer(buffer)
|
||||
buffer.OverCap(rearHeadroom)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
options := N.NewReadWaitOptions(source, destination)
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = src.ReadBuffer(readBuffer)
|
||||
buffer := options.NewBuffer()
|
||||
err = source.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WriteBuffer(buffer)
|
||||
dataLen := buffer.Len()
|
||||
options.PostReturn(buffer)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
||||
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
|
||||
var group task.Group
|
||||
if _, dstDuplex := common.Cast[rw.WriteCloser](dest); dstDuplex {
|
||||
if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex {
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
err := common.Error(Copy(dest, conn))
|
||||
if E.IsMulti(err, io.EOF) {
|
||||
rw.CloseWrite(dest)
|
||||
err := common.Error(Copy(destination, source))
|
||||
if err == nil {
|
||||
N.CloseWrite(destination)
|
||||
} else {
|
||||
common.Close(dest)
|
||||
common.Close(destination)
|
||||
}
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
defer common.Close(dest)
|
||||
return common.Error(Copy(dest, conn))
|
||||
defer common.Close(destination)
|
||||
return common.Error(Copy(destination, source))
|
||||
})
|
||||
}
|
||||
if _, srcDuplex := common.Cast[rw.WriteCloser](conn); srcDuplex {
|
||||
if _, srcDuplex := common.Cast[N.WriteCloser](source); srcDuplex {
|
||||
group.Append("download", func(ctx context.Context) error {
|
||||
err := common.Error(Copy(conn, dest))
|
||||
if E.IsMulti(err, io.EOF) {
|
||||
rw.CloseWrite(conn)
|
||||
err := common.Error(Copy(source, destination))
|
||||
if err == nil {
|
||||
N.CloseWrite(source)
|
||||
} else {
|
||||
common.Close(conn)
|
||||
common.Close(source)
|
||||
}
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
group.Append("download", func(ctx context.Context) error {
|
||||
defer common.Close(conn)
|
||||
return common.Error(Copy(conn, dest))
|
||||
defer common.Close(source)
|
||||
return common.Error(Copy(source, destination))
|
||||
})
|
||||
}
|
||||
group.Cleanup(func() {
|
||||
common.Close(conn, dest)
|
||||
common.Close(source, destination)
|
||||
})
|
||||
return group.Run(ctx)
|
||||
}
|
||||
|
||||
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
||||
src = N.UnwrapPacketReader(src)
|
||||
dst = N.UnwrapPacketWriter(dst)
|
||||
safeSrc := N.IsSafePacketReader(src)
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
headroom := frontHeadroom + rearHeadroom
|
||||
if safeSrc != nil {
|
||||
if headroom == 0 {
|
||||
return CopyPacketWithSrcBuffer(dst, safeSrc)
|
||||
}
|
||||
}
|
||||
if N.IsUnsafeWriter(dst) {
|
||||
return CopyPacketWithPool(dst, src)
|
||||
}
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += headroom
|
||||
} else {
|
||||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
_buffer := buf.StackNewSize(bufferSize)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
|
||||
var readCounters, writeCounters []N.CountFunc
|
||||
var cachedPackets []*N.PacketBuffer
|
||||
originSource := source
|
||||
for {
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
destination, err = src.ReadPacket(readBuffer)
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
|
||||
destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters)
|
||||
if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
|
||||
packet := cachedReader.ReadCachedPacket()
|
||||
if packet != nil {
|
||||
cachedPackets = append(cachedPackets, packet)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WritePacket(buffer, destination)
|
||||
break
|
||||
}
|
||||
if cachedPackets != nil {
|
||||
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
|
||||
n += copeN
|
||||
return
|
||||
}
|
||||
|
||||
func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) {
|
||||
var buffer *buf.Buffer
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer, destination, err = src.ReadPacketThreadSafe()
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
var (
|
||||
handled bool
|
||||
copeN int64
|
||||
)
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
||||
if isReadWaiter {
|
||||
needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
|
||||
if !needCopy || common.LowMemory {
|
||||
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
||||
if handled {
|
||||
n += copeN
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
|
||||
n += copeN
|
||||
return
|
||||
}
|
||||
|
||||
func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||
options := N.NewReadWaitOptions(source, destination)
|
||||
var destinationAddress M.Socksaddr
|
||||
for {
|
||||
buffer := options.NewPacketBuffer()
|
||||
destinationAddress, err = source.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = dst.WritePacket(buffer, destination)
|
||||
options.PostReturn(buffer)
|
||||
err = destination.WritePacket(buffer, destinationAddress)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
destination, err = src.ReadPacket(readBuffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
|
||||
func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
options := N.NewReadWaitOptions(nil, destination)
|
||||
var notFirstTime bool
|
||||
for _, packetBuffer := range packetBuffers {
|
||||
buffer := options.Copy(packetBuffer.Buffer)
|
||||
dataLen := buffer.Len()
|
||||
err = destination.WritePacket(buffer, packetBuffer.Destination)
|
||||
N.PutPacketBuffer(packetBuffer)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
|
||||
var group task.Group
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
return common.Error(CopyPacket(dest, conn))
|
||||
return common.Error(CopyPacket(destination, source))
|
||||
})
|
||||
group.Append("download", func(ctx context.Context) error {
|
||||
return common.Error(CopyPacket(conn, dest))
|
||||
return common.Error(CopyPacket(source, destination))
|
||||
})
|
||||
group.Cleanup(func() {
|
||||
common.Close(conn, dest)
|
||||
common.Close(source, destination)
|
||||
})
|
||||
group.FastFail()
|
||||
return group.Run(ctx)
|
||||
|
|
90
common/bufio/copy_direct.go
Normal file
90
common/bufio/copy_direct.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||
rawSource, err := source.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rawDestination, err := destination.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
|
||||
return
|
||||
}
|
||||
|
||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
notFirstTime bool
|
||||
)
|
||||
for {
|
||||
buffer, err = source.WaitReadBuffer()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
)
|
||||
for {
|
||||
buffer, destination, err = source.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
146
common/bufio/copy_direct_posix.go
Normal file
146
common/bufio/copy_direct_posix.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
//go:build !windows
|
||||
|
||||
package bufio
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
||||
|
||||
type syscallReadWaiter struct {
|
||||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
||||
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
|
||||
rawConn, err := syscallConn.SyscallConn()
|
||||
if err == nil {
|
||||
return &syscallReadWaiter{rawConn: rawConn}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer := w.options.NewBuffer()
|
||||
var readN int
|
||||
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
} else {
|
||||
buffer.Release()
|
||||
}
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if w.readErr == syscall.EAGAIN {
|
||||
return false
|
||||
}
|
||||
if readN == 0 && w.readErr == nil {
|
||||
w.readErr = io.EOF
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
if w.readFunc == nil {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if w.readErr != nil {
|
||||
if w.readErr == io.EOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, E.Cause(w.readErr, "raw read")
|
||||
}
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
|
||||
|
||||
type syscallPacketReadWaiter struct {
|
||||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
readFrom M.Socksaddr
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
|
||||
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
|
||||
rawConn, err := syscallConn.SyscallConn()
|
||||
if err == nil {
|
||||
return &syscallPacketReadWaiter{rawConn: rawConn}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer := w.options.NewPacketBuffer()
|
||||
var readN int
|
||||
var from syscall.Sockaddr
|
||||
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if w.readErr != nil {
|
||||
buffer.Release()
|
||||
return w.readErr != syscall.EAGAIN
|
||||
}
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
}
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
switch fromAddr := from.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
|
||||
case *syscall.SockaddrInet6:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
if w.readFunc == nil {
|
||||
return nil, M.Socksaddr{}, os.ErrInvalid
|
||||
}
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if w.readErr != nil {
|
||||
err = E.Cause(w.readErr, "raw read")
|
||||
return
|
||||
}
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
destination = w.readFrom
|
||||
return
|
||||
}
|
77
common/bufio/copy_direct_test.go
Normal file
77
common/bufio/copy_direct_test.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCopyWaitTCP(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn := TCPPipe(t)
|
||||
readWaiter, created := CreateReadWaiter(outputConn)
|
||||
require.True(t, created)
|
||||
require.NotNil(t, readWaiter)
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
|
||||
require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{
|
||||
Conn: outputConn,
|
||||
readWaiter: readWaiter,
|
||||
}))
|
||||
}
|
||||
|
||||
type readWaitWrapper struct {
|
||||
net.Conn
|
||||
readWaiter N.ReadWaiter
|
||||
buffer *buf.Buffer
|
||||
}
|
||||
|
||||
func (r *readWaitWrapper) Read(p []byte) (n int, err error) {
|
||||
if r.buffer != nil {
|
||||
if r.buffer.Len() > 0 {
|
||||
return r.buffer.Read(p)
|
||||
}
|
||||
if r.buffer.IsEmpty() {
|
||||
r.buffer.Release()
|
||||
r.buffer = nil
|
||||
}
|
||||
}
|
||||
buffer, err := r.readWaiter.WaitReadBuffer()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r.buffer = buffer
|
||||
return r.buffer.Read(p)
|
||||
}
|
||||
|
||||
func TestCopyWaitUDP(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn, outputAddr := UDPPipe(t)
|
||||
readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn))
|
||||
require.True(t, created)
|
||||
require.NotNil(t, readWaiter)
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
|
||||
require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{
|
||||
PacketConn: outputConn,
|
||||
readWaiter: readWaiter,
|
||||
}, outputAddr))
|
||||
}
|
||||
|
||||
type packetReadWaitWrapper struct {
|
||||
net.PacketConn
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
buffer, destination, err := r.readWaiter.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n = copy(p, buffer.Bytes())
|
||||
buffer.Release()
|
||||
addr = destination.UDPAddr()
|
||||
return
|
||||
}
|
163
common/bufio/copy_direct_windows.go
Normal file
163
common/bufio/copy_direct_windows.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
||||
|
||||
type syscallReadWaiter struct {
|
||||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
hasData bool
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
||||
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
|
||||
rawConn, err := syscallConn.SyscallConn()
|
||||
if err == nil {
|
||||
return &syscallReadWaiter{rawConn: rawConn}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
if !w.hasData {
|
||||
w.hasData = true
|
||||
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
|
||||
// socket is readable if we return false. So the `recv` syscall will not block the system thread.
|
||||
return false
|
||||
}
|
||||
buffer := w.options.NewBuffer()
|
||||
var readN int32
|
||||
readN, w.readErr = recv(windows.Handle(fd), buffer.FreeBytes(), 0)
|
||||
if readN > 0 {
|
||||
buffer.Truncate(int(readN))
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
} else {
|
||||
buffer.Release()
|
||||
}
|
||||
if w.readErr == windows.WSAEWOULDBLOCK {
|
||||
return false
|
||||
}
|
||||
if readN == 0 && w.readErr == nil {
|
||||
w.readErr = io.EOF
|
||||
}
|
||||
w.hasData = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
if w.readFunc == nil {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if w.readErr != nil {
|
||||
if w.readErr == io.EOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, E.Cause(w.readErr, "raw read")
|
||||
}
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
|
||||
|
||||
type syscallPacketReadWaiter struct {
|
||||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
readFrom M.Socksaddr
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
hasData bool
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
|
||||
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
|
||||
rawConn, err := syscallConn.SyscallConn()
|
||||
if err == nil {
|
||||
return &syscallPacketReadWaiter{rawConn: rawConn}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
if !w.hasData {
|
||||
w.hasData = true
|
||||
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
|
||||
// socket is readable if we return false. So the `recvfrom` syscall will not block the system thread.
|
||||
return false
|
||||
}
|
||||
buffer := w.options.NewPacketBuffer()
|
||||
var readN int
|
||||
var from windows.Sockaddr
|
||||
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if w.readErr != nil {
|
||||
buffer.Release()
|
||||
return w.readErr != windows.WSAEWOULDBLOCK
|
||||
}
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
}
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
if from != nil {
|
||||
switch fromAddr := from.(type) {
|
||||
case *windows.SockaddrInet4:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
|
||||
case *windows.SockaddrInet6:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
||||
}
|
||||
}
|
||||
w.hasData = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
if w.readFunc == nil {
|
||||
return nil, M.Socksaddr{}, os.ErrInvalid
|
||||
}
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if w.readErr != nil {
|
||||
err = E.Cause(w.readErr, "raw read")
|
||||
return
|
||||
}
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
destination = w.readFrom
|
||||
return
|
||||
}
|
96
common/bufio/counter_conn.go
Normal file
96
common/bufio/counter_conn.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func NewInt64CounterConn(conn net.Conn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *CounterConn {
|
||||
return &CounterConn{
|
||||
NewExtendedConn(conn),
|
||||
common.Map(readCounter, func(it *atomic.Int64) N.CountFunc {
|
||||
return func(n int64) {
|
||||
it.Add(n)
|
||||
}
|
||||
}),
|
||||
common.Map(writeCounter, func(it *atomic.Int64) N.CountFunc {
|
||||
return func(n int64) {
|
||||
it.Add(n)
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func NewCounterConn(conn net.Conn, readCounter []N.CountFunc, writeCounter []N.CountFunc) *CounterConn {
|
||||
return &CounterConn{NewExtendedConn(conn), readCounter, writeCounter}
|
||||
}
|
||||
|
||||
type CounterConn struct {
|
||||
N.ExtendedConn
|
||||
readCounter []N.CountFunc
|
||||
writeCounter []N.CountFunc
|
||||
}
|
||||
|
||||
func (c *CounterConn) Read(p []byte) (n int, err error) {
|
||||
n, err = c.ExtendedConn.Read(p)
|
||||
if n > 0 {
|
||||
for _, counter := range c.readCounter {
|
||||
counter(int64(n))
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *CounterConn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
err := c.ExtendedConn.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if buffer.Len() > 0 {
|
||||
for _, counter := range c.readCounter {
|
||||
counter(int64(buffer.Len()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CounterConn) Write(p []byte) (n int, err error) {
|
||||
n, err = c.ExtendedConn.Write(p)
|
||||
if n > 0 {
|
||||
for _, counter := range c.writeCounter {
|
||||
counter(int64(n))
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *CounterConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
dataLen := int64(buffer.Len())
|
||||
err := c.ExtendedConn.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if dataLen > 0 {
|
||||
for _, counter := range c.writeCounter {
|
||||
counter(dataLen)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CounterConn) UnwrapReader() (io.Reader, []N.CountFunc) {
|
||||
return c.ExtendedConn, c.readCounter
|
||||
}
|
||||
|
||||
func (c *CounterConn) UnwrapWriter() (io.Writer, []N.CountFunc) {
|
||||
return c.ExtendedConn, c.writeCounter
|
||||
}
|
||||
|
||||
func (c *CounterConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
73
common/bufio/counter_packet_conn.go
Normal file
73
common/bufio/counter_packet_conn.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type CounterPacketConn struct {
|
||||
N.PacketConn
|
||||
readCounter []N.CountFunc
|
||||
writeCounter []N.CountFunc
|
||||
}
|
||||
|
||||
func NewInt64CounterPacketConn(conn N.PacketConn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *CounterPacketConn {
|
||||
return &CounterPacketConn{
|
||||
conn,
|
||||
common.Map(readCounter, func(it *atomic.Int64) N.CountFunc {
|
||||
return func(n int64) {
|
||||
it.Add(n)
|
||||
}
|
||||
}),
|
||||
common.Map(writeCounter, func(it *atomic.Int64) N.CountFunc {
|
||||
return func(n int64) {
|
||||
it.Add(n)
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func NewCounterPacketConn(conn N.PacketConn, readCounter []N.CountFunc, writeCounter []N.CountFunc) *CounterPacketConn {
|
||||
return &CounterPacketConn{conn, readCounter, writeCounter}
|
||||
}
|
||||
|
||||
func (c *CounterPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
destination, err = c.PacketConn.ReadPacket(buffer)
|
||||
if err == nil {
|
||||
if buffer.Len() > 0 {
|
||||
for _, counter := range c.readCounter {
|
||||
counter(int64(buffer.Len()))
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *CounterPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
dataLen := int64(buffer.Len())
|
||||
err := c.PacketConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if dataLen > 0 {
|
||||
for _, counter := range c.writeCounter {
|
||||
counter(dataLen)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CounterPacketConn) UnwrapPacketReader() (N.PacketReader, []N.CountFunc) {
|
||||
return c.PacketConn, c.readCounter
|
||||
}
|
||||
|
||||
func (c *CounterPacketConn) UnwrapPacketWriter() (N.PacketWriter, []N.CountFunc) {
|
||||
return c.PacketConn, c.writeCounter
|
||||
}
|
||||
|
||||
func (c *CounterPacketConn) Upstream() any {
|
||||
return c.PacketConn
|
||||
}
|
23
common/bufio/deadline/check.go
Normal file
23
common/bufio/deadline/check.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type WithoutReadDeadline interface {
|
||||
NeedAdditionalReadDeadline() bool
|
||||
}
|
||||
|
||||
func NeedAdditionalReadDeadline(rawReader any) bool {
|
||||
if deadlineReader, loaded := rawReader.(WithoutReadDeadline); loaded {
|
||||
return deadlineReader.NeedAdditionalReadDeadline()
|
||||
}
|
||||
if upstream, hasUpstream := rawReader.(N.WithUpstreamReader); hasUpstream {
|
||||
return NeedAdditionalReadDeadline(upstream.UpstreamReader())
|
||||
}
|
||||
if upstream, hasUpstream := rawReader.(common.WithUpstream); hasUpstream {
|
||||
return NeedAdditionalReadDeadline(upstream.Upstream())
|
||||
}
|
||||
return false
|
||||
}
|
61
common/bufio/deadline/conn.go
Normal file
61
common/bufio/deadline/conn.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
N.ExtendedConn
|
||||
reader Reader
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn) N.ExtendedConn {
|
||||
if deadlineConn, isDeadline := conn.(*Conn); isDeadline {
|
||||
return deadlineConn
|
||||
}
|
||||
return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)})
|
||||
}
|
||||
|
||||
func NewFallbackConn(conn net.Conn) N.ExtendedConn {
|
||||
if deadlineConn, isDeadline := conn.(*Conn); isDeadline {
|
||||
return deadlineConn
|
||||
}
|
||||
return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)})
|
||||
}
|
||||
|
||||
func (c *Conn) Read(p []byte) (n int, err error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *Conn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
return c.reader.ReadBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return c.reader.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) ReaderReplaceable() bool {
|
||||
return c.reader.ReaderReplaceable()
|
||||
}
|
||||
|
||||
func (c *Conn) UpstreamReader() any {
|
||||
return c.reader.UpstreamReader()
|
||||
}
|
||||
|
||||
func (c *Conn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Conn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
func (c *Conn) NeedAdditionalReadDeadline() bool {
|
||||
return false
|
||||
}
|
57
common/bufio/deadline/packet_conn.go
Normal file
57
common/bufio/deadline/packet_conn.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type PacketConn struct {
|
||||
N.NetPacketConn
|
||||
reader PacketReader
|
||||
}
|
||||
|
||||
func NewPacketConn(conn N.NetPacketConn) N.NetPacketConn {
|
||||
if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline {
|
||||
return deadlineConn
|
||||
}
|
||||
return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)})
|
||||
}
|
||||
|
||||
func NewFallbackPacketConn(conn N.NetPacketConn) N.NetPacketConn {
|
||||
if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline {
|
||||
return deadlineConn
|
||||
}
|
||||
return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)})
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
return c.reader.ReadFrom(p)
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
return c.reader.ReadPacket(buffer)
|
||||
}
|
||||
|
||||
func (c *PacketConn) SetReadDeadline(t time.Time) error {
|
||||
return c.reader.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReaderReplaceable() bool {
|
||||
return c.reader.ReaderReplaceable()
|
||||
}
|
||||
|
||||
func (c *PacketConn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *PacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
func (c *PacketConn) NeedAdditionalReadDeadline() bool {
|
||||
return false
|
||||
}
|
159
common/bufio/deadline/packet_reader.go
Normal file
159
common/bufio/deadline/packet_reader.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type TimeoutPacketReader interface {
|
||||
N.NetPacketReader
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
type PacketReader interface {
|
||||
TimeoutPacketReader
|
||||
N.WithUpstreamReader
|
||||
N.ReaderWithUpstream
|
||||
}
|
||||
|
||||
type packetReader struct {
|
||||
TimeoutPacketReader
|
||||
deadline atomic.TypedValue[time.Time]
|
||||
pipeDeadline pipeDeadline
|
||||
result chan *packetReadResult
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
type packetReadResult struct {
|
||||
buffer *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
err error
|
||||
}
|
||||
|
||||
func NewPacketReader(timeoutReader TimeoutPacketReader) PacketReader {
|
||||
return &packetReader{
|
||||
TimeoutPacketReader: timeoutReader,
|
||||
pipeDeadline: makePipeDeadline(),
|
||||
result: make(chan *packetReadResult, 1),
|
||||
done: makeFilledChan(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *packetReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, nil, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
go r.pipeReadFrom(len(p))
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, nil, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *packetReader) pipeReadFrom(pLen int) {
|
||||
buffer := buf.NewSize(pLen)
|
||||
n, addr, err := r.TimeoutPacketReader.ReadFrom(buffer.FreeBytes())
|
||||
buffer.Truncate(n)
|
||||
r.result <- &packetReadResult{
|
||||
buffer: buffer,
|
||||
destination: M.SocksaddrFromNet(addr),
|
||||
err: err,
|
||||
}
|
||||
r.done <- struct{}{}
|
||||
}
|
||||
|
||||
func (r *packetReader) pipeReturnFrom(result *packetReadResult, p []byte) (n int, addr net.Addr, err error) {
|
||||
n = copy(p, result.buffer.Bytes())
|
||||
if result.destination.IsValid() {
|
||||
if result.destination.IsFqdn() {
|
||||
addr = result.destination
|
||||
} else {
|
||||
addr = result.destination.UDPAddr()
|
||||
}
|
||||
}
|
||||
result.buffer.Advance(n)
|
||||
if result.buffer.IsEmpty() {
|
||||
result.buffer.Release()
|
||||
err = result.err
|
||||
} else {
|
||||
r.result <- result
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *packetReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
go r.pipeReadFrom(buffer.FreeLen())
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *packetReader) pipeReturnFromBuffer(result *packetReadResult, buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||
n, _ := buffer.Write(result.buffer.Bytes())
|
||||
result.buffer.Advance(n)
|
||||
if !result.buffer.IsEmpty() {
|
||||
r.result <- result
|
||||
return result.destination, nil
|
||||
} else {
|
||||
result.buffer.Release()
|
||||
return result.destination, result.err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *packetReader) SetReadDeadline(t time.Time) error {
|
||||
r.deadline.Store(t)
|
||||
r.pipeDeadline.set(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *packetReader) ReaderReplaceable() bool {
|
||||
select {
|
||||
case <-r.done:
|
||||
r.done <- struct{}{}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
r.result <- result
|
||||
return false
|
||||
default:
|
||||
}
|
||||
return r.deadline.Load().IsZero()
|
||||
}
|
||||
|
||||
func (r *packetReader) UpstreamReader() any {
|
||||
return r.TimeoutPacketReader
|
||||
}
|
101
common/bufio/deadline/packet_reader_fallback.go
Normal file
101
common/bufio/deadline/packet_reader_fallback.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type fallbackPacketReader struct {
|
||||
*packetReader
|
||||
disablePipe atomic.Bool
|
||||
inRead atomic.Bool
|
||||
}
|
||||
|
||||
func NewFallbackPacketReader(timeoutReader TimeoutPacketReader) PacketReader {
|
||||
return &fallbackPacketReader{packetReader: NewPacketReader(timeoutReader).(*packetReader)}
|
||||
}
|
||||
|
||||
func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, nil, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
if r.disablePipe.Load() {
|
||||
return r.TimeoutPacketReader.ReadFrom(p)
|
||||
} else if r.deadline.Load().IsZero() {
|
||||
r.done <- struct{}{}
|
||||
r.inRead.Store(true)
|
||||
defer r.inRead.Store(false)
|
||||
n, addr, err = r.TimeoutPacketReader.ReadFrom(p)
|
||||
return
|
||||
}
|
||||
go r.pipeReadFrom(len(p))
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, nil, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
if r.disablePipe.Load() {
|
||||
return r.TimeoutPacketReader.ReadPacket(buffer)
|
||||
} else if r.deadline.Load().IsZero() {
|
||||
r.done <- struct{}{}
|
||||
r.inRead.Store(true)
|
||||
defer r.inRead.Store(false)
|
||||
destination, err = r.TimeoutPacketReader.ReadPacket(buffer)
|
||||
return
|
||||
}
|
||||
go r.pipeReadFrom(buffer.FreeLen())
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *fallbackPacketReader) SetReadDeadline(t time.Time) error {
|
||||
if r.disablePipe.Load() {
|
||||
return r.TimeoutPacketReader.SetReadDeadline(t)
|
||||
} else if r.inRead.Load() {
|
||||
r.disablePipe.Store(true)
|
||||
return r.TimeoutPacketReader.SetReadDeadline(t)
|
||||
}
|
||||
return r.packetReader.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (r *fallbackPacketReader) ReaderReplaceable() bool {
|
||||
return r.disablePipe.Load() || r.packetReader.ReaderReplaceable()
|
||||
}
|
||||
|
||||
func (r *fallbackPacketReader) UpstreamReader() any {
|
||||
return r.packetReader.UpstreamReader()
|
||||
}
|
84
common/bufio/deadline/pipe.go
Normal file
84
common/bufio/deadline/pipe.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package deadline
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// pipeDeadline is an abstraction for handling timeouts.
|
||||
type pipeDeadline struct {
|
||||
mu sync.Mutex // Guards timer and cancel
|
||||
timer *time.Timer
|
||||
cancel chan struct{} // Must be non-nil
|
||||
}
|
||||
|
||||
func makePipeDeadline() pipeDeadline {
|
||||
return pipeDeadline{cancel: make(chan struct{})}
|
||||
}
|
||||
|
||||
// set sets the point in time when the deadline will time out.
|
||||
// A timeout event is signaled by closing the channel returned by waiter.
|
||||
// Once a timeout has occurred, the deadline can be refreshed by specifying a
|
||||
// t value in the future.
|
||||
//
|
||||
// A zero value for t prevents timeout.
|
||||
func (d *pipeDeadline) set(t time.Time) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if d.timer != nil && !d.timer.Stop() {
|
||||
<-d.cancel // Wait for the timer callback to finish and close cancel
|
||||
}
|
||||
d.timer = nil
|
||||
|
||||
// Time is zero, then there is no deadline.
|
||||
closed := isClosedChan(d.cancel)
|
||||
if t.IsZero() {
|
||||
if closed {
|
||||
d.cancel = make(chan struct{})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Time in the future, setup a timer to cancel in the future.
|
||||
if dur := time.Until(t); dur > 0 {
|
||||
if closed {
|
||||
d.cancel = make(chan struct{})
|
||||
}
|
||||
d.timer = time.AfterFunc(dur, func() {
|
||||
close(d.cancel)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Time in the past, so close immediately.
|
||||
if !closed {
|
||||
close(d.cancel)
|
||||
}
|
||||
}
|
||||
|
||||
// wait returns a channel that is closed when the deadline is exceeded.
|
||||
func (d *pipeDeadline) wait() chan struct{} {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
return d.cancel
|
||||
}
|
||||
|
||||
func isClosedChan(c <-chan struct{}) bool {
|
||||
select {
|
||||
case <-c:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func makeFilledChan() chan struct{} {
|
||||
ch := make(chan struct{}, 1)
|
||||
ch <- struct{}{}
|
||||
return ch
|
||||
}
|
152
common/bufio/deadline/reader.go
Normal file
152
common/bufio/deadline/reader.go
Normal file
|
@ -0,0 +1,152 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type TimeoutReader interface {
|
||||
io.Reader
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
type Reader interface {
|
||||
N.ExtendedReader
|
||||
TimeoutReader
|
||||
N.WithUpstreamReader
|
||||
N.ReaderWithUpstream
|
||||
}
|
||||
|
||||
type reader struct {
|
||||
N.ExtendedReader
|
||||
timeoutReader TimeoutReader
|
||||
deadline atomic.TypedValue[time.Time]
|
||||
pipeDeadline pipeDeadline
|
||||
result chan *readResult
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
type readResult struct {
|
||||
buffer *buf.Buffer
|
||||
err error
|
||||
}
|
||||
|
||||
func NewReader(timeoutReader TimeoutReader) Reader {
|
||||
return &reader{
|
||||
ExtendedReader: bufio.NewExtendedReader(timeoutReader),
|
||||
timeoutReader: timeoutReader,
|
||||
pipeDeadline: makePipeDeadline(),
|
||||
result: make(chan *readResult, 1),
|
||||
done: makeFilledChan(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *reader) Read(p []byte) (n int, err error) {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
go r.pipeRead(len(p))
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *reader) pipeReturn(result *readResult, p []byte) (n int, err error) {
|
||||
n = copy(p, result.buffer.Bytes())
|
||||
result.buffer.Advance(n)
|
||||
if result.buffer.IsEmpty() {
|
||||
result.buffer.Release()
|
||||
err = result.err
|
||||
} else {
|
||||
r.result <- result
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *reader) pipeRead(pLen int) {
|
||||
buffer := buf.NewSize(pLen)
|
||||
_, err := buffer.ReadOnceFrom(r.ExtendedReader)
|
||||
r.result <- &readResult{
|
||||
buffer: buffer,
|
||||
err: err,
|
||||
}
|
||||
r.done <- struct{}{}
|
||||
}
|
||||
|
||||
func (r *reader) ReadBuffer(buffer *buf.Buffer) error {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
go r.pipeRead(buffer.FreeLen())
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *reader) pipeReturnBuffer(result *readResult, buffer *buf.Buffer) error {
|
||||
n, _ := buffer.Write(result.buffer.Bytes())
|
||||
result.buffer.Advance(n)
|
||||
if !result.buffer.IsEmpty() {
|
||||
r.result <- result
|
||||
return nil
|
||||
} else {
|
||||
result.buffer.Release()
|
||||
return result.err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *reader) SetReadDeadline(t time.Time) error {
|
||||
r.deadline.Store(t)
|
||||
r.pipeDeadline.set(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *reader) ReaderReplaceable() bool {
|
||||
select {
|
||||
case <-r.done:
|
||||
r.done <- struct{}{}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
r.result <- result
|
||||
return false
|
||||
default:
|
||||
}
|
||||
return r.deadline.Load().IsZero()
|
||||
}
|
||||
|
||||
func (r *reader) UpstreamReader() any {
|
||||
return r.ExtendedReader
|
||||
}
|
98
common/bufio/deadline/reader_fallback.go
Normal file
98
common/bufio/deadline/reader_fallback.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
type fallbackReader struct {
|
||||
*reader
|
||||
disablePipe atomic.Bool
|
||||
inRead atomic.Bool
|
||||
}
|
||||
|
||||
func NewFallbackReader(timeoutReader TimeoutReader) Reader {
|
||||
return &fallbackReader{reader: NewReader(timeoutReader).(*reader)}
|
||||
}
|
||||
|
||||
func (r *fallbackReader) Read(p []byte) (n int, err error) {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
if r.disablePipe.Load() {
|
||||
return r.ExtendedReader.Read(p)
|
||||
} else if r.deadline.Load().IsZero() {
|
||||
r.done <- struct{}{}
|
||||
r.inRead.Store(true)
|
||||
defer r.inRead.Store(false)
|
||||
n, err = r.ExtendedReader.Read(p)
|
||||
return
|
||||
}
|
||||
go r.pipeRead(len(p))
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
if r.disablePipe.Load() {
|
||||
return r.ExtendedReader.ReadBuffer(buffer)
|
||||
} else if r.deadline.Load().IsZero() {
|
||||
r.done <- struct{}{}
|
||||
r.inRead.Store(true)
|
||||
defer r.inRead.Store(false)
|
||||
return r.ExtendedReader.ReadBuffer(buffer)
|
||||
}
|
||||
go r.pipeRead(buffer.FreeLen())
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *fallbackReader) SetReadDeadline(t time.Time) error {
|
||||
if r.disablePipe.Load() {
|
||||
return r.timeoutReader.SetReadDeadline(t)
|
||||
} else if r.inRead.Load() {
|
||||
r.disablePipe.Store(true)
|
||||
return r.timeoutReader.SetReadDeadline(t)
|
||||
}
|
||||
return r.reader.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (r *fallbackReader) ReaderReplaceable() bool {
|
||||
return r.disablePipe.Load() || r.reader.ReaderReplaceable()
|
||||
}
|
||||
|
||||
func (r *fallbackReader) UpstreamReader() any {
|
||||
return r.reader.UpstreamReader()
|
||||
}
|
75
common/bufio/deadline/serial.go
Normal file
75
common/bufio/deadline/serial.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/debug"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type SerialConn struct {
|
||||
N.ExtendedConn
|
||||
access sync.Mutex
|
||||
}
|
||||
|
||||
func NewSerialConn(conn N.ExtendedConn) N.ExtendedConn {
|
||||
if !debug.Enabled {
|
||||
return conn
|
||||
}
|
||||
return &SerialConn{ExtendedConn: conn}
|
||||
}
|
||||
|
||||
func (c *SerialConn) Read(p []byte) (n int, err error) {
|
||||
if !c.access.TryLock() {
|
||||
panic("concurrent read on deadline conn")
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.ExtendedConn.Read(p)
|
||||
}
|
||||
|
||||
func (c *SerialConn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
if !c.access.TryLock() {
|
||||
panic("concurrent read on deadline conn")
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.ExtendedConn.ReadBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *SerialConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
type SerialPacketConn struct {
|
||||
N.NetPacketConn
|
||||
access sync.Mutex
|
||||
}
|
||||
|
||||
func NewSerialPacketConn(conn N.NetPacketConn) N.NetPacketConn {
|
||||
if !debug.Enabled {
|
||||
return conn
|
||||
}
|
||||
return &SerialPacketConn{NetPacketConn: conn}
|
||||
}
|
||||
|
||||
func (c *SerialPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
if !c.access.TryLock() {
|
||||
panic("concurrent read on deadline conn")
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.NetPacketConn.ReadFrom(p)
|
||||
}
|
||||
|
||||
func (c *SerialPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
if !c.access.TryLock() {
|
||||
panic("concurrent read on deadline conn")
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.NetPacketConn.ReadPacket(buffer)
|
||||
}
|
||||
|
||||
func (c *SerialPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
104
common/bufio/fallback.go
Normal file
104
common/bufio/fallback.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ N.NetPacketConn = (*FallbackPacketConn)(nil)
|
||||
|
||||
type FallbackPacketConn struct {
|
||||
N.PacketConn
|
||||
writer N.NetPacketWriter
|
||||
}
|
||||
|
||||
func NewNetPacketConn(conn N.PacketConn) N.NetPacketConn {
|
||||
if packetConn, loaded := conn.(N.NetPacketConn); loaded {
|
||||
return packetConn
|
||||
}
|
||||
return &FallbackPacketConn{
|
||||
PacketConn: conn,
|
||||
writer: NewNetPacketWriter(conn),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
buffer := buf.With(p)
|
||||
destination, err := c.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n = buffer.Len()
|
||||
if buffer.Start() > 0 {
|
||||
copy(p, buffer.Bytes())
|
||||
}
|
||||
addr = destination.UDPAddr()
|
||||
return
|
||||
}
|
||||
|
||||
func (c *FallbackPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return c.writer.WriteTo(p, addr)
|
||||
}
|
||||
|
||||
func (c *FallbackPacketConn) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *FallbackPacketConn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *FallbackPacketConn) Upstream() any {
|
||||
return c.PacketConn
|
||||
}
|
||||
|
||||
func (c *FallbackPacketConn) UpstreamWriter() any {
|
||||
return c.writer
|
||||
}
|
||||
|
||||
var _ N.NetPacketWriter = (*FallbackPacketWriter)(nil)
|
||||
|
||||
type FallbackPacketWriter struct {
|
||||
N.PacketWriter
|
||||
frontHeadroom int
|
||||
rearHeadroom int
|
||||
}
|
||||
|
||||
func NewNetPacketWriter(writer N.PacketWriter) N.NetPacketWriter {
|
||||
if packetWriter, loaded := writer.(N.NetPacketWriter); loaded {
|
||||
return packetWriter
|
||||
}
|
||||
return &FallbackPacketWriter{
|
||||
PacketWriter: writer,
|
||||
frontHeadroom: N.CalculateFrontHeadroom(writer),
|
||||
rearHeadroom: N.CalculateRearHeadroom(writer),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *FallbackPacketWriter) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if c.frontHeadroom > 0 || c.rearHeadroom > 0 {
|
||||
buffer := buf.NewSize(len(p) + c.frontHeadroom + c.rearHeadroom)
|
||||
buffer.Resize(c.frontHeadroom, 0)
|
||||
common.Must1(buffer.Write(p))
|
||||
err = c.PacketWriter.WritePacket(buffer, M.SocksaddrFromNet(addr))
|
||||
} else {
|
||||
err = c.PacketWriter.WritePacket(buf.As(p), M.SocksaddrFromNet(addr))
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n = len(p)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *FallbackPacketWriter) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *FallbackPacketWriter) Upstream() any {
|
||||
return c.PacketWriter
|
||||
}
|
|
@ -10,12 +10,14 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
// Deprecated: bad usage
|
||||
func ReadBuffer(reader N.ExtendedReader, buffer *buf.Buffer) (n int, err error) {
|
||||
n, err = reader.Read(buffer.FreeBytes())
|
||||
buffer.Truncate(n)
|
||||
return
|
||||
}
|
||||
|
||||
// Deprecated: bad usage
|
||||
func ReadPacket(reader N.PacketReader, buffer *buf.Buffer) (n int, addr net.Addr, err error) {
|
||||
startLen := buffer.Len()
|
||||
addr, err = reader.ReadPacket(buffer)
|
||||
|
@ -23,6 +25,45 @@ func ReadPacket(reader N.PacketReader, buffer *buf.Buffer) (n int, addr net.Addr
|
|||
return
|
||||
}
|
||||
|
||||
func ReadBufferSize(reader io.Reader, bufferSize int) (buffer *buf.Buffer, err error) {
|
||||
readWaiter, isReadWaiter := CreateReadWaiter(reader)
|
||||
if isReadWaiter {
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
MTU: bufferSize,
|
||||
})
|
||||
return readWaiter.WaitReadBuffer()
|
||||
}
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
if extendedReader, isExtendedReader := reader.(N.ExtendedReader); isExtendedReader {
|
||||
err = extendedReader.ReadBuffer(buffer)
|
||||
} else {
|
||||
_, err = buffer.ReadOnceFrom(reader)
|
||||
}
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ReadPacketSize(reader N.PacketReader, packetSize int) (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(reader)
|
||||
if isReadWaiter {
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
MTU: packetSize,
|
||||
})
|
||||
buffer, destination, err = readWaiter.WaitReadPacket()
|
||||
return
|
||||
}
|
||||
buffer = buf.NewSize(packetSize)
|
||||
destination, err = reader.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func Write(writer io.Writer, data []byte) (n int, err error) {
|
||||
if extendedWriter, isExtended := writer.(N.ExtendedWriter); isExtended {
|
||||
return WriteBuffer(extendedWriter, buf.As(data))
|
||||
|
@ -35,13 +76,7 @@ func WriteBuffer(writer N.ExtendedWriter, buffer *buf.Buffer) (n int, err error)
|
|||
frontHeadroom := N.CalculateFrontHeadroom(writer)
|
||||
rearHeadroom := N.CalculateRearHeadroom(writer)
|
||||
if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() {
|
||||
bufferSize := N.CalculateMTU(nil, writer)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
newBuffer := buf.NewSize(bufferSize)
|
||||
newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom)
|
||||
newBuffer.Resize(frontHeadroom, 0)
|
||||
common.Must1(newBuffer.Write(buffer.Bytes()))
|
||||
buffer.Release()
|
||||
|
@ -67,13 +102,7 @@ func WritePacketBuffer(writer N.PacketWriter, buffer *buf.Buffer, destination M.
|
|||
frontHeadroom := N.CalculateFrontHeadroom(writer)
|
||||
rearHeadroom := N.CalculateRearHeadroom(writer)
|
||||
if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() {
|
||||
bufferSize := N.CalculateMTU(nil, writer)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
newBuffer := buf.NewSize(bufferSize)
|
||||
newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom)
|
||||
newBuffer.Resize(frontHeadroom, 0)
|
||||
common.Must1(newBuffer.Write(buffer.Bytes()))
|
||||
buffer.Release()
|
||||
|
|
212
common/bufio/nat.go
Normal file
212
common/bufio/nat.go
Normal file
|
@ -0,0 +1,212 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type NATPacketConn interface {
|
||||
N.NetPacketConn
|
||||
UpdateDestination(destinationAddress netip.Addr)
|
||||
}
|
||||
|
||||
func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
|
||||
return &unidirectionalNATPacketConn{
|
||||
NetPacketConn: conn,
|
||||
origin: socksaddrWithoutPort(origin),
|
||||
destination: socksaddrWithoutPort(destination),
|
||||
}
|
||||
}
|
||||
|
||||
func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
|
||||
return &bidirectionalNATPacketConn{
|
||||
NetPacketConn: conn,
|
||||
origin: socksaddrWithoutPort(origin),
|
||||
destination: socksaddrWithoutPort(destination),
|
||||
}
|
||||
}
|
||||
|
||||
func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
|
||||
return &destinationNATPacketConn{
|
||||
NetPacketConn: conn,
|
||||
origin: origin,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
type unidirectionalNATPacketConn struct {
|
||||
N.NetPacketConn
|
||||
origin M.Socksaddr
|
||||
destination M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
destination := M.SocksaddrFromNet(addr)
|
||||
if socksaddrWithoutPort(destination) == c.destination {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.origin.Addr,
|
||||
Fqdn: c.origin.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
|
||||
}
|
||||
|
||||
func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if socksaddrWithoutPort(destination) == c.destination {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.origin.Addr,
|
||||
Fqdn: c.origin.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return c.NetPacketConn.WritePacket(buffer, destination)
|
||||
}
|
||||
|
||||
func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
|
||||
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
|
||||
}
|
||||
|
||||
func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
func (c *unidirectionalNATPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
type bidirectionalNATPacketConn struct {
|
||||
N.NetPacketConn
|
||||
origin M.Socksaddr
|
||||
destination M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = c.NetPacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination := M.SocksaddrFromNet(addr)
|
||||
if socksaddrWithoutPort(destination) == c.origin {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.destination.Addr,
|
||||
Fqdn: c.destination.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
addr = destination.UDPAddr()
|
||||
return
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
destination := M.SocksaddrFromNet(addr)
|
||||
if socksaddrWithoutPort(destination) == c.destination {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.origin.Addr,
|
||||
Fqdn: c.origin.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
destination, err = c.NetPacketConn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if socksaddrWithoutPort(destination) == c.origin {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.destination.Addr,
|
||||
Fqdn: c.destination.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if socksaddrWithoutPort(destination) == c.destination {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.origin.Addr,
|
||||
Fqdn: c.origin.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return c.NetPacketConn.WritePacket(buffer, destination)
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
|
||||
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
type destinationNATPacketConn struct {
|
||||
N.NetPacketConn
|
||||
origin M.Socksaddr
|
||||
destination M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = c.NetPacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if M.SocksaddrFromNet(addr) == c.origin {
|
||||
addr = c.destination.UDPAddr()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if M.SocksaddrFromNet(addr) == c.destination {
|
||||
addr = c.origin.UDPAddr()
|
||||
}
|
||||
return c.NetPacketConn.WriteTo(p, addr)
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
destination, err = c.NetPacketConn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if destination == c.origin {
|
||||
destination = c.destination
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if destination == c.destination {
|
||||
destination = c.origin
|
||||
}
|
||||
return c.NetPacketConn.WritePacket(buffer, destination)
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
|
||||
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
|
||||
destination.Port = 0
|
||||
return destination
|
||||
}
|
39
common/bufio/nat_wait.go
Normal file
39
common/bufio/nat_wait.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func (c *bidirectionalNATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaiter, bool) {
|
||||
waiter, created := CreatePacketReadWaiter(c.NetPacketConn)
|
||||
if !created {
|
||||
return nil, false
|
||||
}
|
||||
return &waitBidirectionalNATPacketConn{c, waiter}, true
|
||||
}
|
||||
|
||||
type waitBidirectionalNATPacketConn struct {
|
||||
*bidirectionalNATPacketConn
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (c *waitBidirectionalNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return c.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (c *waitBidirectionalNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
buffer, destination, err = c.readWaiter.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if socksaddrWithoutPort(destination) == c.origin {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.destination.Addr,
|
||||
Fqdn: c.destination.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
277
common/bufio/net_test.go
Normal file
277
common/bufio/net_test.go
Normal file
|
@ -0,0 +1,277 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
var (
|
||||
group task.Group
|
||||
serverConn net.Conn
|
||||
clientConn net.Conn
|
||||
)
|
||||
group.Append0(func(ctx context.Context) error {
|
||||
var serverErr error
|
||||
serverConn, serverErr = listener.Accept()
|
||||
return serverErr
|
||||
})
|
||||
group.Append0(func(ctx context.Context) error {
|
||||
var clientErr error
|
||||
clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
|
||||
return clientErr
|
||||
})
|
||||
err = group.Run(context.Background())
|
||||
require.NoError(t, err)
|
||||
listener.Close()
|
||||
t.Cleanup(func() {
|
||||
serverConn.Close()
|
||||
clientConn.Close()
|
||||
})
|
||||
return serverConn, clientConn
|
||||
}
|
||||
|
||||
func UDPPipe(t *testing.T) (net.PacketConn, net.PacketConn, M.Socksaddr) {
|
||||
serverConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
clientConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
return serverConn, clientConn, M.SocksaddrFromNet(clientConn.LocalAddr())
|
||||
}
|
||||
|
||||
func Timeout(t *testing.T) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("timeout")
|
||||
}
|
||||
}()
|
||||
return cancel
|
||||
}
|
||||
|
||||
type hashPair struct {
|
||||
sendHash map[int][]byte
|
||||
recvHash map[int][]byte
|
||||
}
|
||||
|
||||
func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) {
|
||||
pingCh := make(chan hashPair)
|
||||
pongCh := make(chan hashPair)
|
||||
test := func(t *testing.T) error {
|
||||
defer close(pingCh)
|
||||
defer close(pongCh)
|
||||
pingOpen := false
|
||||
pongOpen := false
|
||||
var serverPair hashPair
|
||||
var clientPair hashPair
|
||||
|
||||
for {
|
||||
if pingOpen && pongOpen {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case serverPair, pingOpen = <-pingCh:
|
||||
assert.True(t, pingOpen)
|
||||
case clientPair, pongOpen = <-pongCh:
|
||||
assert.True(t, pongOpen)
|
||||
case <-time.After(10 * time.Second):
|
||||
return errors.New("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, serverPair.recvHash, clientPair.sendHash)
|
||||
assert.Equal(t, serverPair.sendHash, clientPair.recvHash)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return pingCh, pongCh, test
|
||||
}
|
||||
|
||||
func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error {
|
||||
times := 100
|
||||
chunkSize := int64(64 * 1024)
|
||||
|
||||
pingCh, pongCh, test := newLargeDataPair()
|
||||
writeRandData := func(conn net.Conn) (map[int][]byte, error) {
|
||||
buf := make([]byte, chunkSize)
|
||||
hashMap := map[int][]byte{}
|
||||
for i := 0; i < times; i++ {
|
||||
if _, err := rand.Read(buf[1:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf[0] = byte(i)
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
hashMap[i] = hash[:]
|
||||
|
||||
if _, err := conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return hashMap, nil
|
||||
}
|
||||
go func() {
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, chunkSize)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, err := io.ReadFull(outputConn, buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
|
||||
sendHash, err := writeRandData(outputConn)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pingCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
sendHash, err := writeRandData(inputConn)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, chunkSize)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, err = io.ReadFull(inputConn, buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
|
||||
pongCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
return test(t)
|
||||
}
|
||||
|
||||
func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error {
|
||||
rAddr := outputAddr.UDPAddr()
|
||||
times := 50
|
||||
chunkSize := 9000
|
||||
pingCh, pongCh, test := newLargeDataPair()
|
||||
writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) {
|
||||
hashMap := map[int][]byte{}
|
||||
mux := sync.Mutex{}
|
||||
for i := 0; i < times; i++ {
|
||||
buf := make([]byte, chunkSize)
|
||||
if _, err := rand.Read(buf[1:]); err != nil {
|
||||
t.Log(err.Error())
|
||||
continue
|
||||
}
|
||||
buf[0] = byte(i)
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
mux.Lock()
|
||||
hashMap[i] = hash[:]
|
||||
mux.Unlock()
|
||||
|
||||
if _, err := pc.WriteTo(buf, addr); err != nil {
|
||||
t.Log(err.Error())
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
return hashMap, nil
|
||||
}
|
||||
go func() {
|
||||
var (
|
||||
lAddr net.Addr
|
||||
err error
|
||||
)
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, 64*1024)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, lAddr, err = outputConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
hash := md5.Sum(buf[:chunkSize])
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
sendHash, err := writeRandData(outputConn, lAddr)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pingCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
sendHash, err := writeRandData(inputConn, rAddr)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, 64*1024)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, _, err := inputConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hash := md5.Sum(buf[:chunkSize])
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
|
||||
pongCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
|
||||
return test(t)
|
||||
}
|
|
@ -1,127 +0,0 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func CopyTimes(dst io.Writer, src io.Reader, times int) (n int64, err error) {
|
||||
return CopyExtendedTimes(NewExtendedWriter(N.UnwrapWriter(dst)), NewExtendedReader(N.UnwrapReader(src)), times)
|
||||
}
|
||||
|
||||
func CopyExtendedTimes(dst N.ExtendedWriter, src N.ExtendedReader, times int) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
dstUnsafe := N.IsUnsafeWriter(dst)
|
||||
var buffer *buf.Buffer
|
||||
if !dstUnsafe {
|
||||
_buffer := buf.StackNewSize(bufferSize)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer = common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
}
|
||||
notFirstTime := true
|
||||
for i := 0; i < times; i++ {
|
||||
if dstUnsafe {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
}
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = src.ReadBuffer(readBuffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type ReadFromWriter interface {
|
||||
io.ReaderFrom
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func ReadFrom0(readerFrom ReadFromWriter, reader io.Reader) (n int64, err error) {
|
||||
n, err = CopyTimes(readerFrom, reader, 1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var rn int64
|
||||
rn, err = readerFrom.ReadFrom(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += rn
|
||||
return
|
||||
}
|
||||
|
||||
func ReadFromN(readerFrom ReadFromWriter, reader io.Reader, times int) (n int64, err error) {
|
||||
n, err = CopyTimes(readerFrom, reader, times)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var rn int64
|
||||
rn, err = readerFrom.ReadFrom(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += rn
|
||||
return
|
||||
}
|
||||
|
||||
type WriteToReader interface {
|
||||
io.WriterTo
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func WriteTo0(writerTo WriteToReader, writer io.Writer) (n int64, err error) {
|
||||
n, err = CopyTimes(writer, writerTo, 1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var wn int64
|
||||
wn, err = writerTo.WriteTo(writer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += wn
|
||||
return
|
||||
}
|
||||
|
||||
func WriteToN(writerTo WriteToReader, writer io.Writer, times int) (n int64, err error) {
|
||||
n, err = CopyTimes(writer, writerTo, times)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var wn int64
|
||||
wn, err = writerTo.WriteTo(writer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += wn
|
||||
return
|
||||
}
|
79
common/bufio/splice_linux.go
Normal file
79
common/bufio/splice_linux.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const maxSpliceSize = 1 << 20
|
||||
|
||||
func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||
handed = true
|
||||
var pipeFDs [2]int
|
||||
err = unix.Pipe2(pipeFDs[:], syscall.O_CLOEXEC|syscall.O_NONBLOCK)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer unix.Close(pipeFDs[0])
|
||||
defer unix.Close(pipeFDs[1])
|
||||
|
||||
_, _ = unix.FcntlInt(uintptr(pipeFDs[0]), unix.F_SETPIPE_SZ, maxSpliceSize)
|
||||
var readN int
|
||||
var readErr error
|
||||
var writeSize int
|
||||
var writeErr error
|
||||
readFunc := func(fd uintptr) (done bool) {
|
||||
p0, p1 := unix.Splice(int(fd), nil, pipeFDs[1], nil, maxSpliceSize, unix.SPLICE_F_NONBLOCK)
|
||||
readN = int(p0)
|
||||
readErr = p1
|
||||
return readErr != unix.EAGAIN
|
||||
}
|
||||
writeFunc := func(fd uintptr) (done bool) {
|
||||
for writeSize > 0 {
|
||||
p0, p1 := unix.Splice(pipeFDs[0], nil, int(fd), nil, writeSize, unix.SPLICE_F_NONBLOCK|unix.SPLICE_F_MOVE)
|
||||
writeN := int(p0)
|
||||
writeErr = p1
|
||||
if writeErr != nil {
|
||||
return writeErr != unix.EAGAIN
|
||||
}
|
||||
writeSize -= writeN
|
||||
}
|
||||
return true
|
||||
}
|
||||
for {
|
||||
err = source.Read(readFunc)
|
||||
if err != nil {
|
||||
readErr = err
|
||||
}
|
||||
if readErr != nil {
|
||||
if readErr == unix.EINVAL || readErr == unix.ENOSYS {
|
||||
handed = false
|
||||
return
|
||||
}
|
||||
err = E.Cause(readErr, "splice read")
|
||||
return
|
||||
}
|
||||
if readN == 0 {
|
||||
return
|
||||
}
|
||||
writeSize = readN
|
||||
err = destination.Write(writeFunc)
|
||||
if err != nil {
|
||||
writeErr = err
|
||||
}
|
||||
if writeErr != nil {
|
||||
err = E.Cause(writeErr, "splice write")
|
||||
return
|
||||
}
|
||||
for _, readCounter := range readCounters {
|
||||
readCounter(int64(readN))
|
||||
}
|
||||
for _, writeCounter := range writeCounters {
|
||||
writeCounter(int64(readN))
|
||||
}
|
||||
}
|
||||
}
|
13
common/bufio/splice_stub.go
Normal file
13
common/bufio/splice_stub.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
//go:build !linux
|
||||
|
||||
package bufio
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||
return
|
||||
}
|
5
common/bufio/syscall_windows.go
Normal file
5
common/bufio/syscall_windows.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package bufio
|
||||
|
||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
|
||||
|
||||
//sys recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) [failretval == -1] = ws2_32.recv
|
|
@ -3,7 +3,6 @@ package bufio
|
|||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
|
@ -13,10 +12,10 @@ import (
|
|||
)
|
||||
|
||||
func NewVectorisedWriter(writer io.Writer) N.VectorisedWriter {
|
||||
if vectorisedWriter, ok := CreateVectorisedWriter(writer); ok {
|
||||
if vectorisedWriter, ok := CreateVectorisedWriter(N.UnwrapWriter(writer)); ok {
|
||||
return vectorisedWriter
|
||||
}
|
||||
return &SerialVectorisedWriter{upstream: writer}
|
||||
return &BufferedVectorisedWriter{upstream: writer}
|
||||
}
|
||||
|
||||
func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) {
|
||||
|
@ -34,10 +33,10 @@ func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) {
|
|||
case syscall.Conn:
|
||||
rawConn, err := w.SyscallConn()
|
||||
if err == nil {
|
||||
return &SyscallVectorisedWriter{writer, rawConn}, true
|
||||
return &SyscallVectorisedWriter{upstream: writer, rawConn: rawConn}, true
|
||||
}
|
||||
case syscall.RawConn:
|
||||
return &SyscallVectorisedWriter{writer, w}, true
|
||||
return &SyscallVectorisedWriter{upstream: writer, rawConn: w}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
@ -49,34 +48,41 @@ func CreateVectorisedPacketWriter(writer any) (N.VectorisedPacketWriter, bool) {
|
|||
case syscall.Conn:
|
||||
rawConn, err := w.SyscallConn()
|
||||
if err == nil {
|
||||
return &SyscallVectorisedPacketWriter{writer, rawConn}, true
|
||||
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: rawConn}, true
|
||||
}
|
||||
case syscall.RawConn:
|
||||
return &SyscallVectorisedPacketWriter{writer, w}, true
|
||||
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: w}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var _ N.VectorisedWriter = (*SerialVectorisedWriter)(nil)
|
||||
var _ N.VectorisedWriter = (*BufferedVectorisedWriter)(nil)
|
||||
|
||||
type SerialVectorisedWriter struct {
|
||||
type BufferedVectorisedWriter struct {
|
||||
upstream io.Writer
|
||||
access sync.Mutex
|
||||
}
|
||||
|
||||
func (w *SerialVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
for _, buffer := range buffers {
|
||||
_, err := w.upstream.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func (w *BufferedVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
bufferLen := buf.LenMulti(buffers)
|
||||
if bufferLen == 0 {
|
||||
return common.Error(w.upstream.Write(nil))
|
||||
} else if len(buffers) == 1 {
|
||||
return common.Error(w.upstream.Write(buffers[0].Bytes()))
|
||||
}
|
||||
return nil
|
||||
var bufferBytes []byte
|
||||
if bufferLen > 65535 {
|
||||
bufferBytes = make([]byte, bufferLen)
|
||||
} else {
|
||||
buffer := buf.NewSize(bufferLen)
|
||||
defer buffer.Release()
|
||||
bufferBytes = buffer.FreeBytes()
|
||||
}
|
||||
buf.CopyMulti(bufferBytes, buffers)
|
||||
return common.Error(w.upstream.Write(bufferBytes))
|
||||
}
|
||||
|
||||
func (w *SerialVectorisedWriter) Upstream() any {
|
||||
func (w *BufferedVectorisedWriter) Upstream() any {
|
||||
return w.upstream
|
||||
}
|
||||
|
||||
|
@ -105,6 +111,7 @@ var _ N.VectorisedWriter = (*SyscallVectorisedWriter)(nil)
|
|||
type SyscallVectorisedWriter struct {
|
||||
upstream any
|
||||
rawConn syscall.RawConn
|
||||
syscallVectorisedWriterFields
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedWriter) Upstream() any {
|
||||
|
@ -120,6 +127,7 @@ var _ N.VectorisedPacketWriter = (*SyscallVectorisedPacketWriter)(nil)
|
|||
type SyscallVectorisedPacketWriter struct {
|
||||
upstream any
|
||||
rawConn syscall.RawConn
|
||||
syscallVectorisedWriterFields
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedPacketWriter) Upstream() any {
|
||||
|
|
60
common/bufio/vectorised_test.go
Normal file
60
common/bufio/vectorised_test.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteVectorised(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn := TCPPipe(t)
|
||||
vectorisedWriter, created := CreateVectorisedWriter(inputConn)
|
||||
require.True(t, created)
|
||||
require.NotNil(t, vectorisedWriter)
|
||||
var bufA [1024]byte
|
||||
var bufB [1024]byte
|
||||
var bufC [2048]byte
|
||||
_, err := io.ReadFull(rand.Reader, bufA[:])
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadFull(rand.Reader, bufB[:])
|
||||
require.NoError(t, err)
|
||||
copy(bufC[:], bufA[:])
|
||||
copy(bufC[1024:], bufB[:])
|
||||
finish := Timeout(t)
|
||||
_, err = WriteVectorised(vectorisedWriter, [][]byte{bufA[:], bufB[:]})
|
||||
require.NoError(t, err)
|
||||
output := make([]byte, 2048)
|
||||
_, err = io.ReadFull(outputConn, output)
|
||||
finish()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bufC[:], output)
|
||||
}
|
||||
|
||||
func TestWriteVectorisedPacket(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn, outputAddr := UDPPipe(t)
|
||||
vectorisedWriter, created := CreateVectorisedPacketWriter(inputConn)
|
||||
require.True(t, created)
|
||||
require.NotNil(t, vectorisedWriter)
|
||||
var bufA [1024]byte
|
||||
var bufB [1024]byte
|
||||
var bufC [2048]byte
|
||||
_, err := io.ReadFull(rand.Reader, bufA[:])
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadFull(rand.Reader, bufB[:])
|
||||
require.NoError(t, err)
|
||||
copy(bufC[:], bufA[:])
|
||||
copy(bufC[1024:], bufB[:])
|
||||
finish := Timeout(t)
|
||||
_, err = WriteVectorisedPacket(vectorisedWriter, [][]byte{bufA[:], bufB[:]}, outputAddr)
|
||||
require.NoError(t, err)
|
||||
output := make([]byte, 2048)
|
||||
n, _, err := outputConn.ReadFrom(output)
|
||||
finish()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2048, n)
|
||||
require.Equal(t, bufC[:], output)
|
||||
}
|
|
@ -3,6 +3,8 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
|
@ -11,47 +13,81 @@ import (
|
|||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type syscallVectorisedWriterFields struct {
|
||||
access sync.Mutex
|
||||
iovecList *[]unix.Iovec
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
iovecList := make([]unix.Iovec, 0, len(buffers))
|
||||
for _, buffer := range buffers {
|
||||
var iovec unix.Iovec
|
||||
iovec.Base = &buffer.Bytes()[0]
|
||||
iovec.SetLen(buffer.Len())
|
||||
iovecList = append(iovecList, iovec)
|
||||
var iovecList []unix.Iovec
|
||||
if w.iovecList != nil {
|
||||
iovecList = *w.iovecList
|
||||
}
|
||||
iovecList = iovecList[:0]
|
||||
for index, buffer := range buffers {
|
||||
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
|
||||
iovecList[index].SetLen(buffer.Len())
|
||||
}
|
||||
if w.iovecList == nil {
|
||||
w.iovecList = new([]unix.Iovec)
|
||||
}
|
||||
*w.iovecList = iovecList // cache
|
||||
var innerErr unix.Errno
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
//nolint:staticcheck
|
||||
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
|
||||
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
|
||||
})
|
||||
if innerErr != 0 {
|
||||
err = innerErr
|
||||
err = os.NewSyscallError("SYS_WRITEV", innerErr)
|
||||
}
|
||||
for index := range iovecList {
|
||||
iovecList[index] = unix.Iovec{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
var sockaddr unix.Sockaddr
|
||||
if destination.IsIPv4() {
|
||||
sockaddr = &unix.SockaddrInet4{
|
||||
Port: int(destination.Port),
|
||||
Addr: destination.Addr.As4(),
|
||||
}
|
||||
} else {
|
||||
sockaddr = &unix.SockaddrInet6{
|
||||
Port: int(destination.Port),
|
||||
Addr: destination.Addr.As16(),
|
||||
}
|
||||
var iovecList []unix.Iovec
|
||||
if w.iovecList != nil {
|
||||
iovecList = *w.iovecList
|
||||
}
|
||||
iovecList = iovecList[:0]
|
||||
for index, buffer := range buffers {
|
||||
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
|
||||
iovecList[index].SetLen(buffer.Len())
|
||||
}
|
||||
if w.iovecList == nil {
|
||||
w.iovecList = new([]unix.Iovec)
|
||||
}
|
||||
*w.iovecList = iovecList // cache
|
||||
var innerErr error
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
_, innerErr = unix.SendmsgBuffers(int(fd), buf.ToSliceMulti(buffers), nil, sockaddr, 0)
|
||||
var msg unix.Msghdr
|
||||
name, nameLen := ToSockaddr(destination.AddrPort())
|
||||
msg.Name = (*byte)(name)
|
||||
msg.Namelen = nameLen
|
||||
if len(iovecList) > 0 {
|
||||
msg.Iov = &iovecList[0]
|
||||
msg.SetIovlen(len(iovecList))
|
||||
}
|
||||
_, innerErr = sendmsg(int(fd), &msg, 0)
|
||||
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
|
||||
})
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
for index := range iovecList {
|
||||
iovecList[index] = unix.Iovec{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
//go:linkname sendmsg golang.org/x/sys/unix.sendmsg
|
||||
func sendmsg(s int, msg *unix.Msghdr, flags int) (n int, err error)
|
||||
|
|
|
@ -1,62 +1,93 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type syscallVectorisedWriterFields struct {
|
||||
access sync.Mutex
|
||||
iovecList *[]windows.WSABuf
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
iovecList := make([]*windows.WSABuf, len(buffers))
|
||||
for i, buffer := range buffers {
|
||||
iovecList[i] = &windows.WSABuf{
|
||||
Len: uint32(buffer.Len()),
|
||||
Buf: &buffer.Bytes()[0],
|
||||
}
|
||||
var iovecList []windows.WSABuf
|
||||
if w.iovecList != nil {
|
||||
iovecList = *w.iovecList
|
||||
}
|
||||
iovecList = iovecList[:0]
|
||||
for _, buffer := range buffers {
|
||||
iovecList = append(iovecList, windows.WSABuf{
|
||||
Buf: &buffer.Bytes()[0],
|
||||
Len: uint32(buffer.Len()),
|
||||
})
|
||||
}
|
||||
if w.iovecList == nil {
|
||||
w.iovecList = new([]windows.WSABuf)
|
||||
}
|
||||
*w.iovecList = iovecList // cache
|
||||
var n uint32
|
||||
var innerErr error
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
innerErr = windows.WSASend(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
|
||||
innerErr = windows.WSASend(windows.Handle(fd), &iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
|
||||
return innerErr != windows.WSAEWOULDBLOCK
|
||||
})
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
for index := range iovecList {
|
||||
iovecList[index] = windows.WSABuf{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
iovecList := make([]*windows.WSABuf, len(buffers))
|
||||
for i, buffer := range buffers {
|
||||
iovecList[i] = &windows.WSABuf{
|
||||
Len: uint32(buffer.Len()),
|
||||
var iovecList []windows.WSABuf
|
||||
if w.iovecList != nil {
|
||||
iovecList = *w.iovecList
|
||||
}
|
||||
iovecList = iovecList[:0]
|
||||
for _, buffer := range buffers {
|
||||
iovecList = append(iovecList, windows.WSABuf{
|
||||
Buf: &buffer.Bytes()[0],
|
||||
}
|
||||
Len: uint32(buffer.Len()),
|
||||
})
|
||||
}
|
||||
var sockaddr windows.Sockaddr
|
||||
if destination.IsIPv4() {
|
||||
sockaddr = &windows.SockaddrInet4{
|
||||
Port: int(destination.Port),
|
||||
Addr: destination.Addr.As4(),
|
||||
}
|
||||
} else {
|
||||
sockaddr = &windows.SockaddrInet6{
|
||||
Port: int(destination.Port),
|
||||
Addr: destination.Addr.As16(),
|
||||
}
|
||||
if w.iovecList == nil {
|
||||
w.iovecList = new([]windows.WSABuf)
|
||||
}
|
||||
*w.iovecList = iovecList // cache
|
||||
var n uint32
|
||||
var innerErr error
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
innerErr = windows.WSASendto(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, sockaddr, nil, nil)
|
||||
name, nameLen := ToSockaddr(destination.AddrPort())
|
||||
innerErr = windows.WSASendTo(
|
||||
windows.Handle(fd),
|
||||
&iovecList[0],
|
||||
uint32(len(iovecList)),
|
||||
&n,
|
||||
0,
|
||||
(*windows.RawSockaddrAny)(name),
|
||||
nameLen,
|
||||
nil,
|
||||
nil)
|
||||
return innerErr != windows.WSAEWOULDBLOCK
|
||||
})
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
for index := range iovecList {
|
||||
iovecList[index] = windows.WSABuf{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
35
common/bufio/wait.go
Normal file
35
common/bufio/wait.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func CreateReadWaiter(reader io.Reader) (N.ReadWaiter, bool) {
|
||||
reader = N.UnwrapReader(reader)
|
||||
if readWaiter, isReadWaiter := reader.(N.ReadWaiter); isReadWaiter {
|
||||
return readWaiter, true
|
||||
}
|
||||
if readWaitCreator, isCreator := reader.(N.ReadWaitCreator); isCreator {
|
||||
return readWaitCreator.CreateReadWaiter()
|
||||
}
|
||||
if readWaiter, created := createSyscallReadWaiter(reader); created {
|
||||
return readWaiter, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func CreatePacketReadWaiter(reader N.PacketReader) (N.PacketReadWaiter, bool) {
|
||||
reader = N.UnwrapPacketReader(reader)
|
||||
if readWaiter, isReadWaiter := reader.(N.PacketReadWaiter); isReadWaiter {
|
||||
return readWaiter, true
|
||||
}
|
||||
if readWaitCreator, isCreator := reader.(N.PacketReadWaitCreator); isCreator {
|
||||
return readWaitCreator.CreateReadWaiter()
|
||||
}
|
||||
if readWaiter, created := createSyscallPacketReadWaiter(reader); created {
|
||||
return readWaiter, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
57
common/bufio/zsyscall_windows.go
Normal file
57
common/bufio/zsyscall_windows.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package bufio
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||
|
||||
procrecv = modws2_32.NewProc("recv")
|
||||
)
|
||||
|
||||
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
|
||||
var _p0 *byte
|
||||
if len(buf) > 0 {
|
||||
_p0 = &buf[0]
|
||||
}
|
||||
r0, _, e1 := syscall.Syscall6(procrecv.Addr(), 4, uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags), 0, 0)
|
||||
n = int32(r0)
|
||||
if n == -1 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
8
common/cache/lrucache.go
vendored
8
common/cache/lrucache.go
vendored
|
@ -258,6 +258,14 @@ func (c *LruCache[K, V]) Delete(key K) {
|
|||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *LruCache[K, V]) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for element := c.lru.Front(); element != nil; element = element.Next() {
|
||||
c.deleteElement(element)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LruCache[K, V]) maybeDeleteOldest() {
|
||||
if !c.staleReturn && c.maxAge > 0 {
|
||||
now := time.Now().Unix()
|
||||
|
|
65
common/canceler/instance.go
Normal file
65
common/canceler/instance.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package canceler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
ctx context.Context
|
||||
cancelFunc common.ContextCancelCauseFunc
|
||||
timer *time.Timer
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func New(ctx context.Context, cancelFunc common.ContextCancelCauseFunc, timeout time.Duration) *Instance {
|
||||
instance := &Instance{
|
||||
ctx,
|
||||
cancelFunc,
|
||||
time.NewTimer(timeout),
|
||||
timeout,
|
||||
}
|
||||
go instance.wait()
|
||||
return instance
|
||||
}
|
||||
|
||||
func (i *Instance) Update() bool {
|
||||
if !i.timer.Stop() {
|
||||
return false
|
||||
}
|
||||
if !i.timer.Reset(i.timeout) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (i *Instance) Timeout() time.Duration {
|
||||
return i.timeout
|
||||
}
|
||||
|
||||
func (i *Instance) SetTimeout(timeout time.Duration) bool {
|
||||
i.timeout = timeout
|
||||
return i.Update()
|
||||
}
|
||||
|
||||
func (i *Instance) wait() {
|
||||
select {
|
||||
case <-i.timer.C:
|
||||
case <-i.ctx.Done():
|
||||
}
|
||||
i.CloseWithError(os.ErrDeadlineExceeded)
|
||||
}
|
||||
|
||||
func (i *Instance) Close() error {
|
||||
i.CloseWithError(net.ErrClosed)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Instance) CloseWithError(err error) {
|
||||
i.timer.Stop()
|
||||
i.cancelFunc(err)
|
||||
}
|
76
common/canceler/packet.go
Normal file
76
common/canceler/packet.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package canceler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type PacketConn interface {
|
||||
N.PacketConn
|
||||
Timeout() time.Duration
|
||||
SetTimeout(timeout time.Duration) bool
|
||||
}
|
||||
|
||||
type TimerPacketConn struct {
|
||||
N.PacketConn
|
||||
instance *Instance
|
||||
}
|
||||
|
||||
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
|
||||
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
|
||||
oldTimeout := timeoutConn.Timeout()
|
||||
if oldTimeout > 0 && timeout >= oldTimeout {
|
||||
return ctx, conn
|
||||
}
|
||||
if timeoutConn.SetTimeout(timeout) {
|
||||
return ctx, conn
|
||||
}
|
||||
}
|
||||
err := conn.SetReadDeadline(time.Time{})
|
||||
if err == nil {
|
||||
return NewTimeoutPacketConn(ctx, conn, timeout)
|
||||
}
|
||||
ctx, cancel := common.ContextWithCancelCause(ctx)
|
||||
instance := New(ctx, cancel, timeout)
|
||||
return ctx, &TimerPacketConn{conn, instance}
|
||||
}
|
||||
|
||||
func (c *TimerPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
destination, err = c.PacketConn.ReadPacket(buffer)
|
||||
if err == nil {
|
||||
c.instance.Update()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *TimerPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
err := c.PacketConn.WritePacket(buffer, destination)
|
||||
if err == nil {
|
||||
c.instance.Update()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *TimerPacketConn) Timeout() time.Duration {
|
||||
return c.instance.Timeout()
|
||||
}
|
||||
|
||||
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
|
||||
return c.instance.SetTimeout(timeout)
|
||||
}
|
||||
|
||||
func (c *TimerPacketConn) Close() error {
|
||||
return common.Close(
|
||||
c.PacketConn,
|
||||
c.instance,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *TimerPacketConn) Upstream() any {
|
||||
return c.PacketConn
|
||||
}
|
76
common/canceler/packet_timeout.go
Normal file
76
common/canceler/packet_timeout.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package canceler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type TimeoutPacketConn struct {
|
||||
N.PacketConn
|
||||
timeout time.Duration
|
||||
cancel common.ContextCancelCauseFunc
|
||||
active time.Time
|
||||
}
|
||||
|
||||
func NewTimeoutPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) {
|
||||
ctx, cancel := common.ContextWithCancelCause(ctx)
|
||||
return ctx, &TimeoutPacketConn{
|
||||
PacketConn: conn,
|
||||
timeout: timeout,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
for {
|
||||
err = c.PacketConn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination, err = c.PacketConn.ReadPacket(buffer)
|
||||
if err == nil {
|
||||
c.active = time.Now()
|
||||
return
|
||||
} else if E.IsTimeout(err) {
|
||||
if time.Since(c.active) > c.timeout {
|
||||
c.cancel(err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TimeoutPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
err := c.PacketConn.WritePacket(buffer, destination)
|
||||
if err == nil {
|
||||
c.active = time.Now()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *TimeoutPacketConn) Timeout() time.Duration {
|
||||
return c.timeout
|
||||
}
|
||||
|
||||
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
|
||||
c.timeout = timeout
|
||||
return c.PacketConn.SetReadDeadline(time.Now()) == nil
|
||||
}
|
||||
|
||||
func (c *TimeoutPacketConn) Close() error {
|
||||
c.cancel(net.ErrClosed)
|
||||
return c.PacketConn.Close()
|
||||
}
|
||||
|
||||
func (c *TimeoutPacketConn) Upstream() any {
|
||||
return c.PacketConn
|
||||
}
|
11
common/clear.go
Normal file
11
common/clear.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
//go:build go1.21
|
||||
|
||||
package common
|
||||
|
||||
func ClearArray[T ~[]E, E any](t T) {
|
||||
clear(t)
|
||||
}
|
||||
|
||||
func ClearMap[T ~map[K]V, K comparable, V any](t T) {
|
||||
clear(t)
|
||||
}
|
16
common/clear_compat.go
Normal file
16
common/clear_compat.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
//go:build !go1.21
|
||||
|
||||
package common
|
||||
|
||||
func ClearArray[T ~[]E, E any](t T) {
|
||||
var defaultValue E
|
||||
for i := range t {
|
||||
t[i] = defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
func ClearMap[T ~map[K]V, K comparable, V any](t T) {
|
||||
for k := range t {
|
||||
delete(t, k)
|
||||
}
|
||||
}
|
112
common/cond.go
112
common/cond.go
|
@ -20,8 +20,8 @@ func Any[T any](array []T, block func(it T) bool) bool {
|
|||
}
|
||||
|
||||
func AnyIndexed[T any](array []T, block func(index int, it T) bool) bool {
|
||||
for i, it := range array {
|
||||
if block(i, it) {
|
||||
for index, it := range array {
|
||||
if block(index, it) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -38,8 +38,8 @@ func All[T any](array []T, block func(it T) bool) bool {
|
|||
}
|
||||
|
||||
func AllIndexed[T any](array []T, block func(index int, it T) bool) bool {
|
||||
for i, it := range array {
|
||||
if !block(i, it) {
|
||||
for index, it := range array {
|
||||
if !block(index, it) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
@ -47,8 +47,8 @@ func AllIndexed[T any](array []T, block func(index int, it T) bool) bool {
|
|||
}
|
||||
|
||||
func Contains[T comparable](arr []T, target T) bool {
|
||||
for i := range arr {
|
||||
if target == arr[i] {
|
||||
for index := range arr {
|
||||
if target == arr[index] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -81,8 +81,8 @@ func FlatMap[T any, N any](arr []T, block func(it T) []N) []N {
|
|||
|
||||
func FlatMapIndexed[T any, N any](arr []T, block func(index int, it T) []N) []N {
|
||||
var retAddr []N
|
||||
for i, item := range arr {
|
||||
retAddr = append(retAddr, block(i, item)...)
|
||||
for index, item := range arr {
|
||||
retAddr = append(retAddr, block(index, item)...)
|
||||
}
|
||||
return retAddr
|
||||
}
|
||||
|
@ -113,8 +113,8 @@ func FilterNotDefault[T comparable](arr []T) []T {
|
|||
|
||||
func FilterIndexed[T any](arr []T, block func(index int, it T) bool) []T {
|
||||
var retArr []T
|
||||
for i, it := range arr {
|
||||
if block(i, it) {
|
||||
for index, it := range arr {
|
||||
if block(index, it) {
|
||||
retArr = append(retArr, it)
|
||||
}
|
||||
}
|
||||
|
@ -130,22 +130,55 @@ func Find[T any](arr []T, block func(it T) bool) T {
|
|||
return DefaultValue[T]()
|
||||
}
|
||||
|
||||
func FindIndexed[T any](arr []T, block func(index int, it T) bool) T {
|
||||
for index, it := range arr {
|
||||
if block(index, it) {
|
||||
return it
|
||||
}
|
||||
}
|
||||
return DefaultValue[T]()
|
||||
}
|
||||
|
||||
func Index[T any](arr []T, block func(it T) bool) int {
|
||||
for index, it := range arr {
|
||||
if block(it) {
|
||||
return index
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
|
||||
for index, it := range arr {
|
||||
if block(index, it) {
|
||||
return index
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func Equal[S ~[]E, E comparable](s1, s2 S) bool {
|
||||
if len(s1) != len(s2) {
|
||||
return false
|
||||
}
|
||||
for i := range s1 {
|
||||
if s1[i] != s2[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
//go:norace
|
||||
func Dup[T any](obj T) T {
|
||||
if UnsafeBuffer {
|
||||
pointer := uintptr(unsafe.Pointer(&obj))
|
||||
//nolint:staticcheck
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
return *(*T)(unsafe.Pointer(pointer))
|
||||
} else {
|
||||
return obj
|
||||
}
|
||||
pointer := uintptr(unsafe.Pointer(&obj))
|
||||
//nolint:staticcheck
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
return *(*T)(unsafe.Pointer(pointer))
|
||||
}
|
||||
|
||||
func KeepAlive(obj any) {
|
||||
if UnsafeBuffer {
|
||||
runtime.KeepAlive(obj)
|
||||
}
|
||||
runtime.KeepAlive(obj)
|
||||
}
|
||||
|
||||
func Uniq[T comparable](arr []T) []T {
|
||||
|
@ -247,6 +280,14 @@ func Reverse[T any](arr []T) []T {
|
|||
return arr
|
||||
}
|
||||
|
||||
func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K {
|
||||
ret := make(map[V]K, len(m))
|
||||
for k, v := range m {
|
||||
ret[v] = k
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func Done(ctx context.Context) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
@ -268,16 +309,18 @@ func Must(errs ...error) {
|
|||
}
|
||||
}
|
||||
|
||||
func Must1(_ any, err error) {
|
||||
func Must1[T any](result T, err error) T {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func Must2(_, _ any, err error) {
|
||||
func Must2[T any, T2 any](result T, result2 T2, err error) (T, T2) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return result, result2
|
||||
}
|
||||
|
||||
// Deprecated: use E.Errors
|
||||
|
@ -313,6 +356,10 @@ func DefaultValue[T any]() T {
|
|||
return defaultValue
|
||||
}
|
||||
|
||||
func Ptr[T any](obj T) *T {
|
||||
return &obj
|
||||
}
|
||||
|
||||
func Close(closers ...any) error {
|
||||
var retErr error
|
||||
for _, closer := range closers {
|
||||
|
@ -335,22 +382,3 @@ func Close(closers ...any) error {
|
|||
}
|
||||
return retErr
|
||||
}
|
||||
|
||||
type Starter interface {
|
||||
Start() error
|
||||
}
|
||||
|
||||
func Start(starters ...any) error {
|
||||
for _, rawStarter := range starters {
|
||||
if rawStarter == nil {
|
||||
continue
|
||||
}
|
||||
if starter, isStarter := rawStarter.(Starter); isStarter {
|
||||
err := starter.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
23
common/context.go
Normal file
23
common/context.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Deprecated: not used
|
||||
func SelectContext(contextList []context.Context) (int, error) {
|
||||
if len(contextList) == 1 {
|
||||
<-contextList[0].Done()
|
||||
return 0, contextList[0].Err()
|
||||
}
|
||||
chosen, _, _ := reflect.Select(Map(Filter(contextList, func(it context.Context) bool {
|
||||
return it.Done() != nil
|
||||
}), func(it context.Context) reflect.SelectCase {
|
||||
return reflect.SelectCase{
|
||||
Dir: reflect.SelectRecv,
|
||||
Chan: reflect.ValueOf(it.Done()),
|
||||
}
|
||||
}))
|
||||
return chosen, contextList[chosen].Err()
|
||||
}
|
14
common/context_compat.go
Normal file
14
common/context_compat.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
//go:build go1.20
|
||||
|
||||
package common
|
||||
|
||||
import "context"
|
||||
|
||||
type (
|
||||
ContextCancelCauseFunc = context.CancelCauseFunc
|
||||
)
|
||||
|
||||
var (
|
||||
ContextWithCancelCause = context.WithCancelCause
|
||||
ContextCause = context.Cause
|
||||
)
|
16
common/context_lagacy.go
Normal file
16
common/context_lagacy.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
//go:build !go1.20
|
||||
|
||||
package common
|
||||
|
||||
import "context"
|
||||
|
||||
type ContextCancelCauseFunc func(cause error)
|
||||
|
||||
func ContextWithCancelCause(parentContext context.Context) (context.Context, ContextCancelCauseFunc) {
|
||||
ctx, cancel := context.WithCancel(parentContext)
|
||||
return ctx, func(_ error) { cancel() }
|
||||
}
|
||||
|
||||
func ContextCause(context context.Context) error {
|
||||
return context.Err()
|
||||
}
|
|
@ -1,59 +1,35 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func BindToInterface(finder InterfaceFinder, interfaceName string, interfaceIndex int) Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
|
||||
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int)) Func {
|
||||
func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int, err error)) Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
interfaceName, interfaceIndex := block(network, address)
|
||||
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
|
||||
interfaceName, interfaceIndex, err := block(network, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false)
|
||||
}
|
||||
}
|
||||
|
||||
const useInterfaceName = runtime.GOOS == "linux" || runtime.GOOS == "android"
|
||||
|
||||
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
|
||||
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
if interfaceName == "" && interfaceIndex == -1 {
|
||||
return E.New("interface not found: ", interfaceName)
|
||||
}
|
||||
if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) {
|
||||
return nil
|
||||
}
|
||||
if interfaceName == "" && interfaceIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
if interfaceName != "" && useInterfaceName || interfaceIndex != -1 && !useInterfaceName {
|
||||
return bindToInterface(conn, network, address, interfaceName, interfaceIndex)
|
||||
}
|
||||
if finder == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
var err error
|
||||
if useInterfaceName {
|
||||
interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex)
|
||||
} else {
|
||||
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if useInterfaceName {
|
||||
if interfaceName == "" {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
if interfaceIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return bindToInterface(conn, network, address, interfaceName, interfaceIndex)
|
||||
return bindToInterface(conn, network, address, finder, interfaceName, interfaceIndex, preferInterfaceName)
|
||||
}
|
||||
|
|
|
@ -1,16 +1,24 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
|
||||
if interfaceIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if interfaceIndex == -1 {
|
||||
if finder == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
iif, err := finder.ByName(interfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaceIndex = iif.Index
|
||||
}
|
||||
switch network {
|
||||
case "tcp6", "udp6":
|
||||
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, interfaceIndex)
|
||||
|
|
|
@ -1,30 +1,59 @@
|
|||
package control
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type InterfaceFinder interface {
|
||||
InterfaceIndexByName(name string) (int, error)
|
||||
InterfaceNameByIndex(index int) (string, error)
|
||||
Update() error
|
||||
Interfaces() []Interface
|
||||
ByName(name string) (*Interface, error)
|
||||
ByIndex(index int) (*Interface, error)
|
||||
ByAddr(addr netip.Addr) (*Interface, error)
|
||||
}
|
||||
|
||||
func DefaultInterfaceFinder() InterfaceFinder {
|
||||
return (*netInterfaceFinder)(nil)
|
||||
type Interface struct {
|
||||
Index int
|
||||
MTU int
|
||||
Name string
|
||||
HardwareAddr net.HardwareAddr
|
||||
Flags net.Flags
|
||||
Addresses []netip.Prefix
|
||||
}
|
||||
|
||||
type netInterfaceFinder struct{}
|
||||
func (i Interface) Equals(other Interface) bool {
|
||||
return i.Index == other.Index &&
|
||||
i.MTU == other.MTU &&
|
||||
i.Name == other.Name &&
|
||||
common.Equal(i.HardwareAddr, other.HardwareAddr) &&
|
||||
i.Flags == other.Flags &&
|
||||
common.Equal(i.Addresses, other.Addresses)
|
||||
}
|
||||
|
||||
func (w *netInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
|
||||
netInterface, err := net.InterfaceByName(name)
|
||||
func (i Interface) NetInterface() net.Interface {
|
||||
return *(*net.Interface)(unsafe.Pointer(&i))
|
||||
}
|
||||
|
||||
func InterfaceFromNet(iif net.Interface) (Interface, error) {
|
||||
ifAddrs, err := iif.Addrs()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return Interface{}, err
|
||||
}
|
||||
return netInterface.Index, nil
|
||||
return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil
|
||||
}
|
||||
|
||||
func (w *netInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
|
||||
netInterface, err := net.InterfaceByIndex(index)
|
||||
if err != nil {
|
||||
return "", err
|
||||
func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface {
|
||||
return Interface{
|
||||
Index: iif.Index,
|
||||
MTU: iif.MTU,
|
||||
Name: iif.Name,
|
||||
HardwareAddr: iif.HardwareAddr,
|
||||
Flags: iif.Flags,
|
||||
Addresses: addresses,
|
||||
}
|
||||
return netInterface.Name, nil
|
||||
}
|
||||
|
|
89
common/control/bind_finder_default.go
Normal file
89
common/control/bind_finder_default.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
|
||||
|
||||
type DefaultInterfaceFinder struct {
|
||||
interfaces []Interface
|
||||
}
|
||||
|
||||
func NewDefaultInterfaceFinder() *DefaultInterfaceFinder {
|
||||
return &DefaultInterfaceFinder{}
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) Update() error {
|
||||
netIfs, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaces := make([]Interface, 0, len(netIfs))
|
||||
for _, netIf := range netIfs {
|
||||
var iif Interface
|
||||
iif, err = InterfaceFromNet(netIf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaces = append(interfaces, iif)
|
||||
}
|
||||
f.interfaces = interfaces
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) UpdateInterfaces(interfaces []Interface) {
|
||||
f.interfaces = interfaces
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) Interfaces() []Interface {
|
||||
return f.interfaces
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) {
|
||||
for _, netInterface := range f.interfaces {
|
||||
if netInterface.Name == name {
|
||||
return &netInterface, nil
|
||||
}
|
||||
}
|
||||
_, err := net.InterfaceByName(name)
|
||||
if err == nil {
|
||||
err = f.Update()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f.ByName(name)
|
||||
}
|
||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) {
|
||||
for _, netInterface := range f.interfaces {
|
||||
if netInterface.Index == index {
|
||||
return &netInterface, nil
|
||||
}
|
||||
}
|
||||
_, err := net.InterfaceByIndex(index)
|
||||
if err == nil {
|
||||
err = f.Update()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f.ByIndex(index)
|
||||
}
|
||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) {
|
||||
for _, netInterface := range f.interfaces {
|
||||
for _, prefix := range netInterface.Addresses {
|
||||
if prefix.Contains(addr) {
|
||||
return &netInterface, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: addr.AsSlice()}, Err: E.New("no such network interface")}
|
||||
}
|
|
@ -1,13 +1,42 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
|
||||
var ifIndexDisabled atomic.Bool
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if !preferInterfaceName && !ifIndexDisabled.Load() {
|
||||
if interfaceIndex == -1 {
|
||||
if interfaceName == "" {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
iif, err := finder.ByName(interfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaceIndex = iif.Index
|
||||
}
|
||||
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if E.IsMulti(err, unix.ENOPROTOOPT, unix.EINVAL) {
|
||||
ifIndexDisabled.Store(true)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if interfaceName == "" {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
return unix.BindToDevice(int(fd), interfaceName)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -4,6 +4,6 @@ package control
|
|||
|
||||
import "syscall"
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -2,14 +2,25 @@ package control
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if interfaceIndex == -1 {
|
||||
if finder == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
iif, err := finder.ByName(interfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaceIndex = iif.Index
|
||||
}
|
||||
handle := syscall.Handle(fd)
|
||||
if M.ParseSocksaddr(address).AddrString() == "" {
|
||||
err := bind4(handle, interfaceIndex)
|
||||
|
|
33
common/control/frag_darwin.go
Normal file
33
common/control/frag_darwin.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func DisableUDPFragment() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
if N.NetworkName(network) != N.NetworkUDP {
|
||||
return nil
|
||||
}
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if network == "udp" || network == "udp4" {
|
||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
|
||||
}
|
||||
}
|
||||
if network == "udp" || network == "udp6" {
|
||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
|
@ -11,17 +11,19 @@ import (
|
|||
|
||||
func DisableUDPFragment() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkUDP:
|
||||
default:
|
||||
if N.NetworkName(network) != N.NetworkUDP {
|
||||
return nil
|
||||
}
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
if network == "udp" || network == "udp4" {
|
||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
}
|
||||
}
|
||||
if network == "udp6" {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
|
||||
if network == "udp" || network == "udp6" {
|
||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//go:build !((go1.19 && unix) || (!go1.19 && (linux || darwin)) || windows)
|
||||
//go:build !(linux || windows || darwin)
|
||||
|
||||
package control
|
||||
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
//go:build (go1.19 && unix && !linux) || (!go1.19 && darwin)
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func DisableUDPFragment() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
switch network {
|
||||
case "udp4":
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1); err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
|
||||
}
|
||||
case "udp6":
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1); err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
|
@ -25,17 +25,19 @@ const (
|
|||
|
||||
func DisableUDPFragment() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkUDP:
|
||||
default:
|
||||
if N.NetworkName(network) != N.NetworkUDP {
|
||||
return nil
|
||||
}
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
if network == "udp" || network == "udp4" {
|
||||
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
}
|
||||
}
|
||||
if network == "udp6" {
|
||||
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
|
||||
if network == "udp" || network == "udp6" {
|
||||
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package control
|
|||
import (
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
|
@ -30,6 +31,14 @@ func Conn(conn syscall.Conn, block func(fd uintptr) error) error {
|
|||
return Raw(rawConn, block)
|
||||
}
|
||||
|
||||
func Conn0[T any](conn syscall.Conn, block func(fd uintptr) (T, error)) (T, error) {
|
||||
rawConn, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), err
|
||||
}
|
||||
return Raw0[T](rawConn, block)
|
||||
}
|
||||
|
||||
func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
|
||||
var innerErr error
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
|
@ -37,3 +46,14 @@ func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
|
|||
})
|
||||
return E.Errors(innerErr, err)
|
||||
}
|
||||
|
||||
func Raw0[T any](rawConn syscall.RawConn, block func(fd uintptr) (T, error)) (T, error) {
|
||||
var (
|
||||
value T
|
||||
innerErr error
|
||||
)
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
value, innerErr = block(fd)
|
||||
})
|
||||
return value, E.Errors(innerErr, err)
|
||||
}
|
||||
|
|
|
@ -4,10 +4,10 @@ import (
|
|||
"syscall"
|
||||
)
|
||||
|
||||
func RoutingMark(mark int) Func {
|
||||
func RoutingMark(mark uint32) Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark)
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(mark))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
package control
|
||||
|
||||
func RoutingMark(mark int) Func {
|
||||
func RoutingMark(mark uint32) Func {
|
||||
return nil
|
||||
}
|
||||
|
|
58
common/control/redirect_darwin.go
Normal file
58
common/control/redirect_darwin.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
PF_OUT = 0x2
|
||||
DIOCNATLOOK = 0xc0544417
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
|
||||
pfFd, err := syscall.Open("/dev/pf", 0, syscall.O_RDONLY)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
defer syscall.Close(pfFd)
|
||||
nl := struct {
|
||||
saddr, daddr, rsaddr, rdaddr [16]byte
|
||||
sxport, dxport, rsxport, rdxport [4]byte
|
||||
af, proto, protoVariant, direction uint8
|
||||
}{
|
||||
af: syscall.AF_INET,
|
||||
proto: syscall.IPPROTO_TCP,
|
||||
direction: PF_OUT,
|
||||
}
|
||||
localAddr := M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
|
||||
removeAddr := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
|
||||
if localAddr.IsIPv4() {
|
||||
copy(nl.saddr[:net.IPv4len], removeAddr.Addr.AsSlice())
|
||||
copy(nl.daddr[:net.IPv4len], localAddr.Addr.AsSlice())
|
||||
nl.af = syscall.AF_INET
|
||||
} else {
|
||||
copy(nl.saddr[:], removeAddr.Addr.AsSlice())
|
||||
copy(nl.daddr[:], localAddr.Addr.AsSlice())
|
||||
nl.af = syscall.AF_INET6
|
||||
}
|
||||
binary.BigEndian.PutUint16(nl.sxport[:], removeAddr.Port)
|
||||
binary.BigEndian.PutUint16(nl.dxport[:], localAddr.Port)
|
||||
if _, _, errno := unix.Syscall(syscall.SYS_IOCTL, uintptr(pfFd), DIOCNATLOOK, uintptr(unsafe.Pointer(&nl))); errno != 0 {
|
||||
return netip.AddrPort{}, errno
|
||||
}
|
||||
var address netip.Addr
|
||||
if nl.af == unix.AF_INET {
|
||||
address = M.AddrFromIP(nl.rdaddr[:net.IPv4len])
|
||||
} else {
|
||||
address = netip.AddrFrom16(nl.rdaddr)
|
||||
}
|
||||
return netip.AddrPortFrom(address, binary.BigEndian.Uint16(nl.rdxport[:])), nil
|
||||
}
|
38
common/control/redirect_linux.go
Normal file
38
common/control/redirect_linux.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
|
||||
syscallConn, loaded := common.Cast[syscall.Conn](conn)
|
||||
if !loaded {
|
||||
return netip.AddrPort{}, os.ErrInvalid
|
||||
}
|
||||
return Conn0[netip.AddrPort](syscallConn, func(fd uintptr) (netip.AddrPort, error) {
|
||||
if M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap().IsIPv4() {
|
||||
raw, err := unix.GetsockoptIPv6Mreq(int(fd), unix.IPPROTO_IP, unix.SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
return netip.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
|
||||
} else {
|
||||
raw, err := unix.GetsockoptIPv6MTUInfo(int(fd), unix.IPPROTO_IPV6, unix.SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
var port [2]byte
|
||||
binary.BigEndian.PutUint16(port[:], raw.Addr.Port)
|
||||
return netip.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), binary.LittleEndian.Uint16(port[:])), nil
|
||||
}
|
||||
})
|
||||
}
|
13
common/control/redirect_other.go
Normal file
13
common/control/redirect_other.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
//go:build !linux && !darwin
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, os.ErrInvalid
|
||||
}
|
30
common/control/tcp_keep_alive_linux.go
Normal file
30
common/control/tcp_keep_alive_linux.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
if N.NetworkName(network) != N.NetworkTCP {
|
||||
return nil
|
||||
}
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
return E.Errors(
|
||||
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPIDLE, int(roundDurationUp(idle, time.Second))),
|
||||
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPINTVL, int(roundDurationUp(interval, time.Second))),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func roundDurationUp(d time.Duration, to time.Duration) time.Duration {
|
||||
return (d + to - 1) / to
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue