mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 20:37:40 +03:00
Compare commits
181 commits
Author | SHA1 | Date | |
---|---|---|---|
|
159e489fc3 | ||
|
d39c2c2fdd | ||
|
ea82ac275f | ||
|
ea0ac932ae | ||
|
2b41455f5a | ||
|
23b0180a1b | ||
|
ce1b4851a4 | ||
|
2238a05966 | ||
|
b55d1c78b3 | ||
|
d54716612c | ||
|
9eafc7fc62 | ||
|
d8153df67f | ||
|
d9f6eb136d | ||
|
4dabb9be97 | ||
|
be9840c70f | ||
|
aa7d2543a3 | ||
|
33beacc053 | ||
|
442cceb9fa | ||
|
3374a45475 | ||
|
73776cf797 | ||
|
957166799e | ||
|
809d8eca13 | ||
|
9f69e7f9f7 | ||
|
478265cd45 | ||
|
3f30aaf25e | ||
|
39040e06dc | ||
|
6edd2ce0ea | ||
|
0a2e2a3eaf | ||
|
4ba1eb123c | ||
|
c44912a861 | ||
|
a8f5bf4eb0 | ||
|
30e9d91b57 | ||
|
7fd3517e4d | ||
|
a8285e06a5 | ||
|
3613ead480 | ||
|
c8f251c668 | ||
|
fa5355e99e | ||
|
30fbafd954 | ||
|
fdca9b3f8e | ||
|
e52e04f721 | ||
|
7f621fdd78 | ||
|
ae139d9ee1 | ||
|
c432befd02 | ||
|
cc7e630923 | ||
|
0998999911 | ||
|
72ff654ee0 | ||
|
11ffb962ae | ||
|
fcb19641e6 | ||
|
524a6bd0d1 | ||
|
b5f9e70ffd | ||
|
c80c8f907c | ||
|
a4eb7fa900 | ||
|
7ec09d6045 | ||
|
0641c71805 | ||
|
e7ec021b81 | ||
|
0f2447a95b | ||
|
72db784fc7 | ||
|
d59ac57aaa | ||
|
c63546470b | ||
|
55908bea36 | ||
|
6567829958 | ||
|
c324d4143d | ||
|
0acb36c118 | ||
|
26511a251f | ||
|
afd8993773 | ||
|
96bef0733f | ||
|
ec1df651e8 | ||
|
e33b1d67d5 | ||
|
ed6cde73f7 | ||
|
73cc65605e | ||
|
6c19e0736d | ||
|
08e8c02fb1 | ||
|
7beca62e4f | ||
|
e422e3d048 | ||
|
fa81eabc29 | ||
|
4498e57839 | ||
|
f97054e917 | ||
|
a2f9fef936 | ||
|
7893a74f75 | ||
|
332e470075 | ||
|
2bf9cc7253 | ||
|
bf8fc103a4 | ||
|
774893928c | ||
|
7ceaf63d41 | ||
|
8806e421f2 | ||
|
c37f988a4f | ||
|
0b4c0a1283 | ||
|
4745c34b4c | ||
|
e0196407a3 | ||
|
d8ec9c46cc | ||
|
a33349366d | ||
|
3155c16990 | ||
|
caa4340dc9 | ||
|
a31dba8ad2 | ||
|
9571124cf4 | ||
|
1c495c9b07 | ||
|
0f95dfe0e3 | ||
|
f3380c8dfe | ||
|
ab4353dd13 | ||
|
e0e490af7b | ||
|
ad4d59e2ed | ||
|
aca2a85545 | ||
|
589c7eb4df | ||
|
2873799b6d | ||
|
47cc308abf | ||
|
d9f2559214 | ||
|
b8736cc58d | ||
|
e82ff8e2e6 | ||
|
ba68e017a9 | ||
|
967afcf6c1 | ||
|
0c110ad733 | ||
|
de1b0bd772 | ||
|
f67a0988a6 | ||
|
e0ee7f49e2 | ||
|
284cb5ce98 | ||
|
8fb1634c9a | ||
|
4ab8cac5eb | ||
|
eec2fc325a | ||
|
2fa039945c | ||
|
e5825dcb59 | ||
|
6b73a57a24 | ||
|
3e2631ef0b | ||
|
4d96f15eca | ||
|
f9c59e9940 | ||
|
8b68fc4d7a | ||
|
5bfc326913 | ||
|
04152ea672 | ||
|
a069af4787 | ||
|
807a51bb81 | ||
|
ec2595f010 | ||
|
c98e8b6921 | ||
|
a4a9ec42c6 | ||
|
6e3921083b | ||
|
8e89f9b4dc | ||
|
5f02cb1cff | ||
|
ef00a1ec1e | ||
|
5ee4f84faf | ||
|
30f7629317 | ||
|
9e1749e108 | ||
|
b1355d7a4b | ||
|
45f572495e | ||
|
3ac055b755 | ||
|
a6e8fa3019 | ||
|
57b8a4c64a | ||
|
b7a631f798 | ||
|
fa0cc448dc | ||
|
0d1b3d6d6d | ||
|
2196f193ac | ||
|
c501a58ae7 | ||
|
c9319a35ee | ||
|
cdb9908442 | ||
|
81d1bc2768 | ||
|
2e36fa6849 | ||
|
edd320c3a8 | ||
|
56b953e091 | ||
|
4c4773fe54 | ||
|
36acc18bfb | ||
|
ad670bab68 | ||
|
2a2dbf1971 | ||
|
afa72012e5 | ||
|
c7ef05a85b | ||
|
0f7de716ac | ||
|
231d7607bc | ||
|
8b43ec8058 | ||
|
c17babe0ba | ||
|
1f02d6daca | ||
|
aa34723225 | ||
|
ae8098ad39 | ||
|
05c71c99d1 | ||
|
060edf2d69 | ||
|
d171f04941 | ||
|
51aeb14a87 | ||
|
96f5dea24b | ||
|
3336b50119 | ||
|
36be4ef141 | ||
|
843bab522a | ||
|
af92594d6d | ||
|
f23499eaea | ||
|
d7ce998e7e | ||
|
99d07d6e5a | ||
|
028dcd722c |
212 changed files with 15666 additions and 1379 deletions
7
.github/renovate.json
vendored
7
.github/renovate.json
vendored
|
@ -5,6 +5,9 @@
|
|||
"config:base",
|
||||
":disableRateLimiting"
|
||||
],
|
||||
"baseBranches": [
|
||||
"dev"
|
||||
],
|
||||
"packageRules": [
|
||||
{
|
||||
"matchManagers": [
|
||||
|
@ -14,9 +17,9 @@
|
|||
},
|
||||
{
|
||||
"matchManagers": [
|
||||
"gomod"
|
||||
"dockerfile"
|
||||
],
|
||||
"groupName": "gomod"
|
||||
"groupName": "Dockerfile"
|
||||
}
|
||||
]
|
||||
}
|
43
.github/workflows/debug.yml
vendored
43
.github/workflows/debug.yml
vendored
|
@ -1,43 +0,0 @@
|
|||
name: Debug build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- '.github/**'
|
||||
- '!.github/workflows/debug.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Debug build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Get latest go version
|
||||
id: version
|
||||
run: |
|
||||
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ steps.version.outputs.go_version }}
|
||||
- name: Add cache to Go proxy
|
||||
run: |
|
||||
version=`git rev-parse HEAD`
|
||||
mkdir build
|
||||
pushd build
|
||||
go mod init build
|
||||
go get -v github.com/sagernet/sing@$version
|
||||
popd
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
16
.github/workflows/lint.yml
vendored
16
.github/workflows/lint.yml
vendored
|
@ -1,8 +1,9 @@
|
|||
name: Lint
|
||||
name: lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
|
@ -10,6 +11,7 @@ on:
|
|||
- '!.github/workflows/lint.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
|
@ -21,21 +23,17 @@ jobs:
|
|||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Get latest go version
|
||||
id: version
|
||||
run: |
|
||||
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ steps.version.outputs.go_version }}
|
||||
go-version: ^1.23
|
||||
- name: Cache go module
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
key: go-${{ hashFiles('**/go.sum') }}
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
with:
|
||||
version: latest
|
112
.github/workflows/test.yml
vendored
Normal file
112
.github/workflows/test.yml
vendored
Normal file
|
@ -0,0 +1,112 @@
|
|||
name: test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- '.github/**'
|
||||
- '!.github/workflows/debug.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Linux
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.23
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go120:
|
||||
name: Linux (Go 1.20)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.20
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go121:
|
||||
name: Linux (Go 1.21)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.21
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go122:
|
||||
name: Linux (Go 1.22)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.22
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_windows:
|
||||
name: Windows
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.23
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_darwin:
|
||||
name: macOS
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.23
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
|
@ -5,6 +5,8 @@ linters:
|
|||
- govet
|
||||
- gci
|
||||
- staticcheck
|
||||
- paralleltest
|
||||
- ineffassign
|
||||
|
||||
linters-settings:
|
||||
gci:
|
||||
|
@ -14,4 +16,9 @@ linters-settings:
|
|||
- prefix(github.com/sagernet/)
|
||||
- default
|
||||
staticcheck:
|
||||
go: '1.20'
|
||||
checks:
|
||||
- all
|
||||
- -SA1003
|
||||
|
||||
run:
|
||||
go: "1.23"
|
12
Makefile
12
Makefile
|
@ -8,14 +8,14 @@ fmt_install:
|
|||
go install -v github.com/daixiang0/gci@latest
|
||||
|
||||
lint:
|
||||
GOOS=linux golangci-lint run ./...
|
||||
GOOS=android golangci-lint run ./...
|
||||
GOOS=windows golangci-lint run ./...
|
||||
GOOS=darwin golangci-lint run ./...
|
||||
GOOS=freebsd golangci-lint run ./...
|
||||
GOOS=linux golangci-lint run
|
||||
GOOS=android golangci-lint run
|
||||
GOOS=windows golangci-lint run
|
||||
GOOS=darwin golangci-lint run
|
||||
GOOS=freebsd golangci-lint run
|
||||
|
||||
lint_install:
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
test:
|
||||
go test -v ./...
|
||||
go test ./...
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# sing
|
||||
|
||||

|
||||

|
||||
|
||||
Do you hear the people sing?
|
|
@ -10,26 +10,37 @@ type TypedValue[T any] struct {
|
|||
value atomic.Value
|
||||
}
|
||||
|
||||
// typedValue is a struct with determined type to resolve atomic.Value usages with interface types
|
||||
// https://github.com/golang/go/issues/22550
|
||||
//
|
||||
// The intention to have an atomic value store for errors. However, running this code panics:
|
||||
// panic: sync/atomic: store of inconsistently typed value into Value
|
||||
// This is because atomic.Value requires that the underlying concrete type be the same (which is a reasonable expectation for its implementation).
|
||||
// When going through the atomic.Value.Store method call, the fact that both these are of the error interface is lost.
|
||||
type typedValue[T any] struct {
|
||||
value T
|
||||
}
|
||||
|
||||
func (t *TypedValue[T]) Load() T {
|
||||
value := t.value.Load()
|
||||
if value == nil {
|
||||
return common.DefaultValue[T]()
|
||||
}
|
||||
return value.(T)
|
||||
return value.(typedValue[T]).value
|
||||
}
|
||||
|
||||
func (t *TypedValue[T]) Store(value T) {
|
||||
t.value.Store(value)
|
||||
t.value.Store(typedValue[T]{value})
|
||||
}
|
||||
|
||||
func (t *TypedValue[T]) Swap(new T) T {
|
||||
old := t.value.Swap(new)
|
||||
old := t.value.Swap(typedValue[T]{new})
|
||||
if old == nil {
|
||||
return common.DefaultValue[T]()
|
||||
}
|
||||
return old.(T)
|
||||
return old.(typedValue[T]).value
|
||||
}
|
||||
|
||||
func (t *TypedValue[T]) CompareAndSwap(old, new T) bool {
|
||||
return t.value.CompareAndSwap(old, new)
|
||||
return t.value.CompareAndSwap(typedValue[T]{old}, typedValue[T]{new})
|
||||
}
|
||||
|
|
|
@ -1,38 +1,30 @@
|
|||
package auth
|
||||
|
||||
type Authenticator interface {
|
||||
Verify(user string, pass string) bool
|
||||
Users() []string
|
||||
}
|
||||
import "github.com/sagernet/sing/common"
|
||||
|
||||
type User struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
type inMemoryAuthenticator struct {
|
||||
storage map[string]string
|
||||
usernames []string
|
||||
type Authenticator struct {
|
||||
userMap map[string][]string
|
||||
}
|
||||
|
||||
func (au *inMemoryAuthenticator) Verify(username string, password string) bool {
|
||||
realPass, ok := au.storage[username]
|
||||
return ok && realPass == password
|
||||
}
|
||||
|
||||
func (au *inMemoryAuthenticator) Users() []string { return au.usernames }
|
||||
|
||||
func NewAuthenticator(users []User) Authenticator {
|
||||
func NewAuthenticator(users []User) *Authenticator {
|
||||
if len(users) == 0 {
|
||||
return nil
|
||||
}
|
||||
au := &inMemoryAuthenticator{
|
||||
storage: make(map[string]string),
|
||||
usernames: make([]string, 0, len(users)),
|
||||
au := &Authenticator{
|
||||
userMap: make(map[string][]string),
|
||||
}
|
||||
for _, user := range users {
|
||||
au.storage[user.Username] = user.Password
|
||||
au.usernames = append(au.usernames, user.Username)
|
||||
au.userMap[user.Username] = append(au.userMap[user.Username], user.Password)
|
||||
}
|
||||
return au
|
||||
}
|
||||
|
||||
func (au *Authenticator) Verify(username string, password string) bool {
|
||||
passwordList, ok := au.userMap[username]
|
||||
return ok && common.Contains(passwordList, password)
|
||||
}
|
||||
|
|
|
@ -2,11 +2,10 @@ package baderror
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func Contains(err error, msgList ...string) bool {
|
||||
|
@ -22,8 +21,7 @@ func WrapH2(err error) error {
|
|||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
err = E.Unwrap(err)
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return io.EOF
|
||||
}
|
||||
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {
|
||||
|
|
3
common/binary/README.md
Normal file
3
common/binary/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# binary
|
||||
|
||||
mod from go 1.22.3
|
817
common/binary/binary.go
Normal file
817
common/binary/binary.go
Normal file
|
@ -0,0 +1,817 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package binary implements simple translation between numbers and byte
|
||||
// sequences and encoding and decoding of varints.
|
||||
//
|
||||
// Numbers are translated by reading and writing fixed-size values.
|
||||
// A fixed-size value is either a fixed-size arithmetic
|
||||
// type (bool, int8, uint8, int16, float32, complex64, ...)
|
||||
// or an array or struct containing only fixed-size values.
|
||||
//
|
||||
// The varint functions encode and decode single integer values using
|
||||
// a variable-length encoding; smaller values require fewer bytes.
|
||||
// For a specification, see
|
||||
// https://developers.google.com/protocol-buffers/docs/encoding.
|
||||
//
|
||||
// This package favors simplicity over efficiency. Clients that require
|
||||
// high-performance serialization, especially for large data structures,
|
||||
// should look at more advanced solutions such as the [encoding/gob]
|
||||
// package or [google.golang.org/protobuf] for protocol buffers.
|
||||
package binary
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A ByteOrder specifies how to convert byte slices into
|
||||
// 16-, 32-, or 64-bit unsigned integers.
|
||||
//
|
||||
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
|
||||
type ByteOrder interface {
|
||||
Uint16([]byte) uint16
|
||||
Uint32([]byte) uint32
|
||||
Uint64([]byte) uint64
|
||||
PutUint16([]byte, uint16)
|
||||
PutUint32([]byte, uint32)
|
||||
PutUint64([]byte, uint64)
|
||||
String() string
|
||||
}
|
||||
|
||||
// AppendByteOrder specifies how to append 16-, 32-, or 64-bit unsigned integers
|
||||
// into a byte slice.
|
||||
//
|
||||
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
|
||||
type AppendByteOrder interface {
|
||||
AppendUint16([]byte, uint16) []byte
|
||||
AppendUint32([]byte, uint32) []byte
|
||||
AppendUint64([]byte, uint64) []byte
|
||||
String() string
|
||||
}
|
||||
|
||||
// LittleEndian is the little-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var LittleEndian littleEndian
|
||||
|
||||
// BigEndian is the big-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var BigEndian bigEndian
|
||||
|
||||
type littleEndian struct{}
|
||||
|
||||
func (littleEndian) Uint16(b []byte) uint16 {
|
||||
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint16(b[0]) | uint16(b[1])<<8
|
||||
}
|
||||
|
||||
func (littleEndian) PutUint16(b []byte, v uint16) {
|
||||
_ = b[1] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v)
|
||||
b[1] = byte(v >> 8)
|
||||
}
|
||||
|
||||
func (littleEndian) AppendUint16(b []byte, v uint16) []byte {
|
||||
return append(b,
|
||||
byte(v),
|
||||
byte(v>>8),
|
||||
)
|
||||
}
|
||||
|
||||
func (littleEndian) Uint32(b []byte) uint32 {
|
||||
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
|
||||
}
|
||||
|
||||
func (littleEndian) PutUint32(b []byte, v uint32) {
|
||||
_ = b[3] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v)
|
||||
b[1] = byte(v >> 8)
|
||||
b[2] = byte(v >> 16)
|
||||
b[3] = byte(v >> 24)
|
||||
}
|
||||
|
||||
func (littleEndian) AppendUint32(b []byte, v uint32) []byte {
|
||||
return append(b,
|
||||
byte(v),
|
||||
byte(v>>8),
|
||||
byte(v>>16),
|
||||
byte(v>>24),
|
||||
)
|
||||
}
|
||||
|
||||
func (littleEndian) Uint64(b []byte) uint64 {
|
||||
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
|
||||
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
|
||||
}
|
||||
|
||||
func (littleEndian) PutUint64(b []byte, v uint64) {
|
||||
_ = b[7] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v)
|
||||
b[1] = byte(v >> 8)
|
||||
b[2] = byte(v >> 16)
|
||||
b[3] = byte(v >> 24)
|
||||
b[4] = byte(v >> 32)
|
||||
b[5] = byte(v >> 40)
|
||||
b[6] = byte(v >> 48)
|
||||
b[7] = byte(v >> 56)
|
||||
}
|
||||
|
||||
func (littleEndian) AppendUint64(b []byte, v uint64) []byte {
|
||||
return append(b,
|
||||
byte(v),
|
||||
byte(v>>8),
|
||||
byte(v>>16),
|
||||
byte(v>>24),
|
||||
byte(v>>32),
|
||||
byte(v>>40),
|
||||
byte(v>>48),
|
||||
byte(v>>56),
|
||||
)
|
||||
}
|
||||
|
||||
func (littleEndian) String() string { return "LittleEndian" }
|
||||
|
||||
func (littleEndian) GoString() string { return "binary.LittleEndian" }
|
||||
|
||||
type bigEndian struct{}
|
||||
|
||||
func (bigEndian) Uint16(b []byte) uint16 {
|
||||
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint16(b[1]) | uint16(b[0])<<8
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint16(b []byte, v uint16) {
|
||||
_ = b[1] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v >> 8)
|
||||
b[1] = byte(v)
|
||||
}
|
||||
|
||||
func (bigEndian) AppendUint16(b []byte, v uint16) []byte {
|
||||
return append(b,
|
||||
byte(v>>8),
|
||||
byte(v),
|
||||
)
|
||||
}
|
||||
|
||||
func (bigEndian) Uint32(b []byte) uint32 {
|
||||
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint32(b []byte, v uint32) {
|
||||
_ = b[3] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v >> 24)
|
||||
b[1] = byte(v >> 16)
|
||||
b[2] = byte(v >> 8)
|
||||
b[3] = byte(v)
|
||||
}
|
||||
|
||||
func (bigEndian) AppendUint32(b []byte, v uint32) []byte {
|
||||
return append(b,
|
||||
byte(v>>24),
|
||||
byte(v>>16),
|
||||
byte(v>>8),
|
||||
byte(v),
|
||||
)
|
||||
}
|
||||
|
||||
func (bigEndian) Uint64(b []byte) uint64 {
|
||||
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
|
||||
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint64(b []byte, v uint64) {
|
||||
_ = b[7] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v >> 56)
|
||||
b[1] = byte(v >> 48)
|
||||
b[2] = byte(v >> 40)
|
||||
b[3] = byte(v >> 32)
|
||||
b[4] = byte(v >> 24)
|
||||
b[5] = byte(v >> 16)
|
||||
b[6] = byte(v >> 8)
|
||||
b[7] = byte(v)
|
||||
}
|
||||
|
||||
func (bigEndian) AppendUint64(b []byte, v uint64) []byte {
|
||||
return append(b,
|
||||
byte(v>>56),
|
||||
byte(v>>48),
|
||||
byte(v>>40),
|
||||
byte(v>>32),
|
||||
byte(v>>24),
|
||||
byte(v>>16),
|
||||
byte(v>>8),
|
||||
byte(v),
|
||||
)
|
||||
}
|
||||
|
||||
func (bigEndian) String() string { return "BigEndian" }
|
||||
|
||||
func (bigEndian) GoString() string { return "binary.BigEndian" }
|
||||
|
||||
func (nativeEndian) String() string { return "NativeEndian" }
|
||||
|
||||
func (nativeEndian) GoString() string { return "binary.NativeEndian" }
|
||||
|
||||
// Read reads structured binary data from r into data.
|
||||
// Data must be a pointer to a fixed-size value or a slice
|
||||
// of fixed-size values.
|
||||
// Bytes read from r are decoded using the specified byte order
|
||||
// and written to successive fields of the data.
|
||||
// When decoding boolean values, a zero byte is decoded as false, and
|
||||
// any other non-zero byte is decoded as true.
|
||||
// When reading into structs, the field data for fields with
|
||||
// blank (_) field names is skipped; i.e., blank field names
|
||||
// may be used for padding.
|
||||
// When reading into a struct, all non-blank fields must be exported
|
||||
// or Read may panic.
|
||||
//
|
||||
// The error is [io.EOF] only if no bytes were read.
|
||||
// If an [io.EOF] happens after reading some but not all the bytes,
|
||||
// Read returns [io.ErrUnexpectedEOF].
|
||||
func Read(r io.Reader, order ByteOrder, data any) error {
|
||||
// Fast path for basic types and slices.
|
||||
if n := intDataSize(data); n != 0 {
|
||||
bs := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, bs); err != nil {
|
||||
return err
|
||||
}
|
||||
switch data := data.(type) {
|
||||
case *bool:
|
||||
*data = bs[0] != 0
|
||||
case *int8:
|
||||
*data = int8(bs[0])
|
||||
case *uint8:
|
||||
*data = bs[0]
|
||||
case *int16:
|
||||
*data = int16(order.Uint16(bs))
|
||||
case *uint16:
|
||||
*data = order.Uint16(bs)
|
||||
case *int32:
|
||||
*data = int32(order.Uint32(bs))
|
||||
case *uint32:
|
||||
*data = order.Uint32(bs)
|
||||
case *int64:
|
||||
*data = int64(order.Uint64(bs))
|
||||
case *uint64:
|
||||
*data = order.Uint64(bs)
|
||||
case *float32:
|
||||
*data = math.Float32frombits(order.Uint32(bs))
|
||||
case *float64:
|
||||
*data = math.Float64frombits(order.Uint64(bs))
|
||||
case []bool:
|
||||
for i, x := range bs { // Easier to loop over the input for 8-bit values.
|
||||
data[i] = x != 0
|
||||
}
|
||||
case []int8:
|
||||
for i, x := range bs {
|
||||
data[i] = int8(x)
|
||||
}
|
||||
case []uint8:
|
||||
copy(data, bs)
|
||||
case []int16:
|
||||
for i := range data {
|
||||
data[i] = int16(order.Uint16(bs[2*i:]))
|
||||
}
|
||||
case []uint16:
|
||||
for i := range data {
|
||||
data[i] = order.Uint16(bs[2*i:])
|
||||
}
|
||||
case []int32:
|
||||
for i := range data {
|
||||
data[i] = int32(order.Uint32(bs[4*i:]))
|
||||
}
|
||||
case []uint32:
|
||||
for i := range data {
|
||||
data[i] = order.Uint32(bs[4*i:])
|
||||
}
|
||||
case []int64:
|
||||
for i := range data {
|
||||
data[i] = int64(order.Uint64(bs[8*i:]))
|
||||
}
|
||||
case []uint64:
|
||||
for i := range data {
|
||||
data[i] = order.Uint64(bs[8*i:])
|
||||
}
|
||||
case []float32:
|
||||
for i := range data {
|
||||
data[i] = math.Float32frombits(order.Uint32(bs[4*i:]))
|
||||
}
|
||||
case []float64:
|
||||
for i := range data {
|
||||
data[i] = math.Float64frombits(order.Uint64(bs[8*i:]))
|
||||
}
|
||||
default:
|
||||
n = 0 // fast path doesn't apply
|
||||
}
|
||||
if n != 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to reflect-based decoding.
|
||||
v := reflect.ValueOf(data)
|
||||
size := -1
|
||||
switch v.Kind() {
|
||||
case reflect.Pointer:
|
||||
v = v.Elem()
|
||||
size = dataSize(v)
|
||||
case reflect.Slice:
|
||||
size = dataSize(v)
|
||||
}
|
||||
if size < 0 {
|
||||
return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String())
|
||||
}
|
||||
d := &decoder{order: order, buf: make([]byte, size)}
|
||||
if _, err := io.ReadFull(r, d.buf); err != nil {
|
||||
return err
|
||||
}
|
||||
d.value(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes the binary representation of data into w.
|
||||
// Data must be a fixed-size value or a slice of fixed-size
|
||||
// values, or a pointer to such data.
|
||||
// Boolean values encode as one byte: 1 for true, and 0 for false.
|
||||
// Bytes written to w are encoded using the specified byte order
|
||||
// and read from successive fields of the data.
|
||||
// When writing structs, zero values are written for fields
|
||||
// with blank (_) field names.
|
||||
func Write(w io.Writer, order ByteOrder, data any) error {
|
||||
// Fast path for basic types and slices.
|
||||
if n := intDataSize(data); n != 0 {
|
||||
bs := make([]byte, n)
|
||||
switch v := data.(type) {
|
||||
case *bool:
|
||||
if *v {
|
||||
bs[0] = 1
|
||||
} else {
|
||||
bs[0] = 0
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
bs[0] = 1
|
||||
} else {
|
||||
bs[0] = 0
|
||||
}
|
||||
case []bool:
|
||||
for i, x := range v {
|
||||
if x {
|
||||
bs[i] = 1
|
||||
} else {
|
||||
bs[i] = 0
|
||||
}
|
||||
}
|
||||
case *int8:
|
||||
bs[0] = byte(*v)
|
||||
case int8:
|
||||
bs[0] = byte(v)
|
||||
case []int8:
|
||||
for i, x := range v {
|
||||
bs[i] = byte(x)
|
||||
}
|
||||
case *uint8:
|
||||
bs[0] = *v
|
||||
case uint8:
|
||||
bs[0] = v
|
||||
case []uint8:
|
||||
bs = v
|
||||
case *int16:
|
||||
order.PutUint16(bs, uint16(*v))
|
||||
case int16:
|
||||
order.PutUint16(bs, uint16(v))
|
||||
case []int16:
|
||||
for i, x := range v {
|
||||
order.PutUint16(bs[2*i:], uint16(x))
|
||||
}
|
||||
case *uint16:
|
||||
order.PutUint16(bs, *v)
|
||||
case uint16:
|
||||
order.PutUint16(bs, v)
|
||||
case []uint16:
|
||||
for i, x := range v {
|
||||
order.PutUint16(bs[2*i:], x)
|
||||
}
|
||||
case *int32:
|
||||
order.PutUint32(bs, uint32(*v))
|
||||
case int32:
|
||||
order.PutUint32(bs, uint32(v))
|
||||
case []int32:
|
||||
for i, x := range v {
|
||||
order.PutUint32(bs[4*i:], uint32(x))
|
||||
}
|
||||
case *uint32:
|
||||
order.PutUint32(bs, *v)
|
||||
case uint32:
|
||||
order.PutUint32(bs, v)
|
||||
case []uint32:
|
||||
for i, x := range v {
|
||||
order.PutUint32(bs[4*i:], x)
|
||||
}
|
||||
case *int64:
|
||||
order.PutUint64(bs, uint64(*v))
|
||||
case int64:
|
||||
order.PutUint64(bs, uint64(v))
|
||||
case []int64:
|
||||
for i, x := range v {
|
||||
order.PutUint64(bs[8*i:], uint64(x))
|
||||
}
|
||||
case *uint64:
|
||||
order.PutUint64(bs, *v)
|
||||
case uint64:
|
||||
order.PutUint64(bs, v)
|
||||
case []uint64:
|
||||
for i, x := range v {
|
||||
order.PutUint64(bs[8*i:], x)
|
||||
}
|
||||
case *float32:
|
||||
order.PutUint32(bs, math.Float32bits(*v))
|
||||
case float32:
|
||||
order.PutUint32(bs, math.Float32bits(v))
|
||||
case []float32:
|
||||
for i, x := range v {
|
||||
order.PutUint32(bs[4*i:], math.Float32bits(x))
|
||||
}
|
||||
case *float64:
|
||||
order.PutUint64(bs, math.Float64bits(*v))
|
||||
case float64:
|
||||
order.PutUint64(bs, math.Float64bits(v))
|
||||
case []float64:
|
||||
for i, x := range v {
|
||||
order.PutUint64(bs[8*i:], math.Float64bits(x))
|
||||
}
|
||||
}
|
||||
_, err := w.Write(bs)
|
||||
return err
|
||||
}
|
||||
|
||||
// Fallback to reflect-based encoding.
|
||||
v := reflect.Indirect(reflect.ValueOf(data))
|
||||
size := dataSize(v)
|
||||
if size < 0 {
|
||||
return errors.New("binary.Write: some values are not fixed-sized in type " + reflect.TypeOf(data).String())
|
||||
}
|
||||
buf := make([]byte, size)
|
||||
e := &encoder{order: order, buf: buf}
|
||||
e.value(v)
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// Size returns how many bytes [Write] would generate to encode the value v, which
|
||||
// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data.
|
||||
// If v is neither of these, Size returns -1.
|
||||
func Size(v any) int {
|
||||
return dataSize(reflect.Indirect(reflect.ValueOf(v)))
|
||||
}
|
||||
|
||||
var structSize sync.Map // map[reflect.Type]int
|
||||
|
||||
// dataSize returns the number of bytes the actual data represented by v occupies in memory.
|
||||
// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice
|
||||
// it returns the length of the slice times the element size and does not count the memory
|
||||
// occupied by the header. If the type of v is not acceptable, dataSize returns -1.
|
||||
func dataSize(v reflect.Value) int {
|
||||
switch v.Kind() {
|
||||
case reflect.Slice:
|
||||
if s := sizeof(v.Type().Elem()); s >= 0 {
|
||||
return s * v.Len()
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
if size, ok := structSize.Load(t); ok {
|
||||
return size.(int)
|
||||
}
|
||||
size := sizeof(t)
|
||||
structSize.Store(t, size)
|
||||
return size
|
||||
|
||||
default:
|
||||
if v.IsValid() {
|
||||
return sizeof(v.Type())
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// sizeof returns the size >= 0 of variables for the given type or -1 if the type is not acceptable.
|
||||
func sizeof(t reflect.Type) int {
|
||||
switch t.Kind() {
|
||||
case reflect.Array:
|
||||
if s := sizeof(t.Elem()); s >= 0 {
|
||||
return s * t.Len()
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
sum := 0
|
||||
for i, n := 0, t.NumField(); i < n; i++ {
|
||||
s := sizeof(t.Field(i).Type)
|
||||
if s < 0 {
|
||||
return -1
|
||||
}
|
||||
sum += s
|
||||
}
|
||||
return sum
|
||||
|
||||
case reflect.Bool,
|
||||
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
return int(t.Size())
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
type coder struct {
|
||||
order ByteOrder
|
||||
buf []byte
|
||||
offset int
|
||||
}
|
||||
|
||||
type (
|
||||
decoder coder
|
||||
encoder coder
|
||||
)
|
||||
|
||||
func (d *decoder) bool() bool {
|
||||
x := d.buf[d.offset]
|
||||
d.offset++
|
||||
return x != 0
|
||||
}
|
||||
|
||||
func (e *encoder) bool(x bool) {
|
||||
if x {
|
||||
e.buf[e.offset] = 1
|
||||
} else {
|
||||
e.buf[e.offset] = 0
|
||||
}
|
||||
e.offset++
|
||||
}
|
||||
|
||||
func (d *decoder) uint8() uint8 {
|
||||
x := d.buf[d.offset]
|
||||
d.offset++
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint8(x uint8) {
|
||||
e.buf[e.offset] = x
|
||||
e.offset++
|
||||
}
|
||||
|
||||
func (d *decoder) uint16() uint16 {
|
||||
x := d.order.Uint16(d.buf[d.offset : d.offset+2])
|
||||
d.offset += 2
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint16(x uint16) {
|
||||
e.order.PutUint16(e.buf[e.offset:e.offset+2], x)
|
||||
e.offset += 2
|
||||
}
|
||||
|
||||
func (d *decoder) uint32() uint32 {
|
||||
x := d.order.Uint32(d.buf[d.offset : d.offset+4])
|
||||
d.offset += 4
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint32(x uint32) {
|
||||
e.order.PutUint32(e.buf[e.offset:e.offset+4], x)
|
||||
e.offset += 4
|
||||
}
|
||||
|
||||
func (d *decoder) uint64() uint64 {
|
||||
x := d.order.Uint64(d.buf[d.offset : d.offset+8])
|
||||
d.offset += 8
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint64(x uint64) {
|
||||
e.order.PutUint64(e.buf[e.offset:e.offset+8], x)
|
||||
e.offset += 8
|
||||
}
|
||||
|
||||
func (d *decoder) int8() int8 { return int8(d.uint8()) }
|
||||
|
||||
func (e *encoder) int8(x int8) { e.uint8(uint8(x)) }
|
||||
|
||||
func (d *decoder) int16() int16 { return int16(d.uint16()) }
|
||||
|
||||
func (e *encoder) int16(x int16) { e.uint16(uint16(x)) }
|
||||
|
||||
func (d *decoder) int32() int32 { return int32(d.uint32()) }
|
||||
|
||||
func (e *encoder) int32(x int32) { e.uint32(uint32(x)) }
|
||||
|
||||
func (d *decoder) int64() int64 { return int64(d.uint64()) }
|
||||
|
||||
func (e *encoder) int64(x int64) { e.uint64(uint64(x)) }
|
||||
|
||||
func (d *decoder) value(v reflect.Value) {
|
||||
switch v.Kind() {
|
||||
case reflect.Array:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
d.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
l := v.NumField()
|
||||
for i := 0; i < l; i++ {
|
||||
// Note: Calling v.CanSet() below is an optimization.
|
||||
// It would be sufficient to check the field name,
|
||||
// but creating the StructField info for each field is
|
||||
// costly (run "go test -bench=ReadStruct" and compare
|
||||
// results when making changes to this code).
|
||||
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
|
||||
d.value(v)
|
||||
} else {
|
||||
d.skip(v)
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
d.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Bool:
|
||||
v.SetBool(d.bool())
|
||||
|
||||
case reflect.Int8:
|
||||
v.SetInt(int64(d.int8()))
|
||||
case reflect.Int16:
|
||||
v.SetInt(int64(d.int16()))
|
||||
case reflect.Int32:
|
||||
v.SetInt(int64(d.int32()))
|
||||
case reflect.Int64:
|
||||
v.SetInt(d.int64())
|
||||
|
||||
case reflect.Uint8:
|
||||
v.SetUint(uint64(d.uint8()))
|
||||
case reflect.Uint16:
|
||||
v.SetUint(uint64(d.uint16()))
|
||||
case reflect.Uint32:
|
||||
v.SetUint(uint64(d.uint32()))
|
||||
case reflect.Uint64:
|
||||
v.SetUint(d.uint64())
|
||||
|
||||
case reflect.Float32:
|
||||
v.SetFloat(float64(math.Float32frombits(d.uint32())))
|
||||
case reflect.Float64:
|
||||
v.SetFloat(math.Float64frombits(d.uint64()))
|
||||
|
||||
case reflect.Complex64:
|
||||
v.SetComplex(complex(
|
||||
float64(math.Float32frombits(d.uint32())),
|
||||
float64(math.Float32frombits(d.uint32())),
|
||||
))
|
||||
case reflect.Complex128:
|
||||
v.SetComplex(complex(
|
||||
math.Float64frombits(d.uint64()),
|
||||
math.Float64frombits(d.uint64()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *encoder) value(v reflect.Value) {
|
||||
switch v.Kind() {
|
||||
case reflect.Array:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
e.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
l := v.NumField()
|
||||
for i := 0; i < l; i++ {
|
||||
// see comment for corresponding code in decoder.value()
|
||||
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
|
||||
e.value(v)
|
||||
} else {
|
||||
e.skip(v)
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
e.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Bool:
|
||||
e.bool(v.Bool())
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Int8:
|
||||
e.int8(int8(v.Int()))
|
||||
case reflect.Int16:
|
||||
e.int16(int16(v.Int()))
|
||||
case reflect.Int32:
|
||||
e.int32(int32(v.Int()))
|
||||
case reflect.Int64:
|
||||
e.int64(v.Int())
|
||||
}
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Uint8:
|
||||
e.uint8(uint8(v.Uint()))
|
||||
case reflect.Uint16:
|
||||
e.uint16(uint16(v.Uint()))
|
||||
case reflect.Uint32:
|
||||
e.uint32(uint32(v.Uint()))
|
||||
case reflect.Uint64:
|
||||
e.uint64(v.Uint())
|
||||
}
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Float32:
|
||||
e.uint32(math.Float32bits(float32(v.Float())))
|
||||
case reflect.Float64:
|
||||
e.uint64(math.Float64bits(v.Float()))
|
||||
}
|
||||
|
||||
case reflect.Complex64, reflect.Complex128:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Complex64:
|
||||
x := v.Complex()
|
||||
e.uint32(math.Float32bits(float32(real(x))))
|
||||
e.uint32(math.Float32bits(float32(imag(x))))
|
||||
case reflect.Complex128:
|
||||
x := v.Complex()
|
||||
e.uint64(math.Float64bits(real(x)))
|
||||
e.uint64(math.Float64bits(imag(x)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *decoder) skip(v reflect.Value) {
|
||||
d.offset += dataSize(v)
|
||||
}
|
||||
|
||||
func (e *encoder) skip(v reflect.Value) {
|
||||
n := dataSize(v)
|
||||
zero := e.buf[e.offset : e.offset+n]
|
||||
for i := range zero {
|
||||
zero[i] = 0
|
||||
}
|
||||
e.offset += n
|
||||
}
|
||||
|
||||
// intDataSize returns the size of the data required to represent the data when encoded.
|
||||
// It returns zero if the type cannot be implemented by the fast path in Read or Write.
|
||||
func intDataSize(data any) int {
|
||||
switch data := data.(type) {
|
||||
case bool, int8, uint8, *bool, *int8, *uint8:
|
||||
return 1
|
||||
case []bool:
|
||||
return len(data)
|
||||
case []int8:
|
||||
return len(data)
|
||||
case []uint8:
|
||||
return len(data)
|
||||
case int16, uint16, *int16, *uint16:
|
||||
return 2
|
||||
case []int16:
|
||||
return 2 * len(data)
|
||||
case []uint16:
|
||||
return 2 * len(data)
|
||||
case int32, uint32, *int32, *uint32:
|
||||
return 4
|
||||
case []int32:
|
||||
return 4 * len(data)
|
||||
case []uint32:
|
||||
return 4 * len(data)
|
||||
case int64, uint64, *int64, *uint64:
|
||||
return 8
|
||||
case []int64:
|
||||
return 8 * len(data)
|
||||
case []uint64:
|
||||
return 8 * len(data)
|
||||
case float32, *float32:
|
||||
return 4
|
||||
case float64, *float64:
|
||||
return 8
|
||||
case []float32:
|
||||
return 4 * len(data)
|
||||
case []float64:
|
||||
return 8 * len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
18
common/binary/export.go
Normal file
18
common/binary/export.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package binary
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func DataSize(t reflect.Value) int {
|
||||
return dataSize(t)
|
||||
}
|
||||
|
||||
func EncodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
|
||||
(&encoder{order: order, buf: buf}).value(v)
|
||||
}
|
||||
|
||||
func DecodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
|
||||
(&decoder{order: order, buf: buf}).value(v)
|
||||
}
|
14
common/binary/native_endian_big.go
Normal file
14
common/binary/native_endian_big.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build armbe || arm64be || m68k || mips || mips64 || mips64p32 || ppc || ppc64 || s390 || s390x || shbe || sparc || sparc64
|
||||
|
||||
package binary
|
||||
|
||||
type nativeEndian struct {
|
||||
bigEndian
|
||||
}
|
||||
|
||||
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var NativeEndian nativeEndian
|
14
common/binary/native_endian_little.go
Normal file
14
common/binary/native_endian_little.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build 386 || amd64 || amd64p32 || alpha || arm || arm64 || loong64 || mipsle || mips64le || mips64p32le || nios2 || ppc64le || riscv || riscv64 || sh || wasm
|
||||
|
||||
package binary
|
||||
|
||||
type nativeEndian struct {
|
||||
littleEndian
|
||||
}
|
||||
|
||||
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var NativeEndian nativeEndian
|
166
common/binary/varint.go
Normal file
166
common/binary/varint.go
Normal file
|
@ -0,0 +1,166 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package binary
|
||||
|
||||
// This file implements "varint" encoding of 64-bit integers.
|
||||
// The encoding is:
|
||||
// - unsigned integers are serialized 7 bits at a time, starting with the
|
||||
// least significant bits
|
||||
// - the most significant bit (msb) in each output byte indicates if there
|
||||
// is a continuation byte (msb = 1)
|
||||
// - signed integers are mapped to unsigned integers using "zig-zag"
|
||||
// encoding: Positive values x are written as 2*x + 0, negative values
|
||||
// are written as 2*(^x) + 1; that is, negative numbers are complemented
|
||||
// and whether to complement is encoded in bit 0.
|
||||
//
|
||||
// Design note:
|
||||
// At most 10 bytes are needed for 64-bit values. The encoding could
|
||||
// be more dense: a full 64-bit value needs an extra byte just to hold bit 63.
|
||||
// Instead, the msb of the previous byte could be used to hold bit 63 since we
|
||||
// know there can't be more than 64 bits. This is a trivial improvement and
|
||||
// would reduce the maximum encoding length to 9 bytes. However, it breaks the
|
||||
// invariant that the msb is always the "continuation bit" and thus makes the
|
||||
// format incompatible with a varint encoding for larger numbers (say 128-bit).
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer.
|
||||
const (
|
||||
MaxVarintLen16 = 3
|
||||
MaxVarintLen32 = 5
|
||||
MaxVarintLen64 = 10
|
||||
)
|
||||
|
||||
// AppendUvarint appends the varint-encoded form of x,
|
||||
// as generated by [PutUvarint], to buf and returns the extended buffer.
|
||||
func AppendUvarint(buf []byte, x uint64) []byte {
|
||||
for x >= 0x80 {
|
||||
buf = append(buf, byte(x)|0x80)
|
||||
x >>= 7
|
||||
}
|
||||
return append(buf, byte(x))
|
||||
}
|
||||
|
||||
// PutUvarint encodes a uint64 into buf and returns the number of bytes written.
|
||||
// If the buffer is too small, PutUvarint will panic.
|
||||
func PutUvarint(buf []byte, x uint64) int {
|
||||
i := 0
|
||||
for x >= 0x80 {
|
||||
buf[i] = byte(x) | 0x80
|
||||
x >>= 7
|
||||
i++
|
||||
}
|
||||
buf[i] = byte(x)
|
||||
return i + 1
|
||||
}
|
||||
|
||||
// Uvarint decodes a uint64 from buf and returns that value and the
|
||||
// number of bytes read (> 0). If an error occurred, the value is 0
|
||||
// and the number of bytes n is <= 0 meaning:
|
||||
//
|
||||
// n == 0: buf too small
|
||||
// n < 0: value larger than 64 bits (overflow)
|
||||
// and -n is the number of bytes read
|
||||
func Uvarint(buf []byte) (uint64, int) {
|
||||
var x uint64
|
||||
var s uint
|
||||
for i, b := range buf {
|
||||
if i == MaxVarintLen64 {
|
||||
// Catch byte reads past MaxVarintLen64.
|
||||
// See issue https://golang.org/issues/41185
|
||||
return 0, -(i + 1) // overflow
|
||||
}
|
||||
if b < 0x80 {
|
||||
if i == MaxVarintLen64-1 && b > 1 {
|
||||
return 0, -(i + 1) // overflow
|
||||
}
|
||||
return x | uint64(b)<<s, i + 1
|
||||
}
|
||||
x |= uint64(b&0x7f) << s
|
||||
s += 7
|
||||
}
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
// AppendVarint appends the varint-encoded form of x,
|
||||
// as generated by [PutVarint], to buf and returns the extended buffer.
|
||||
func AppendVarint(buf []byte, x int64) []byte {
|
||||
ux := uint64(x) << 1
|
||||
if x < 0 {
|
||||
ux = ^ux
|
||||
}
|
||||
return AppendUvarint(buf, ux)
|
||||
}
|
||||
|
||||
// PutVarint encodes an int64 into buf and returns the number of bytes written.
|
||||
// If the buffer is too small, PutVarint will panic.
|
||||
func PutVarint(buf []byte, x int64) int {
|
||||
ux := uint64(x) << 1
|
||||
if x < 0 {
|
||||
ux = ^ux
|
||||
}
|
||||
return PutUvarint(buf, ux)
|
||||
}
|
||||
|
||||
// Varint decodes an int64 from buf and returns that value and the
|
||||
// number of bytes read (> 0). If an error occurred, the value is 0
|
||||
// and the number of bytes n is <= 0 with the following meaning:
|
||||
//
|
||||
// n == 0: buf too small
|
||||
// n < 0: value larger than 64 bits (overflow)
|
||||
// and -n is the number of bytes read
|
||||
func Varint(buf []byte) (int64, int) {
|
||||
ux, n := Uvarint(buf) // ok to continue in presence of error
|
||||
x := int64(ux >> 1)
|
||||
if ux&1 != 0 {
|
||||
x = ^x
|
||||
}
|
||||
return x, n
|
||||
}
|
||||
|
||||
var errOverflow = errors.New("binary: varint overflows a 64-bit integer")
|
||||
|
||||
// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64.
|
||||
// The error is [io.EOF] only if no bytes were read.
|
||||
// If an [io.EOF] happens after reading some but not all the bytes,
|
||||
// ReadUvarint returns [io.ErrUnexpectedEOF].
|
||||
func ReadUvarint(r io.ByteReader) (uint64, error) {
|
||||
var x uint64
|
||||
var s uint
|
||||
for i := 0; i < MaxVarintLen64; i++ {
|
||||
b, err := r.ReadByte()
|
||||
if err != nil {
|
||||
if i > 0 && err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return x, err
|
||||
}
|
||||
if b < 0x80 {
|
||||
if i == MaxVarintLen64-1 && b > 1 {
|
||||
return x, errOverflow
|
||||
}
|
||||
return x | uint64(b)<<s, nil
|
||||
}
|
||||
x |= uint64(b&0x7f) << s
|
||||
s += 7
|
||||
}
|
||||
return x, errOverflow
|
||||
}
|
||||
|
||||
// ReadVarint reads an encoded signed integer from r and returns it as an int64.
|
||||
// The error is [io.EOF] only if no bytes were read.
|
||||
// If an [io.EOF] happens after reading some but not all the bytes,
|
||||
// ReadVarint returns [io.ErrUnexpectedEOF].
|
||||
func ReadVarint(r io.ByteReader) (int64, error) {
|
||||
ux, err := ReadUvarint(r) // ok to continue in presence of error
|
||||
x := int64(ux >> 1)
|
||||
if ux&1 != 0 {
|
||||
x = ^x
|
||||
}
|
||||
return x, err
|
||||
}
|
|
@ -8,7 +8,7 @@ import (
|
|||
"sync"
|
||||
)
|
||||
|
||||
var DefaultAllocator = newDefaultAllocer()
|
||||
var DefaultAllocator = newDefaultAllocator()
|
||||
|
||||
type Allocator interface {
|
||||
Get(size int) []byte
|
||||
|
@ -17,22 +17,28 @@ type Allocator interface {
|
|||
|
||||
// defaultAllocator for incoming frames, optimized to prevent overwriting after zeroing
|
||||
type defaultAllocator struct {
|
||||
buffers []sync.Pool
|
||||
buffers [11]sync.Pool
|
||||
}
|
||||
|
||||
// NewAllocator initiates a []byte allocator for frames less than 65536 bytes,
|
||||
// the waste(memory fragmentation) of space allocation is guaranteed to be
|
||||
// no more than 50%.
|
||||
func newDefaultAllocer() Allocator {
|
||||
alloc := new(defaultAllocator)
|
||||
alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K
|
||||
for k := range alloc.buffers {
|
||||
i := k
|
||||
alloc.buffers[k].New = func() any {
|
||||
return make([]byte, 1<<uint32(i))
|
||||
}
|
||||
func newDefaultAllocator() Allocator {
|
||||
return &defaultAllocator{
|
||||
buffers: [...]sync.Pool{ // 64B -> 64K
|
||||
{New: func() any { return new([1 << 6]byte) }},
|
||||
{New: func() any { return new([1 << 7]byte) }},
|
||||
{New: func() any { return new([1 << 8]byte) }},
|
||||
{New: func() any { return new([1 << 9]byte) }},
|
||||
{New: func() any { return new([1 << 10]byte) }},
|
||||
{New: func() any { return new([1 << 11]byte) }},
|
||||
{New: func() any { return new([1 << 12]byte) }},
|
||||
{New: func() any { return new([1 << 13]byte) }},
|
||||
{New: func() any { return new([1 << 14]byte) }},
|
||||
{New: func() any { return new([1 << 15]byte) }},
|
||||
{New: func() any { return new([1 << 16]byte) }},
|
||||
},
|
||||
}
|
||||
return alloc
|
||||
}
|
||||
|
||||
// Get a []byte from pool with most appropriate cap
|
||||
|
@ -41,12 +47,42 @@ func (alloc *defaultAllocator) Get(size int) []byte {
|
|||
return nil
|
||||
}
|
||||
|
||||
bits := msb(size)
|
||||
if size == 1<<bits {
|
||||
return alloc.buffers[bits].Get().([]byte)[:size]
|
||||
var index uint16
|
||||
if size > 64 {
|
||||
index = msb(size)
|
||||
if size != 1<<index {
|
||||
index += 1
|
||||
}
|
||||
index -= 6
|
||||
}
|
||||
|
||||
return alloc.buffers[bits+1].Get().([]byte)[:size]
|
||||
buffer := alloc.buffers[index].Get()
|
||||
switch index {
|
||||
case 0:
|
||||
return buffer.(*[1 << 6]byte)[:size]
|
||||
case 1:
|
||||
return buffer.(*[1 << 7]byte)[:size]
|
||||
case 2:
|
||||
return buffer.(*[1 << 8]byte)[:size]
|
||||
case 3:
|
||||
return buffer.(*[1 << 9]byte)[:size]
|
||||
case 4:
|
||||
return buffer.(*[1 << 10]byte)[:size]
|
||||
case 5:
|
||||
return buffer.(*[1 << 11]byte)[:size]
|
||||
case 6:
|
||||
return buffer.(*[1 << 12]byte)[:size]
|
||||
case 7:
|
||||
return buffer.(*[1 << 13]byte)[:size]
|
||||
case 8:
|
||||
return buffer.(*[1 << 14]byte)[:size]
|
||||
case 9:
|
||||
return buffer.(*[1 << 15]byte)[:size]
|
||||
case 10:
|
||||
return buffer.(*[1 << 16]byte)[:size]
|
||||
default:
|
||||
panic("invalid pool index")
|
||||
}
|
||||
}
|
||||
|
||||
// Put returns a []byte to pool for future use,
|
||||
|
@ -56,10 +92,37 @@ func (alloc *defaultAllocator) Put(buf []byte) error {
|
|||
if cap(buf) == 0 || cap(buf) > 65536 || cap(buf) != 1<<bits {
|
||||
return errors.New("allocator Put() incorrect buffer size")
|
||||
}
|
||||
bits -= 6
|
||||
buf = buf[:cap(buf)]
|
||||
|
||||
//nolint
|
||||
//lint:ignore SA6002 ignore temporarily
|
||||
alloc.buffers[bits].Put(buf)
|
||||
switch bits {
|
||||
case 0:
|
||||
alloc.buffers[bits].Put((*[1 << 6]byte)(buf))
|
||||
case 1:
|
||||
alloc.buffers[bits].Put((*[1 << 7]byte)(buf))
|
||||
case 2:
|
||||
alloc.buffers[bits].Put((*[1 << 8]byte)(buf))
|
||||
case 3:
|
||||
alloc.buffers[bits].Put((*[1 << 9]byte)(buf))
|
||||
case 4:
|
||||
alloc.buffers[bits].Put((*[1 << 10]byte)(buf))
|
||||
case 5:
|
||||
alloc.buffers[bits].Put((*[1 << 11]byte)(buf))
|
||||
case 6:
|
||||
alloc.buffers[bits].Put((*[1 << 12]byte)(buf))
|
||||
case 7:
|
||||
alloc.buffers[bits].Put((*[1 << 13]byte)(buf))
|
||||
case 8:
|
||||
alloc.buffers[bits].Put((*[1 << 14]byte)(buf))
|
||||
case 9:
|
||||
alloc.buffers[bits].Put((*[1 << 15]byte)(buf))
|
||||
case 10:
|
||||
alloc.buffers[bits].Put((*[1 << 16]byte)(buf))
|
||||
default:
|
||||
panic("invalid pool index")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -4,39 +4,36 @@ import (
|
|||
"crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/debug"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
)
|
||||
|
||||
const ReversedHeader = 1024
|
||||
|
||||
type Buffer struct {
|
||||
data []byte
|
||||
start int
|
||||
end int
|
||||
refs int32
|
||||
managed bool
|
||||
closed bool
|
||||
data []byte
|
||||
start int
|
||||
end int
|
||||
capacity int
|
||||
refs atomic.Int32
|
||||
managed bool
|
||||
}
|
||||
|
||||
func New() *Buffer {
|
||||
return &Buffer{
|
||||
data: Get(BufferSize),
|
||||
start: ReversedHeader,
|
||||
end: ReversedHeader,
|
||||
managed: true,
|
||||
data: Get(BufferSize),
|
||||
capacity: BufferSize,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
||||
func NewPacket() *Buffer {
|
||||
return &Buffer{
|
||||
data: Get(UDPBufferSize),
|
||||
start: ReversedHeader,
|
||||
end: ReversedHeader,
|
||||
managed: true,
|
||||
data: Get(UDPBufferSize),
|
||||
capacity: UDPBufferSize,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,40 +42,29 @@ func NewSize(size int) *Buffer {
|
|||
return &Buffer{}
|
||||
} else if size > 65535 {
|
||||
return &Buffer{
|
||||
data: make([]byte, size),
|
||||
data: make([]byte, size),
|
||||
capacity: size,
|
||||
}
|
||||
}
|
||||
return &Buffer{
|
||||
data: Get(size),
|
||||
managed: true,
|
||||
data: Get(size),
|
||||
capacity: size,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated: use New instead.
|
||||
func StackNew() *Buffer {
|
||||
return New()
|
||||
}
|
||||
|
||||
// Deprecated: use NewPacket instead.
|
||||
func StackNewPacket() *Buffer {
|
||||
return NewPacket()
|
||||
}
|
||||
|
||||
// Deprecated: use NewSize instead.
|
||||
func StackNewSize(size int) *Buffer {
|
||||
return NewSize(size)
|
||||
}
|
||||
|
||||
func As(data []byte) *Buffer {
|
||||
return &Buffer{
|
||||
data: data,
|
||||
end: len(data),
|
||||
data: data,
|
||||
end: len(data),
|
||||
capacity: len(data),
|
||||
}
|
||||
}
|
||||
|
||||
func With(data []byte) *Buffer {
|
||||
return &Buffer{
|
||||
data: data,
|
||||
data: data,
|
||||
capacity: len(data),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -92,8 +78,8 @@ func (b *Buffer) SetByte(index int, value byte) {
|
|||
|
||||
func (b *Buffer) Extend(n int) []byte {
|
||||
end := b.end + n
|
||||
if end > cap(b.data) {
|
||||
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",end " + strconv.Itoa(b.end) + ", need " + strconv.Itoa(n))
|
||||
if end > b.capacity {
|
||||
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",end ", b.end, ", need ", n))
|
||||
}
|
||||
ext := b.data[b.end:end]
|
||||
b.end = end
|
||||
|
@ -115,14 +101,14 @@ func (b *Buffer) Write(data []byte) (n int, err error) {
|
|||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n = copy(b.data[b.end:], data)
|
||||
n = copy(b.data[b.end:b.capacity], data)
|
||||
b.end += n
|
||||
return
|
||||
}
|
||||
|
||||
func (b *Buffer) ExtendHeader(n int) []byte {
|
||||
if b.start < n {
|
||||
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",start " + strconv.Itoa(b.start) + ", need " + strconv.Itoa(n))
|
||||
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",start ", b.start, ", need ", n))
|
||||
}
|
||||
b.start -= n
|
||||
return b.data[b.start : b.start+n]
|
||||
|
@ -175,7 +161,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
|
|||
}
|
||||
|
||||
func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) {
|
||||
if b.end+size > b.Cap() {
|
||||
if b.end+size > b.capacity {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n, err = io.ReadFull(r, b.data[b.end:b.end+size])
|
||||
|
@ -212,7 +198,7 @@ func (b *Buffer) WriteString(s string) (n int, err error) {
|
|||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n = copy(b.data[b.end:], s)
|
||||
n = copy(b.data[b.end:b.capacity], s)
|
||||
b.end += n
|
||||
return
|
||||
}
|
||||
|
@ -227,13 +213,10 @@ func (b *Buffer) WriteZero() error {
|
|||
}
|
||||
|
||||
func (b *Buffer) WriteZeroN(n int) error {
|
||||
if b.end+n > b.Cap() {
|
||||
if b.end+n > b.capacity {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
for i := b.end; i < b.end+n; i++ {
|
||||
b.data[i] = 0
|
||||
}
|
||||
b.end += n
|
||||
common.ClearArray(b.Extend(n))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -276,40 +259,63 @@ func (b *Buffer) Resize(start, end int) {
|
|||
b.end = b.start + end
|
||||
}
|
||||
|
||||
func (b *Buffer) Reset() {
|
||||
b.start = ReversedHeader
|
||||
b.end = ReversedHeader
|
||||
func (b *Buffer) Reserve(n int) {
|
||||
if n > b.capacity {
|
||||
panic(F.ToString("buffer overflow: capacity ", b.capacity, ", need ", n))
|
||||
}
|
||||
b.capacity -= n
|
||||
}
|
||||
|
||||
func (b *Buffer) FullReset() {
|
||||
func (b *Buffer) OverCap(n int) {
|
||||
if b.capacity+n > len(b.data) {
|
||||
panic(F.ToString("buffer overflow: capacity ", len(b.data), ", need ", b.capacity+n))
|
||||
}
|
||||
b.capacity += n
|
||||
}
|
||||
|
||||
func (b *Buffer) Reset() {
|
||||
b.start = 0
|
||||
b.end = 0
|
||||
b.capacity = len(b.data)
|
||||
}
|
||||
|
||||
// Deprecated: use Reset instead.
|
||||
func (b *Buffer) FullReset() {
|
||||
b.Reset()
|
||||
}
|
||||
|
||||
func (b *Buffer) IncRef() {
|
||||
atomic.AddInt32(&b.refs, 1)
|
||||
b.refs.Add(1)
|
||||
}
|
||||
|
||||
func (b *Buffer) DecRef() {
|
||||
atomic.AddInt32(&b.refs, -1)
|
||||
b.refs.Add(-1)
|
||||
}
|
||||
|
||||
func (b *Buffer) Release() {
|
||||
if b == nil || b.closed || !b.managed {
|
||||
if b == nil || !b.managed {
|
||||
return
|
||||
}
|
||||
if atomic.LoadInt32(&b.refs) > 0 {
|
||||
if b.refs.Load() > 0 {
|
||||
return
|
||||
}
|
||||
common.Must(Put(b.data))
|
||||
*b = Buffer{closed: true}
|
||||
*b = Buffer{}
|
||||
}
|
||||
|
||||
func (b *Buffer) Cut(start int, end int) *Buffer {
|
||||
b.start += start
|
||||
b.end = len(b.data) - end
|
||||
return &Buffer{
|
||||
data: b.data[b.start:b.end],
|
||||
func (b *Buffer) Leak() {
|
||||
if debug.Enabled {
|
||||
if b == nil || !b.managed {
|
||||
return
|
||||
}
|
||||
refs := b.refs.Load()
|
||||
if refs == 0 {
|
||||
panic("leaking buffer")
|
||||
} else {
|
||||
panic(F.ToString("leaking buffer with ", refs, " references"))
|
||||
}
|
||||
} else {
|
||||
b.Release()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -322,6 +328,10 @@ func (b *Buffer) Len() int {
|
|||
}
|
||||
|
||||
func (b *Buffer) Cap() int {
|
||||
return b.capacity
|
||||
}
|
||||
|
||||
func (b *Buffer) RawCap() int {
|
||||
return len(b.data)
|
||||
}
|
||||
|
||||
|
@ -329,10 +339,6 @@ func (b *Buffer) Bytes() []byte {
|
|||
return b.data[b.start:b.end]
|
||||
}
|
||||
|
||||
func (b *Buffer) Slice() []byte {
|
||||
return b.data
|
||||
}
|
||||
|
||||
func (b *Buffer) From(n int) []byte {
|
||||
return b.data[b.start+n : b.end]
|
||||
}
|
||||
|
@ -350,11 +356,11 @@ func (b *Buffer) Index(start int) []byte {
|
|||
}
|
||||
|
||||
func (b *Buffer) FreeLen() int {
|
||||
return b.Cap() - b.end
|
||||
return b.capacity - b.end
|
||||
}
|
||||
|
||||
func (b *Buffer) FreeBytes() []byte {
|
||||
return b.data[b.end:b.Cap()]
|
||||
return b.data[b.end:b.capacity]
|
||||
}
|
||||
|
||||
func (b *Buffer) IsEmpty() bool {
|
||||
|
@ -362,7 +368,7 @@ func (b *Buffer) IsEmpty() bool {
|
|||
}
|
||||
|
||||
func (b *Buffer) IsFull() bool {
|
||||
return b.end == b.Cap()
|
||||
return b.end == b.capacity
|
||||
}
|
||||
|
||||
func (b *Buffer) ToOwned() *Buffer {
|
||||
|
@ -370,5 +376,6 @@ func (b *Buffer) ToOwned() *Buffer {
|
|||
copy(n.data[b.start:b.end], b.data[b.start:b.end])
|
||||
n.start = b.start
|
||||
n.end = b.end
|
||||
n.capacity = b.capacity
|
||||
return n
|
||||
}
|
||||
|
|
34
common/bufio/addr_bsd.go
Normal file
34
common/bufio/addr_bsd.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package bufio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
|
||||
if destination.Addr().Is4() {
|
||||
sa := unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: destination.Addr().As4(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = unix.SizeofSockaddrInet4
|
||||
} else {
|
||||
sa := unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: destination.Addr().As16(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = unix.SizeofSockaddrInet6
|
||||
}
|
||||
return
|
||||
}
|
|
@ -9,19 +9,20 @@ import (
|
|||
|
||||
type AddrConn struct {
|
||||
net.Conn
|
||||
M.Metadata
|
||||
Source M.Socksaddr
|
||||
Destination M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *AddrConn) LocalAddr() net.Addr {
|
||||
if c.Metadata.Destination.IsValid() {
|
||||
return c.Metadata.Destination.TCPAddr()
|
||||
if c.Destination.IsValid() {
|
||||
return c.Destination.TCPAddr()
|
||||
}
|
||||
return c.Conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *AddrConn) RemoteAddr() net.Addr {
|
||||
if c.Metadata.Source.IsValid() {
|
||||
return c.Metadata.Source.TCPAddr()
|
||||
if c.Source.IsValid() {
|
||||
return c.Source.TCPAddr()
|
||||
}
|
||||
return c.Conn.RemoteAddr()
|
||||
}
|
30
common/bufio/addr_linux.go
Normal file
30
common/bufio/addr_linux.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
|
||||
if destination.Addr().Is4() {
|
||||
sa := unix.RawSockaddrInet4{
|
||||
Family: unix.AF_INET,
|
||||
Addr: destination.Addr().As4(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = unix.SizeofSockaddrInet4
|
||||
} else {
|
||||
sa := unix.RawSockaddrInet6{
|
||||
Family: unix.AF_INET6,
|
||||
Addr: destination.Addr().As16(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = unix.SizeofSockaddrInet6
|
||||
}
|
||||
return
|
||||
}
|
30
common/bufio/addr_windows.go
Normal file
30
common/bufio/addr_windows.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen int32) {
|
||||
if destination.Addr().Is4() {
|
||||
sa := windows.RawSockaddrInet4{
|
||||
Family: windows.AF_INET,
|
||||
Addr: destination.Addr().As4(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = int32(unsafe.Sizeof(sa))
|
||||
} else {
|
||||
sa := windows.RawSockaddrInet6{
|
||||
Family: windows.AF_INET6,
|
||||
Addr: destination.Addr().As16(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = int32(unsafe.Sizeof(sa))
|
||||
}
|
||||
return
|
||||
}
|
|
@ -8,51 +8,76 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type BindPacketConn struct {
|
||||
type BindPacketConn interface {
|
||||
N.NetPacketConn
|
||||
Addr net.Addr
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) *BindPacketConn {
|
||||
return &BindPacketConn{
|
||||
type bindPacketConn struct {
|
||||
N.NetPacketConn
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) BindPacketConn {
|
||||
return &bindPacketConn{
|
||||
NewPacketConn(conn),
|
||||
addr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) Read(b []byte) (n int, err error) {
|
||||
func (c *bindPacketConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.Addr)
|
||||
func (c *bindPacketConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.addr)
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) RemoteAddr() net.Addr {
|
||||
return c.Addr
|
||||
func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
|
||||
if !isReadWaiter {
|
||||
return nil, false
|
||||
}
|
||||
return &bindPacketReadWaiter{readWaiter}, true
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) Upstream() any {
|
||||
func (c *bindPacketConn) RemoteAddr() net.Addr {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (c *bindPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
var (
|
||||
_ N.NetPacketConn = (*UnbindPacketConn)(nil)
|
||||
_ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil)
|
||||
)
|
||||
|
||||
type UnbindPacketConn struct {
|
||||
N.ExtendedConn
|
||||
Addr M.Socksaddr
|
||||
addr M.Socksaddr
|
||||
}
|
||||
|
||||
func NewUnbindPacketConn(conn net.Conn) *UnbindPacketConn {
|
||||
func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn {
|
||||
return &UnbindPacketConn{
|
||||
NewExtendedConn(conn),
|
||||
M.SocksaddrFromNet(conn.RemoteAddr()),
|
||||
}
|
||||
}
|
||||
|
||||
func NewUnbindPacketConnWithAddr(conn net.Conn, addr M.Socksaddr) N.NetPacketConn {
|
||||
return &UnbindPacketConn{
|
||||
NewExtendedConn(conn),
|
||||
addr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, err = c.ExtendedConn.Read(p)
|
||||
if err == nil {
|
||||
addr = c.Addr.UDPAddr()
|
||||
addr = c.addr.UDPAddr()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -66,7 +91,7 @@ func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination = c.Addr
|
||||
destination = c.addr
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -74,6 +99,67 @@ func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error
|
|||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
|
||||
readWaiter, isReadWaiter := CreateReadWaiter(c.ExtendedConn)
|
||||
if !isReadWaiter {
|
||||
return nil, false
|
||||
}
|
||||
return &unbindPacketReadWaiter{readWaiter, c.addr}, true
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
func NewServerPacketConn(conn net.PacketConn) N.ExtendedConn {
|
||||
return &serverPacketConn{
|
||||
NetPacketConn: NewPacketConn(conn),
|
||||
}
|
||||
}
|
||||
|
||||
type serverPacketConn struct {
|
||||
N.NetPacketConn
|
||||
remoteAddr M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Read(p []byte) (n int, err error) {
|
||||
n, addr, err := c.NetPacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.remoteAddr = M.SocksaddrFromNet(addr)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
destination, err := c.NetPacketConn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.remoteAddr = destination
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Write(p []byte) (n int, err error) {
|
||||
return c.NetPacketConn.WriteTo(p, c.remoteAddr.UDPAddr())
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
return c.NetPacketConn.WritePacket(buffer, c.remoteAddr)
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
|
||||
if !isReadWaiter {
|
||||
return nil, false
|
||||
}
|
||||
return &serverPacketReadWaiter{c, readWaiter}, true
|
||||
}
|
||||
|
|
62
common/bufio/bind_wait.go
Normal file
62
common/bufio/bind_wait.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ N.ReadWaiter = (*bindPacketReadWaiter)(nil)
|
||||
|
||||
type bindPacketReadWaiter struct {
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (w *bindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return w.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (w *bindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
buffer, _, err = w.readWaiter.WaitReadPacket()
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*unbindPacketReadWaiter)(nil)
|
||||
|
||||
type unbindPacketReadWaiter struct {
|
||||
readWaiter N.ReadWaiter
|
||||
addr M.Socksaddr
|
||||
}
|
||||
|
||||
func (w *unbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return w.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (w *unbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
buffer, err = w.readWaiter.WaitReadBuffer()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination = w.addr
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.ReadWaiter = (*serverPacketReadWaiter)(nil)
|
||||
|
||||
type serverPacketReadWaiter struct {
|
||||
*serverPacketConn
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (w *serverPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return w.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (w *serverPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
buffer, destination, err := w.readWaiter.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.remoteAddr = destination
|
||||
return
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
|
@ -37,7 +38,26 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.buffer.FullReset()
|
||||
w.buffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) WriteByte(c byte) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
if w.buffer == nil {
|
||||
return common.Error(w.upstream.Write([]byte{c}))
|
||||
}
|
||||
for {
|
||||
err := w.buffer.WriteByte(c)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
_, err = w.upstream.Write(w.buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.buffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ package bufio
|
|||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
@ -60,13 +59,6 @@ func (c *CachedConn) WriteTo(w io.Writer) (n int64, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *CachedConn) SetReadDeadline(t time.Time) error {
|
||||
if c.buffer != nil && !c.buffer.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
return c.Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *CachedConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
return Copy(c.Conn, r)
|
||||
}
|
||||
|
@ -192,10 +184,12 @@ func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
|
|||
if buffer != nil {
|
||||
buffer.DecRef()
|
||||
}
|
||||
return &N.PacketBuffer{
|
||||
packet := N.NewPacketBuffer()
|
||||
*packet = N.PacketBuffer{
|
||||
Buffer: buffer,
|
||||
Destination: c.destination,
|
||||
}
|
||||
return packet
|
||||
}
|
||||
|
||||
func (c *CachedPacketConn) Upstream() any {
|
||||
|
|
|
@ -30,7 +30,7 @@ func (c *ChunkReader) ReadBuffer(buffer *buf.Buffer) error {
|
|||
} else if !c.cache.IsEmpty() {
|
||||
return common.Error(buffer.ReadFrom(c.cache))
|
||||
}
|
||||
c.cache.FullReset()
|
||||
c.cache.Reset()
|
||||
err := c.upstream.ReadBuffer(c.cache)
|
||||
if err != nil {
|
||||
c.cache.Release()
|
||||
|
@ -46,7 +46,7 @@ func (c *ChunkReader) Read(p []byte) (n int, err error) {
|
|||
} else if !c.cache.IsEmpty() {
|
||||
return c.cache.Read(p)
|
||||
}
|
||||
c.cache.FullReset()
|
||||
c.cache.Reset()
|
||||
err = c.upstream.ReadBuffer(c.cache)
|
||||
if err != nil {
|
||||
c.cache.Release()
|
||||
|
@ -70,7 +70,7 @@ func (c *ChunkReader) ReadChunk() (*buf.Buffer, error) {
|
|||
} else if !c.cache.IsEmpty() {
|
||||
return c.cache, nil
|
||||
}
|
||||
c.cache.FullReset()
|
||||
c.cache.Reset()
|
||||
err := c.upstream.ReadBuffer(c.cache)
|
||||
if err != nil {
|
||||
c.cache.Release()
|
||||
|
|
|
@ -35,14 +35,7 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
|||
|
||||
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
defer buffer.Release()
|
||||
if destination.IsFqdn() {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
|
||||
}
|
||||
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
|
||||
return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
|
||||
}
|
||||
|
||||
func (w *ExtendedUDPConn) Upstream() any {
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
|
@ -13,7 +12,6 @@ import (
|
|||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
)
|
||||
|
||||
|
@ -31,93 +29,71 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
|||
if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
||||
cachedBuffer := cachedSrc.ReadCached()
|
||||
if cachedBuffer != nil {
|
||||
if !cachedBuffer.IsEmpty() {
|
||||
_, err = destination.Write(cachedBuffer.Bytes())
|
||||
if err != nil {
|
||||
cachedBuffer.Release()
|
||||
return
|
||||
}
|
||||
}
|
||||
dataLen := cachedBuffer.Len()
|
||||
_, err = destination.Write(cachedBuffer.Bytes())
|
||||
cachedBuffer.Release()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
|
||||
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
||||
if srcIsSyscall && dstIsSyscall {
|
||||
var handled bool
|
||||
handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
|
||||
}
|
||||
|
||||
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
|
||||
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
||||
if srcIsSyscall && dstIsSyscall {
|
||||
var handled bool
|
||||
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
||||
}
|
||||
|
||||
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
safeSrc := N.IsSafeReader(source)
|
||||
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
|
||||
if safeSrc != nil {
|
||||
if headroom == 0 {
|
||||
return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters)
|
||||
}
|
||||
}
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
readWaiter, isReadWaiter := CreateReadWaiter(source)
|
||||
if isReadWaiter {
|
||||
var handled bool
|
||||
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
FrontHeadroom: frontHeadroom,
|
||||
RearHeadroom: rearHeadroom,
|
||||
MTU: N.CalculateMTU(source, destination),
|
||||
})
|
||||
if !needCopy || common.LowMemory {
|
||||
var handled bool
|
||||
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
|
||||
}
|
||||
|
||||
// Deprecated: not used
|
||||
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
buffer.Resize(frontHeadroom, 0)
|
||||
buffer.Reserve(rearHeadroom)
|
||||
var notFirstTime bool
|
||||
for {
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = source.ReadBuffer(readBuffer)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
var notFirstTime bool
|
||||
for {
|
||||
var buffer *buf.Buffer
|
||||
buffer, err = source.ReadBufferThreadSafe()
|
||||
err = source.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
|
@ -126,9 +102,9 @@ func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWri
|
|||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
buffer.OverCap(rearHeadroom)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
|
@ -146,21 +122,11 @@ func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWri
|
|||
}
|
||||
|
||||
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
bufferSize := N.CalculateMTU(source, destination)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
options := N.NewReadWaitOptions(source, destination)
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = source.ReadBuffer(readBuffer)
|
||||
buffer := options.NewBuffer()
|
||||
err = source.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if errors.Is(err, io.EOF) {
|
||||
|
@ -169,11 +135,11 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
|
|||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
dataLen := buffer.Len()
|
||||
options.PostReturn(buffer)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
|
@ -191,16 +157,12 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
|
|||
}
|
||||
|
||||
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
|
||||
return CopyConnContextList([]context.Context{ctx}, source, destination)
|
||||
}
|
||||
|
||||
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
|
||||
var group task.Group
|
||||
if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex {
|
||||
if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex {
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
err := common.Error(Copy(destination, source))
|
||||
if err == nil {
|
||||
rw.CloseWrite(destination)
|
||||
N.CloseWrite(destination)
|
||||
} else {
|
||||
common.Close(destination)
|
||||
}
|
||||
|
@ -212,11 +174,11 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
|
|||
return common.Error(Copy(destination, source))
|
||||
})
|
||||
}
|
||||
if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex {
|
||||
if _, srcDuplex := common.Cast[N.WriteCloser](source); srcDuplex {
|
||||
group.Append("download", func(ctx context.Context) error {
|
||||
err := common.Error(Copy(source, destination))
|
||||
if err == nil {
|
||||
rw.CloseWrite(source)
|
||||
N.CloseWrite(source)
|
||||
} else {
|
||||
common.Close(source)
|
||||
}
|
||||
|
@ -231,7 +193,7 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
|
|||
group.Cleanup(func() {
|
||||
common.Close(source, destination)
|
||||
})
|
||||
return group.RunContextList(contextList)
|
||||
return group.Run(ctx)
|
||||
}
|
||||
|
||||
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
|
||||
|
@ -251,33 +213,30 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
|
|||
break
|
||||
}
|
||||
if cachedPackets != nil {
|
||||
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
|
||||
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
safeSrc := N.IsSafePacketReader(source)
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
headroom := frontHeadroom + rearHeadroom
|
||||
if safeSrc != nil {
|
||||
if headroom == 0 {
|
||||
var copyN int64
|
||||
copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0)
|
||||
n += copyN
|
||||
return
|
||||
}
|
||||
}
|
||||
copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
|
||||
n += copeN
|
||||
return
|
||||
}
|
||||
|
||||
func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
var (
|
||||
handled bool
|
||||
copeN int64
|
||||
)
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
||||
if isReadWaiter {
|
||||
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
||||
if handled {
|
||||
n += copeN
|
||||
return
|
||||
needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
|
||||
if !needCopy || common.LowMemory {
|
||||
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
||||
if handled {
|
||||
n += copeN
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
|
||||
|
@ -285,116 +244,65 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
|
|||
return
|
||||
}
|
||||
|
||||
func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||
var buffer *buf.Buffer
|
||||
var destination M.Socksaddr
|
||||
func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||
options := N.NewReadWaitOptions(source, destination)
|
||||
var destinationAddress M.Socksaddr
|
||||
for {
|
||||
buffer, destination, err = source.ReadPacketThreadSafe()
|
||||
buffer := options.NewPacketBuffer()
|
||||
destinationAddress, err = source.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
if buffer == nil {
|
||||
panic("nil buffer returned from " + reflect.TypeOf(source).String())
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
if dataLen == 0 {
|
||||
continue
|
||||
}
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
options.PostReturn(buffer)
|
||||
err = destination.WritePacket(buffer, destinationAddress)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
var destination M.Socksaddr
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
destination, err = source.ReadPacket(readBuffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
options := N.NewReadWaitOptions(nil, destination)
|
||||
var notFirstTime bool
|
||||
for _, packetBuffer := range packetBuffers {
|
||||
buffer := buf.NewPacket()
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
_, err = readBuffer.Write(packetBuffer.Buffer.Bytes())
|
||||
packetBuffer.Buffer.Release()
|
||||
buffer := options.Copy(packetBuffer.Buffer)
|
||||
dataLen := buffer.Len()
|
||||
err = destination.WritePacket(buffer, packetBuffer.Destination)
|
||||
N.PutPacketBuffer(packetBuffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
continue
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
|
||||
return CopyPacketConnContextList([]context.Context{ctx}, source, destination)
|
||||
}
|
||||
|
||||
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
|
||||
var group task.Group
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
return common.Error(CopyPacket(destination, source))
|
||||
|
@ -406,5 +314,5 @@ func CopyPacketConnContextList(contextList []context.Context, source N.PacketCon
|
|||
common.Close(source, destination)
|
||||
})
|
||||
group.FastFail()
|
||||
return group.RunContextList(contextList)
|
||||
return group.Run(ctx)
|
||||
}
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||
func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||
rawSource, err := source.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -18,3 +22,69 @@ func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.
|
|||
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
|
||||
return
|
||||
}
|
||||
|
||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
notFirstTime bool
|
||||
)
|
||||
for {
|
||||
buffer, err = source.WaitReadBuffer()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
)
|
||||
for {
|
||||
buffer, destination, err = source.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
@ -15,114 +14,14 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
bufferSize := N.CalculateMTU(source, destination)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
notFirstTime bool
|
||||
)
|
||||
source.InitializeReadWaiter(func() *buf.Buffer {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
return readBuffer
|
||||
})
|
||||
defer source.InitializeReadWaiter(nil)
|
||||
for {
|
||||
err = source.WaitReadBuffer()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
)
|
||||
source.InitializeReadWaiter(func() *buf.Buffer {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
return readBuffer
|
||||
})
|
||||
defer source.InitializeReadWaiter(nil)
|
||||
for {
|
||||
destination, err = source.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
||||
|
||||
type syscallReadWaiter struct {
|
||||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
||||
|
@ -135,47 +34,48 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readErr = nil
|
||||
if newBuffer == nil {
|
||||
w.readFunc = nil
|
||||
} else {
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer := newBuffer()
|
||||
var readN int
|
||||
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
} else {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
if w.readErr == syscall.EAGAIN {
|
||||
return false
|
||||
}
|
||||
if readN == 0 {
|
||||
w.readErr = io.EOF
|
||||
}
|
||||
return true
|
||||
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer := w.options.NewBuffer()
|
||||
var readN int
|
||||
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
} else {
|
||||
buffer.Release()
|
||||
}
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if w.readErr == syscall.EAGAIN {
|
||||
return false
|
||||
}
|
||||
if readN == 0 && w.readErr == nil {
|
||||
w.readErr = io.EOF
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) WaitReadBuffer() error {
|
||||
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
if w.readFunc == nil {
|
||||
return os.ErrInvalid
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
err := w.rawConn.Read(w.readFunc)
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return err
|
||||
return
|
||||
}
|
||||
if w.readErr != nil {
|
||||
if w.readErr == io.EOF {
|
||||
return io.EOF
|
||||
return nil, io.EOF
|
||||
}
|
||||
return E.Cause(w.readErr, "raw read")
|
||||
return nil, E.Cause(w.readErr, "raw read")
|
||||
}
|
||||
return nil
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
|
||||
|
@ -185,6 +85,8 @@ type syscallPacketReadWaiter struct {
|
|||
readErr error
|
||||
readFrom M.Socksaddr
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
|
||||
|
@ -197,42 +99,37 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readErr = nil
|
||||
w.readFrom = M.Socksaddr{}
|
||||
if newBuffer == nil {
|
||||
w.readFunc = nil
|
||||
} else {
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer := newBuffer()
|
||||
var readN int
|
||||
var from syscall.Sockaddr
|
||||
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
} else {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
if w.readErr == syscall.EAGAIN {
|
||||
return false
|
||||
}
|
||||
if from != nil {
|
||||
switch fromAddr := from.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
|
||||
case *syscall.SockaddrInet6:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
||||
}
|
||||
}
|
||||
return true
|
||||
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer := w.options.NewPacketBuffer()
|
||||
var readN int
|
||||
var from syscall.Sockaddr
|
||||
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if w.readErr != nil {
|
||||
buffer.Release()
|
||||
return w.readErr != syscall.EAGAIN
|
||||
}
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
}
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
switch fromAddr := from.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
|
||||
case *syscall.SockaddrInet6:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
if w.readFunc == nil {
|
||||
return M.Socksaddr{}, os.ErrInvalid
|
||||
return nil, M.Socksaddr{}, os.ErrInvalid
|
||||
}
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
|
@ -242,6 +139,8 @@ func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err
|
|||
err = E.Cause(w.readErr, "raw read")
|
||||
return
|
||||
}
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
destination = w.readFrom
|
||||
return
|
||||
}
|
||||
|
|
77
common/bufio/copy_direct_test.go
Normal file
77
common/bufio/copy_direct_test.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCopyWaitTCP(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn := TCPPipe(t)
|
||||
readWaiter, created := CreateReadWaiter(outputConn)
|
||||
require.True(t, created)
|
||||
require.NotNil(t, readWaiter)
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
|
||||
require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{
|
||||
Conn: outputConn,
|
||||
readWaiter: readWaiter,
|
||||
}))
|
||||
}
|
||||
|
||||
type readWaitWrapper struct {
|
||||
net.Conn
|
||||
readWaiter N.ReadWaiter
|
||||
buffer *buf.Buffer
|
||||
}
|
||||
|
||||
func (r *readWaitWrapper) Read(p []byte) (n int, err error) {
|
||||
if r.buffer != nil {
|
||||
if r.buffer.Len() > 0 {
|
||||
return r.buffer.Read(p)
|
||||
}
|
||||
if r.buffer.IsEmpty() {
|
||||
r.buffer.Release()
|
||||
r.buffer = nil
|
||||
}
|
||||
}
|
||||
buffer, err := r.readWaiter.WaitReadBuffer()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r.buffer = buffer
|
||||
return r.buffer.Read(p)
|
||||
}
|
||||
|
||||
func TestCopyWaitUDP(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn, outputAddr := UDPPipe(t)
|
||||
readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn))
|
||||
require.True(t, created)
|
||||
require.NotNil(t, readWaiter)
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
|
||||
require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{
|
||||
PacketConn: outputConn,
|
||||
readWaiter: readWaiter,
|
||||
}, outputAddr))
|
||||
}
|
||||
|
||||
type packetReadWaitWrapper struct {
|
||||
net.PacketConn
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
buffer, destination, err := r.readWaiter.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n = copy(p, buffer.Bytes())
|
||||
buffer.Release()
|
||||
addr = destination.UDPAddr()
|
||||
return
|
||||
}
|
|
@ -2,22 +2,162 @@ package bufio
|
|||
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
return
|
||||
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
||||
|
||||
type syscallReadWaiter struct {
|
||||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
hasData bool
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) {
|
||||
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
||||
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
|
||||
rawConn, err := syscallConn.SyscallConn()
|
||||
if err == nil {
|
||||
return &syscallReadWaiter{rawConn: rawConn}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func createSyscallPacketReadWaiter(reader any) (N.PacketReadWaiter, bool) {
|
||||
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
if !w.hasData {
|
||||
w.hasData = true
|
||||
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
|
||||
// socket is readable if we return false. So the `recv` syscall will not block the system thread.
|
||||
return false
|
||||
}
|
||||
buffer := w.options.NewBuffer()
|
||||
var readN int32
|
||||
readN, w.readErr = recv(windows.Handle(fd), buffer.FreeBytes(), 0)
|
||||
if readN > 0 {
|
||||
buffer.Truncate(int(readN))
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
} else {
|
||||
buffer.Release()
|
||||
}
|
||||
if w.readErr == windows.WSAEWOULDBLOCK {
|
||||
return false
|
||||
}
|
||||
if readN == 0 && w.readErr == nil {
|
||||
w.readErr = io.EOF
|
||||
}
|
||||
w.hasData = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
if w.readFunc == nil {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if w.readErr != nil {
|
||||
if w.readErr == io.EOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, E.Cause(w.readErr, "raw read")
|
||||
}
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
|
||||
|
||||
type syscallPacketReadWaiter struct {
|
||||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
readFrom M.Socksaddr
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
hasData bool
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
|
||||
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
|
||||
rawConn, err := syscallConn.SyscallConn()
|
||||
if err == nil {
|
||||
return &syscallPacketReadWaiter{rawConn: rawConn}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
if !w.hasData {
|
||||
w.hasData = true
|
||||
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
|
||||
// socket is readable if we return false. So the `recvfrom` syscall will not block the system thread.
|
||||
return false
|
||||
}
|
||||
buffer := w.options.NewPacketBuffer()
|
||||
var readN int
|
||||
var from windows.Sockaddr
|
||||
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if w.readErr != nil {
|
||||
buffer.Release()
|
||||
return w.readErr != windows.WSAEWOULDBLOCK
|
||||
}
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
}
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
if from != nil {
|
||||
switch fromAddr := from.(type) {
|
||||
case *windows.SockaddrInet4:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
|
||||
case *windows.SockaddrInet6:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
||||
}
|
||||
}
|
||||
w.hasData = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
if w.readFunc == nil {
|
||||
return nil, M.Socksaddr{}, os.ErrInvalid
|
||||
}
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if w.readErr != nil {
|
||||
err = E.Cause(w.readErr, "raw read")
|
||||
return
|
||||
}
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
destination = w.readFrom
|
||||
return
|
||||
}
|
||||
|
|
|
@ -38,6 +38,10 @@ func (c *SerialConn) ReadBuffer(buffer *buf.Buffer) error {
|
|||
return c.ExtendedConn.ReadBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *SerialConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
type SerialPacketConn struct {
|
||||
N.NetPacketConn
|
||||
access sync.Mutex
|
||||
|
|
|
@ -25,6 +25,45 @@ func ReadPacket(reader N.PacketReader, buffer *buf.Buffer) (n int, addr net.Addr
|
|||
return
|
||||
}
|
||||
|
||||
func ReadBufferSize(reader io.Reader, bufferSize int) (buffer *buf.Buffer, err error) {
|
||||
readWaiter, isReadWaiter := CreateReadWaiter(reader)
|
||||
if isReadWaiter {
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
MTU: bufferSize,
|
||||
})
|
||||
return readWaiter.WaitReadBuffer()
|
||||
}
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
if extendedReader, isExtendedReader := reader.(N.ExtendedReader); isExtendedReader {
|
||||
err = extendedReader.ReadBuffer(buffer)
|
||||
} else {
|
||||
_, err = buffer.ReadOnceFrom(reader)
|
||||
}
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ReadPacketSize(reader N.PacketReader, packetSize int) (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(reader)
|
||||
if isReadWaiter {
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
MTU: packetSize,
|
||||
})
|
||||
buffer, destination, err = readWaiter.WaitReadPacket()
|
||||
return
|
||||
}
|
||||
buffer = buf.NewSize(packetSize)
|
||||
destination, err = reader.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func Write(writer io.Writer, data []byte) (n int, err error) {
|
||||
if extendedWriter, isExtended := writer.(N.ExtendedWriter); isExtended {
|
||||
return WriteBuffer(extendedWriter, buf.As(data))
|
||||
|
|
|
@ -17,13 +17,21 @@ type NATPacketConn interface {
|
|||
func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
|
||||
return &unidirectionalNATPacketConn{
|
||||
NetPacketConn: conn,
|
||||
origin: origin,
|
||||
destination: destination,
|
||||
origin: socksaddrWithoutPort(origin),
|
||||
destination: socksaddrWithoutPort(destination),
|
||||
}
|
||||
}
|
||||
|
||||
func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
|
||||
return &bidirectionalNATPacketConn{
|
||||
NetPacketConn: conn,
|
||||
origin: socksaddrWithoutPort(origin),
|
||||
destination: socksaddrWithoutPort(destination),
|
||||
}
|
||||
}
|
||||
|
||||
func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
|
||||
return &destinationNATPacketConn{
|
||||
NetPacketConn: conn,
|
||||
origin: origin,
|
||||
destination: destination,
|
||||
|
@ -37,15 +45,24 @@ type unidirectionalNATPacketConn struct {
|
|||
}
|
||||
|
||||
func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if M.SocksaddrFromNet(addr) == c.destination {
|
||||
addr = c.origin.UDPAddr()
|
||||
destination := M.SocksaddrFromNet(addr)
|
||||
if socksaddrWithoutPort(destination) == c.destination {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.origin.Addr,
|
||||
Fqdn: c.origin.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return c.NetPacketConn.WriteTo(p, addr)
|
||||
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
|
||||
}
|
||||
|
||||
func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if destination == c.destination {
|
||||
destination = c.origin
|
||||
if socksaddrWithoutPort(destination) == c.destination {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.origin.Addr,
|
||||
Fqdn: c.origin.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return c.NetPacketConn.WritePacket(buffer, destination)
|
||||
}
|
||||
|
@ -54,6 +71,10 @@ func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip
|
|||
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
|
||||
}
|
||||
|
||||
func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
func (c *unidirectionalNATPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
@ -66,30 +87,55 @@ type bidirectionalNATPacketConn struct {
|
|||
|
||||
func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = c.NetPacketConn.ReadFrom(p)
|
||||
if err == nil && M.SocksaddrFromNet(addr) == c.origin {
|
||||
addr = c.destination.UDPAddr()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination := M.SocksaddrFromNet(addr)
|
||||
if socksaddrWithoutPort(destination) == c.origin {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.destination.Addr,
|
||||
Fqdn: c.destination.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
addr = destination.UDPAddr()
|
||||
return
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if M.SocksaddrFromNet(addr) == c.destination {
|
||||
addr = c.origin.UDPAddr()
|
||||
destination := M.SocksaddrFromNet(addr)
|
||||
if socksaddrWithoutPort(destination) == c.destination {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.origin.Addr,
|
||||
Fqdn: c.origin.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return c.NetPacketConn.WriteTo(p, addr)
|
||||
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
destination, err = c.NetPacketConn.ReadPacket(buffer)
|
||||
if destination == c.origin {
|
||||
destination = c.destination
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if socksaddrWithoutPort(destination) == c.origin {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.destination.Addr,
|
||||
Fqdn: c.destination.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if destination == c.destination {
|
||||
destination = c.origin
|
||||
if socksaddrWithoutPort(destination) == c.destination {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.origin.Addr,
|
||||
Fqdn: c.origin.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return c.NetPacketConn.WritePacket(buffer, destination)
|
||||
}
|
||||
|
@ -101,3 +147,66 @@ func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.
|
|||
func (c *bidirectionalNATPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
type destinationNATPacketConn struct {
|
||||
N.NetPacketConn
|
||||
origin M.Socksaddr
|
||||
destination M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = c.NetPacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if M.SocksaddrFromNet(addr) == c.origin {
|
||||
addr = c.destination.UDPAddr()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if M.SocksaddrFromNet(addr) == c.destination {
|
||||
addr = c.origin.UDPAddr()
|
||||
}
|
||||
return c.NetPacketConn.WriteTo(p, addr)
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
destination, err = c.NetPacketConn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if destination == c.origin {
|
||||
destination = c.destination
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if destination == c.destination {
|
||||
destination = c.origin
|
||||
}
|
||||
return c.NetPacketConn.WritePacket(buffer, destination)
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
|
||||
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
|
||||
destination.Port = 0
|
||||
return destination
|
||||
}
|
||||
|
|
39
common/bufio/nat_wait.go
Normal file
39
common/bufio/nat_wait.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func (c *bidirectionalNATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaiter, bool) {
|
||||
waiter, created := CreatePacketReadWaiter(c.NetPacketConn)
|
||||
if !created {
|
||||
return nil, false
|
||||
}
|
||||
return &waitBidirectionalNATPacketConn{c, waiter}, true
|
||||
}
|
||||
|
||||
type waitBidirectionalNATPacketConn struct {
|
||||
*bidirectionalNATPacketConn
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (c *waitBidirectionalNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return c.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (c *waitBidirectionalNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
buffer, destination, err = c.readWaiter.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if socksaddrWithoutPort(destination) == c.origin {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.destination.Addr,
|
||||
Fqdn: c.destination.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
277
common/bufio/net_test.go
Normal file
277
common/bufio/net_test.go
Normal file
|
@ -0,0 +1,277 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
var (
|
||||
group task.Group
|
||||
serverConn net.Conn
|
||||
clientConn net.Conn
|
||||
)
|
||||
group.Append0(func(ctx context.Context) error {
|
||||
var serverErr error
|
||||
serverConn, serverErr = listener.Accept()
|
||||
return serverErr
|
||||
})
|
||||
group.Append0(func(ctx context.Context) error {
|
||||
var clientErr error
|
||||
clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
|
||||
return clientErr
|
||||
})
|
||||
err = group.Run(context.Background())
|
||||
require.NoError(t, err)
|
||||
listener.Close()
|
||||
t.Cleanup(func() {
|
||||
serverConn.Close()
|
||||
clientConn.Close()
|
||||
})
|
||||
return serverConn, clientConn
|
||||
}
|
||||
|
||||
func UDPPipe(t *testing.T) (net.PacketConn, net.PacketConn, M.Socksaddr) {
|
||||
serverConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
clientConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
return serverConn, clientConn, M.SocksaddrFromNet(clientConn.LocalAddr())
|
||||
}
|
||||
|
||||
func Timeout(t *testing.T) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("timeout")
|
||||
}
|
||||
}()
|
||||
return cancel
|
||||
}
|
||||
|
||||
type hashPair struct {
|
||||
sendHash map[int][]byte
|
||||
recvHash map[int][]byte
|
||||
}
|
||||
|
||||
func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) {
|
||||
pingCh := make(chan hashPair)
|
||||
pongCh := make(chan hashPair)
|
||||
test := func(t *testing.T) error {
|
||||
defer close(pingCh)
|
||||
defer close(pongCh)
|
||||
pingOpen := false
|
||||
pongOpen := false
|
||||
var serverPair hashPair
|
||||
var clientPair hashPair
|
||||
|
||||
for {
|
||||
if pingOpen && pongOpen {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case serverPair, pingOpen = <-pingCh:
|
||||
assert.True(t, pingOpen)
|
||||
case clientPair, pongOpen = <-pongCh:
|
||||
assert.True(t, pongOpen)
|
||||
case <-time.After(10 * time.Second):
|
||||
return errors.New("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, serverPair.recvHash, clientPair.sendHash)
|
||||
assert.Equal(t, serverPair.sendHash, clientPair.recvHash)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return pingCh, pongCh, test
|
||||
}
|
||||
|
||||
func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error {
|
||||
times := 100
|
||||
chunkSize := int64(64 * 1024)
|
||||
|
||||
pingCh, pongCh, test := newLargeDataPair()
|
||||
writeRandData := func(conn net.Conn) (map[int][]byte, error) {
|
||||
buf := make([]byte, chunkSize)
|
||||
hashMap := map[int][]byte{}
|
||||
for i := 0; i < times; i++ {
|
||||
if _, err := rand.Read(buf[1:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf[0] = byte(i)
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
hashMap[i] = hash[:]
|
||||
|
||||
if _, err := conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return hashMap, nil
|
||||
}
|
||||
go func() {
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, chunkSize)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, err := io.ReadFull(outputConn, buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
|
||||
sendHash, err := writeRandData(outputConn)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pingCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
sendHash, err := writeRandData(inputConn)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, chunkSize)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, err = io.ReadFull(inputConn, buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
|
||||
pongCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
return test(t)
|
||||
}
|
||||
|
||||
func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error {
|
||||
rAddr := outputAddr.UDPAddr()
|
||||
times := 50
|
||||
chunkSize := 9000
|
||||
pingCh, pongCh, test := newLargeDataPair()
|
||||
writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) {
|
||||
hashMap := map[int][]byte{}
|
||||
mux := sync.Mutex{}
|
||||
for i := 0; i < times; i++ {
|
||||
buf := make([]byte, chunkSize)
|
||||
if _, err := rand.Read(buf[1:]); err != nil {
|
||||
t.Log(err.Error())
|
||||
continue
|
||||
}
|
||||
buf[0] = byte(i)
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
mux.Lock()
|
||||
hashMap[i] = hash[:]
|
||||
mux.Unlock()
|
||||
|
||||
if _, err := pc.WriteTo(buf, addr); err != nil {
|
||||
t.Log(err.Error())
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
return hashMap, nil
|
||||
}
|
||||
go func() {
|
||||
var (
|
||||
lAddr net.Addr
|
||||
err error
|
||||
)
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, 64*1024)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, lAddr, err = outputConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
hash := md5.Sum(buf[:chunkSize])
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
sendHash, err := writeRandData(outputConn, lAddr)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pingCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
sendHash, err := writeRandData(inputConn, rAddr)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, 64*1024)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, _, err := inputConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hash := md5.Sum(buf[:chunkSize])
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
|
||||
pongCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
|
||||
return test(t)
|
||||
}
|
5
common/bufio/syscall_windows.go
Normal file
5
common/bufio/syscall_windows.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package bufio
|
||||
|
||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
|
||||
|
||||
//sys recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) [failretval == -1] = ws2_32.recv
|
|
@ -33,10 +33,10 @@ func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) {
|
|||
case syscall.Conn:
|
||||
rawConn, err := w.SyscallConn()
|
||||
if err == nil {
|
||||
return &SyscallVectorisedWriter{writer, rawConn}, true
|
||||
return &SyscallVectorisedWriter{upstream: writer, rawConn: rawConn}, true
|
||||
}
|
||||
case syscall.RawConn:
|
||||
return &SyscallVectorisedWriter{writer, w}, true
|
||||
return &SyscallVectorisedWriter{upstream: writer, rawConn: w}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
@ -48,10 +48,10 @@ func CreateVectorisedPacketWriter(writer any) (N.VectorisedPacketWriter, bool) {
|
|||
case syscall.Conn:
|
||||
rawConn, err := w.SyscallConn()
|
||||
if err == nil {
|
||||
return &SyscallVectorisedPacketWriter{writer, rawConn}, true
|
||||
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: rawConn}, true
|
||||
}
|
||||
case syscall.RawConn:
|
||||
return &SyscallVectorisedPacketWriter{writer, w}, true
|
||||
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: w}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
@ -111,6 +111,7 @@ var _ N.VectorisedWriter = (*SyscallVectorisedWriter)(nil)
|
|||
type SyscallVectorisedWriter struct {
|
||||
upstream any
|
||||
rawConn syscall.RawConn
|
||||
syscallVectorisedWriterFields
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedWriter) Upstream() any {
|
||||
|
@ -126,6 +127,7 @@ var _ N.VectorisedPacketWriter = (*SyscallVectorisedPacketWriter)(nil)
|
|||
type SyscallVectorisedPacketWriter struct {
|
||||
upstream any
|
||||
rawConn syscall.RawConn
|
||||
syscallVectorisedWriterFields
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedPacketWriter) Upstream() any {
|
||||
|
|
60
common/bufio/vectorised_test.go
Normal file
60
common/bufio/vectorised_test.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteVectorised(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn := TCPPipe(t)
|
||||
vectorisedWriter, created := CreateVectorisedWriter(inputConn)
|
||||
require.True(t, created)
|
||||
require.NotNil(t, vectorisedWriter)
|
||||
var bufA [1024]byte
|
||||
var bufB [1024]byte
|
||||
var bufC [2048]byte
|
||||
_, err := io.ReadFull(rand.Reader, bufA[:])
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadFull(rand.Reader, bufB[:])
|
||||
require.NoError(t, err)
|
||||
copy(bufC[:], bufA[:])
|
||||
copy(bufC[1024:], bufB[:])
|
||||
finish := Timeout(t)
|
||||
_, err = WriteVectorised(vectorisedWriter, [][]byte{bufA[:], bufB[:]})
|
||||
require.NoError(t, err)
|
||||
output := make([]byte, 2048)
|
||||
_, err = io.ReadFull(outputConn, output)
|
||||
finish()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bufC[:], output)
|
||||
}
|
||||
|
||||
func TestWriteVectorisedPacket(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn, outputAddr := UDPPipe(t)
|
||||
vectorisedWriter, created := CreateVectorisedPacketWriter(inputConn)
|
||||
require.True(t, created)
|
||||
require.NotNil(t, vectorisedWriter)
|
||||
var bufA [1024]byte
|
||||
var bufB [1024]byte
|
||||
var bufC [2048]byte
|
||||
_, err := io.ReadFull(rand.Reader, bufA[:])
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadFull(rand.Reader, bufB[:])
|
||||
require.NoError(t, err)
|
||||
copy(bufC[:], bufA[:])
|
||||
copy(bufC[1024:], bufB[:])
|
||||
finish := Timeout(t)
|
||||
_, err = WriteVectorisedPacket(vectorisedWriter, [][]byte{bufA[:], bufB[:]}, outputAddr)
|
||||
require.NoError(t, err)
|
||||
output := make([]byte, 2048)
|
||||
n, _, err := outputConn.ReadFrom(output)
|
||||
finish()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2048, n)
|
||||
require.Equal(t, bufC[:], output)
|
||||
}
|
|
@ -3,6 +3,8 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
|
@ -11,49 +13,81 @@ import (
|
|||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type syscallVectorisedWriterFields struct {
|
||||
access sync.Mutex
|
||||
iovecList *[]unix.Iovec
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
iovecList := make([]unix.Iovec, 0, len(buffers))
|
||||
for _, buffer := range buffers {
|
||||
var iovec unix.Iovec
|
||||
iovec.Base = &buffer.Bytes()[0]
|
||||
iovec.SetLen(buffer.Len())
|
||||
iovecList = append(iovecList, iovec)
|
||||
var iovecList []unix.Iovec
|
||||
if w.iovecList != nil {
|
||||
iovecList = *w.iovecList
|
||||
}
|
||||
iovecList = iovecList[:0]
|
||||
for index, buffer := range buffers {
|
||||
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
|
||||
iovecList[index].SetLen(buffer.Len())
|
||||
}
|
||||
if w.iovecList == nil {
|
||||
w.iovecList = new([]unix.Iovec)
|
||||
}
|
||||
*w.iovecList = iovecList // cache
|
||||
var innerErr unix.Errno
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
//nolint:staticcheck
|
||||
//goland:noinspection GoDeprecation
|
||||
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
|
||||
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
|
||||
})
|
||||
if innerErr != 0 {
|
||||
err = innerErr
|
||||
err = os.NewSyscallError("SYS_WRITEV", innerErr)
|
||||
}
|
||||
for index := range iovecList {
|
||||
iovecList[index] = unix.Iovec{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
var sockaddr unix.Sockaddr
|
||||
if destination.IsIPv4() {
|
||||
sockaddr = &unix.SockaddrInet4{
|
||||
Port: int(destination.Port),
|
||||
Addr: destination.Addr.As4(),
|
||||
}
|
||||
} else {
|
||||
sockaddr = &unix.SockaddrInet6{
|
||||
Port: int(destination.Port),
|
||||
Addr: destination.Addr.As16(),
|
||||
}
|
||||
var iovecList []unix.Iovec
|
||||
if w.iovecList != nil {
|
||||
iovecList = *w.iovecList
|
||||
}
|
||||
iovecList = iovecList[:0]
|
||||
for index, buffer := range buffers {
|
||||
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
|
||||
iovecList[index].SetLen(buffer.Len())
|
||||
}
|
||||
if w.iovecList == nil {
|
||||
w.iovecList = new([]unix.Iovec)
|
||||
}
|
||||
*w.iovecList = iovecList // cache
|
||||
var innerErr error
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
_, innerErr = unix.SendmsgBuffers(int(fd), buf.ToSliceMulti(buffers), nil, sockaddr, 0)
|
||||
var msg unix.Msghdr
|
||||
name, nameLen := ToSockaddr(destination.AddrPort())
|
||||
msg.Name = (*byte)(name)
|
||||
msg.Namelen = nameLen
|
||||
if len(iovecList) > 0 {
|
||||
msg.Iov = &iovecList[0]
|
||||
msg.SetIovlen(len(iovecList))
|
||||
}
|
||||
_, innerErr = sendmsg(int(fd), &msg, 0)
|
||||
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
|
||||
})
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
for index := range iovecList {
|
||||
iovecList[index] = unix.Iovec{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
//go:linkname sendmsg golang.org/x/sys/unix.sendmsg
|
||||
func sendmsg(s int, msg *unix.Msghdr, flags int) (n int, err error)
|
||||
|
|
|
@ -1,62 +1,93 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type syscallVectorisedWriterFields struct {
|
||||
access sync.Mutex
|
||||
iovecList *[]windows.WSABuf
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
iovecList := make([]*windows.WSABuf, len(buffers))
|
||||
for i, buffer := range buffers {
|
||||
iovecList[i] = &windows.WSABuf{
|
||||
Len: uint32(buffer.Len()),
|
||||
Buf: &buffer.Bytes()[0],
|
||||
}
|
||||
var iovecList []windows.WSABuf
|
||||
if w.iovecList != nil {
|
||||
iovecList = *w.iovecList
|
||||
}
|
||||
iovecList = iovecList[:0]
|
||||
for _, buffer := range buffers {
|
||||
iovecList = append(iovecList, windows.WSABuf{
|
||||
Buf: &buffer.Bytes()[0],
|
||||
Len: uint32(buffer.Len()),
|
||||
})
|
||||
}
|
||||
if w.iovecList == nil {
|
||||
w.iovecList = new([]windows.WSABuf)
|
||||
}
|
||||
*w.iovecList = iovecList // cache
|
||||
var n uint32
|
||||
var innerErr error
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
innerErr = windows.WSASend(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
|
||||
innerErr = windows.WSASend(windows.Handle(fd), &iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
|
||||
return innerErr != windows.WSAEWOULDBLOCK
|
||||
})
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
for index := range iovecList {
|
||||
iovecList[index] = windows.WSABuf{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
iovecList := make([]*windows.WSABuf, len(buffers))
|
||||
for i, buffer := range buffers {
|
||||
iovecList[i] = &windows.WSABuf{
|
||||
Len: uint32(buffer.Len()),
|
||||
var iovecList []windows.WSABuf
|
||||
if w.iovecList != nil {
|
||||
iovecList = *w.iovecList
|
||||
}
|
||||
iovecList = iovecList[:0]
|
||||
for _, buffer := range buffers {
|
||||
iovecList = append(iovecList, windows.WSABuf{
|
||||
Buf: &buffer.Bytes()[0],
|
||||
}
|
||||
Len: uint32(buffer.Len()),
|
||||
})
|
||||
}
|
||||
var sockaddr windows.Sockaddr
|
||||
if destination.IsIPv4() {
|
||||
sockaddr = &windows.SockaddrInet4{
|
||||
Port: int(destination.Port),
|
||||
Addr: destination.Addr.As4(),
|
||||
}
|
||||
} else {
|
||||
sockaddr = &windows.SockaddrInet6{
|
||||
Port: int(destination.Port),
|
||||
Addr: destination.Addr.As16(),
|
||||
}
|
||||
if w.iovecList == nil {
|
||||
w.iovecList = new([]windows.WSABuf)
|
||||
}
|
||||
*w.iovecList = iovecList // cache
|
||||
var n uint32
|
||||
var innerErr error
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
innerErr = windows.WSASendto(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, sockaddr, nil, nil)
|
||||
name, nameLen := ToSockaddr(destination.AddrPort())
|
||||
innerErr = windows.WSASendTo(
|
||||
windows.Handle(fd),
|
||||
&iovecList[0],
|
||||
uint32(len(iovecList)),
|
||||
&n,
|
||||
0,
|
||||
(*windows.RawSockaddrAny)(name),
|
||||
nameLen,
|
||||
nil,
|
||||
nil)
|
||||
return innerErr != windows.WSAEWOULDBLOCK
|
||||
})
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
for index := range iovecList {
|
||||
iovecList[index] = windows.WSABuf{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
57
common/bufio/zsyscall_windows.go
Normal file
57
common/bufio/zsyscall_windows.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package bufio
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||
|
||||
procrecv = modws2_32.NewProc("recv")
|
||||
)
|
||||
|
||||
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
|
||||
var _p0 *byte
|
||||
if len(buf) > 0 {
|
||||
_p0 = &buf[0]
|
||||
}
|
||||
r0, _, e1 := syscall.Syscall6(procrecv.Addr(), 4, uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags), 0, 0)
|
||||
n = int32(r0)
|
||||
if n == -1 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
|
@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration {
|
|||
return i.timeout
|
||||
}
|
||||
|
||||
func (i *Instance) SetTimeout(timeout time.Duration) {
|
||||
func (i *Instance) SetTimeout(timeout time.Duration) bool {
|
||||
i.timeout = timeout
|
||||
i.Update()
|
||||
return i.Update()
|
||||
}
|
||||
|
||||
func (i *Instance) wait() {
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
type PacketConn interface {
|
||||
N.PacketConn
|
||||
Timeout() time.Duration
|
||||
SetTimeout(timeout time.Duration)
|
||||
SetTimeout(timeout time.Duration) bool
|
||||
}
|
||||
|
||||
type TimerPacketConn struct {
|
||||
|
@ -21,13 +21,15 @@ type TimerPacketConn struct {
|
|||
instance *Instance
|
||||
}
|
||||
|
||||
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) {
|
||||
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
|
||||
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
|
||||
oldTimeout := timeoutConn.Timeout()
|
||||
if timeout < oldTimeout {
|
||||
timeoutConn.SetTimeout(timeout)
|
||||
if oldTimeout > 0 && timeout >= oldTimeout {
|
||||
return ctx, conn
|
||||
}
|
||||
if timeoutConn.SetTimeout(timeout) {
|
||||
return ctx, conn
|
||||
}
|
||||
return ctx, timeoutConn
|
||||
}
|
||||
err := conn.SetReadDeadline(time.Time{})
|
||||
if err == nil {
|
||||
|
@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration {
|
|||
return c.instance.Timeout()
|
||||
}
|
||||
|
||||
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) {
|
||||
c.instance.SetTimeout(timeout)
|
||||
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
|
||||
return c.instance.SetTimeout(timeout)
|
||||
}
|
||||
|
||||
func (c *TimerPacketConn) Close() error {
|
||||
|
|
|
@ -2,6 +2,7 @@ package canceler
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
|
@ -31,7 +32,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
|
|||
for {
|
||||
err = c.PacketConn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||
if err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
return
|
||||
}
|
||||
destination, err = c.PacketConn.ReadPacket(buffer)
|
||||
if err == nil {
|
||||
|
@ -43,7 +44,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
|
|||
return
|
||||
}
|
||||
} else {
|
||||
return M.Socksaddr{}, err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -60,12 +61,13 @@ func (c *TimeoutPacketConn) Timeout() time.Duration {
|
|||
return c.timeout
|
||||
}
|
||||
|
||||
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
|
||||
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
|
||||
c.timeout = timeout
|
||||
c.PacketConn.SetReadDeadline(time.Now())
|
||||
return c.PacketConn.SetReadDeadline(time.Now()) == nil
|
||||
}
|
||||
|
||||
func (c *TimeoutPacketConn) Close() error {
|
||||
c.cancel(net.ErrClosed)
|
||||
return c.PacketConn.Close()
|
||||
}
|
||||
|
||||
|
|
11
common/clear.go
Normal file
11
common/clear.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
//go:build go1.21
|
||||
|
||||
package common
|
||||
|
||||
func ClearArray[T ~[]E, E any](t T) {
|
||||
clear(t)
|
||||
}
|
||||
|
||||
func ClearMap[T ~map[K]V, K comparable, V any](t T) {
|
||||
clear(t)
|
||||
}
|
16
common/clear_compat.go
Normal file
16
common/clear_compat.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
//go:build !go1.21
|
||||
|
||||
package common
|
||||
|
||||
func ClearArray[T ~[]E, E any](t T) {
|
||||
var defaultValue E
|
||||
for i := range t {
|
||||
t[i] = defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
func ClearMap[T ~map[K]V, K comparable, V any](t T) {
|
||||
for k := range t {
|
||||
delete(t, k)
|
||||
}
|
||||
}
|
|
@ -157,6 +157,18 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
|
|||
return -1
|
||||
}
|
||||
|
||||
func Equal[S ~[]E, E comparable](s1, s2 S) bool {
|
||||
if len(s1) != len(s2) {
|
||||
return false
|
||||
}
|
||||
for i := range s1 {
|
||||
if s1[i] != s2[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
//go:norace
|
||||
func Dup[T any](obj T) T {
|
||||
pointer := uintptr(unsafe.Pointer(&obj))
|
||||
|
@ -268,6 +280,14 @@ func Reverse[T any](arr []T) []T {
|
|||
return arr
|
||||
}
|
||||
|
||||
func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K {
|
||||
ret := make(map[V]K, len(m))
|
||||
for k, v := range m {
|
||||
ret[v] = k
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func Done(ctx context.Context) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
@ -336,6 +356,10 @@ func DefaultValue[T any]() T {
|
|||
return defaultValue
|
||||
}
|
||||
|
||||
func Ptr[T any](obj T) *T {
|
||||
return &obj
|
||||
}
|
||||
|
||||
func Close(closers ...any) error {
|
||||
var retErr error
|
||||
for _, closer := range closers {
|
||||
|
@ -358,22 +382,3 @@ func Close(closers ...any) error {
|
|||
}
|
||||
return retErr
|
||||
}
|
||||
|
||||
type Starter interface {
|
||||
Start() error
|
||||
}
|
||||
|
||||
func Start(starters ...any) error {
|
||||
for _, rawStarter := range starters {
|
||||
if rawStarter == nil {
|
||||
continue
|
||||
}
|
||||
if starter, isStarter := rawStarter.(Starter); isStarter {
|
||||
err := starter.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"reflect"
|
||||
)
|
||||
|
||||
// Deprecated: not used
|
||||
func SelectContext(contextList []context.Context) (int, error) {
|
||||
if len(contextList) == 1 {
|
||||
<-contextList[0].Done()
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
func BindToInterface(finder InterfaceFinder, interfaceName string, interfaceIndex int) Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
|
||||
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -20,16 +20,16 @@ func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, addr
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
|
||||
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
|
||||
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
if interfaceName == "" && interfaceIndex == -1 {
|
||||
return E.New("interface not found: ", interfaceName)
|
||||
}
|
||||
if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) {
|
||||
return nil
|
||||
}
|
||||
return bindToInterface(conn, network, address, finder, interfaceName, interfaceIndex)
|
||||
return bindToInterface(conn, network, address, finder, interfaceName, interfaceIndex, preferInterfaceName)
|
||||
}
|
||||
|
|
|
@ -7,17 +7,17 @@ import (
|
|||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error {
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
var err error
|
||||
if interfaceIndex == -1 {
|
||||
if finder == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||
iif, err := finder.ByName(interfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaceIndex = iif.Index
|
||||
}
|
||||
switch network {
|
||||
case "tcp6", "udp6":
|
||||
|
|
|
@ -1,30 +1,59 @@
|
|||
package control
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type InterfaceFinder interface {
|
||||
InterfaceIndexByName(name string) (int, error)
|
||||
InterfaceNameByIndex(index int) (string, error)
|
||||
Update() error
|
||||
Interfaces() []Interface
|
||||
ByName(name string) (*Interface, error)
|
||||
ByIndex(index int) (*Interface, error)
|
||||
ByAddr(addr netip.Addr) (*Interface, error)
|
||||
}
|
||||
|
||||
func DefaultInterfaceFinder() InterfaceFinder {
|
||||
return (*netInterfaceFinder)(nil)
|
||||
type Interface struct {
|
||||
Index int
|
||||
MTU int
|
||||
Name string
|
||||
HardwareAddr net.HardwareAddr
|
||||
Flags net.Flags
|
||||
Addresses []netip.Prefix
|
||||
}
|
||||
|
||||
type netInterfaceFinder struct{}
|
||||
func (i Interface) Equals(other Interface) bool {
|
||||
return i.Index == other.Index &&
|
||||
i.MTU == other.MTU &&
|
||||
i.Name == other.Name &&
|
||||
common.Equal(i.HardwareAddr, other.HardwareAddr) &&
|
||||
i.Flags == other.Flags &&
|
||||
common.Equal(i.Addresses, other.Addresses)
|
||||
}
|
||||
|
||||
func (w *netInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
|
||||
netInterface, err := net.InterfaceByName(name)
|
||||
func (i Interface) NetInterface() net.Interface {
|
||||
return *(*net.Interface)(unsafe.Pointer(&i))
|
||||
}
|
||||
|
||||
func InterfaceFromNet(iif net.Interface) (Interface, error) {
|
||||
ifAddrs, err := iif.Addrs()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return Interface{}, err
|
||||
}
|
||||
return netInterface.Index, nil
|
||||
return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil
|
||||
}
|
||||
|
||||
func (w *netInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
|
||||
netInterface, err := net.InterfaceByIndex(index)
|
||||
if err != nil {
|
||||
return "", err
|
||||
func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface {
|
||||
return Interface{
|
||||
Index: iif.Index,
|
||||
MTU: iif.MTU,
|
||||
Name: iif.Name,
|
||||
HardwareAddr: iif.HardwareAddr,
|
||||
Flags: iif.Flags,
|
||||
Addresses: addresses,
|
||||
}
|
||||
return netInterface.Name, nil
|
||||
}
|
||||
|
|
89
common/control/bind_finder_default.go
Normal file
89
common/control/bind_finder_default.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
|
||||
|
||||
type DefaultInterfaceFinder struct {
|
||||
interfaces []Interface
|
||||
}
|
||||
|
||||
func NewDefaultInterfaceFinder() *DefaultInterfaceFinder {
|
||||
return &DefaultInterfaceFinder{}
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) Update() error {
|
||||
netIfs, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaces := make([]Interface, 0, len(netIfs))
|
||||
for _, netIf := range netIfs {
|
||||
var iif Interface
|
||||
iif, err = InterfaceFromNet(netIf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaces = append(interfaces, iif)
|
||||
}
|
||||
f.interfaces = interfaces
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) UpdateInterfaces(interfaces []Interface) {
|
||||
f.interfaces = interfaces
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) Interfaces() []Interface {
|
||||
return f.interfaces
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) {
|
||||
for _, netInterface := range f.interfaces {
|
||||
if netInterface.Name == name {
|
||||
return &netInterface, nil
|
||||
}
|
||||
}
|
||||
_, err := net.InterfaceByName(name)
|
||||
if err == nil {
|
||||
err = f.Update()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f.ByName(name)
|
||||
}
|
||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) {
|
||||
for _, netInterface := range f.interfaces {
|
||||
if netInterface.Index == index {
|
||||
return &netInterface, nil
|
||||
}
|
||||
}
|
||||
_, err := net.InterfaceByIndex(index)
|
||||
if err == nil {
|
||||
err = f.Update()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f.ByIndex(index)
|
||||
}
|
||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) {
|
||||
for _, netInterface := range f.interfaces {
|
||||
for _, prefix := range netInterface.Addresses {
|
||||
if prefix.Contains(addr) {
|
||||
return &netInterface, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: addr.AsSlice()}, Err: E.New("no such network interface")}
|
||||
}
|
|
@ -12,20 +12,20 @@ import (
|
|||
|
||||
var ifIndexDisabled atomic.Bool
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error {
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
var err error
|
||||
if !ifIndexDisabled.Load() {
|
||||
if !preferInterfaceName && !ifIndexDisabled.Load() {
|
||||
if interfaceIndex == -1 {
|
||||
if finder == nil {
|
||||
if interfaceName == "" {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||
iif, err := finder.ByName(interfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaceIndex = iif.Index
|
||||
}
|
||||
err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
|
||||
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if E.IsMulti(err, unix.ENOPROTOOPT, unix.EINVAL) {
|
||||
|
@ -35,13 +35,7 @@ func bindToInterface(conn syscall.RawConn, network string, address string, finde
|
|||
}
|
||||
}
|
||||
if interfaceName == "" {
|
||||
if finder == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.ErrInvalid
|
||||
}
|
||||
return unix.BindToDevice(int(fd), interfaceName)
|
||||
})
|
||||
|
|
|
@ -4,6 +4,6 @@ package control
|
|||
|
||||
import "syscall"
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error {
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -9,21 +9,21 @@ import (
|
|||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error {
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
var err error
|
||||
if interfaceIndex == -1 {
|
||||
if finder == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||
iif, err := finder.ByName(interfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaceIndex = iif.Index
|
||||
}
|
||||
handle := syscall.Handle(fd)
|
||||
if M.ParseSocksaddr(address).AddrString() == "" {
|
||||
err = bind4(handle, interfaceIndex)
|
||||
err := bind4(handle, interfaceIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -4,19 +4,26 @@ import (
|
|||
"os"
|
||||
"syscall"
|
||||
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func DisableUDPFragment() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
if N.NetworkName(network) != N.NetworkUDP {
|
||||
return nil
|
||||
}
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
switch network {
|
||||
case "udp4":
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1); err != nil {
|
||||
if network == "udp" || network == "udp4" {
|
||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
|
||||
}
|
||||
case "udp6":
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1); err != nil {
|
||||
}
|
||||
if network == "udp" || network == "udp6" {
|
||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,17 +11,19 @@ import (
|
|||
|
||||
func DisableUDPFragment() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkUDP:
|
||||
default:
|
||||
if N.NetworkName(network) != N.NetworkUDP {
|
||||
return nil
|
||||
}
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
if network == "udp" || network == "udp4" {
|
||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
}
|
||||
}
|
||||
if network == "udp6" {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
|
||||
if network == "udp" || network == "udp6" {
|
||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,17 +25,19 @@ const (
|
|||
|
||||
func DisableUDPFragment() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkUDP:
|
||||
default:
|
||||
if N.NetworkName(network) != N.NetworkUDP {
|
||||
return nil
|
||||
}
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
if network == "udp" || network == "udp4" {
|
||||
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
}
|
||||
}
|
||||
if network == "udp6" {
|
||||
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
|
||||
if network == "udp" || network == "udp6" {
|
||||
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package control
|
|||
import (
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
|
@ -30,6 +31,14 @@ func Conn(conn syscall.Conn, block func(fd uintptr) error) error {
|
|||
return Raw(rawConn, block)
|
||||
}
|
||||
|
||||
func Conn0[T any](conn syscall.Conn, block func(fd uintptr) (T, error)) (T, error) {
|
||||
rawConn, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), err
|
||||
}
|
||||
return Raw0[T](rawConn, block)
|
||||
}
|
||||
|
||||
func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
|
||||
var innerErr error
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
|
@ -37,3 +46,14 @@ func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
|
|||
})
|
||||
return E.Errors(innerErr, err)
|
||||
}
|
||||
|
||||
func Raw0[T any](rawConn syscall.RawConn, block func(fd uintptr) (T, error)) (T, error) {
|
||||
var (
|
||||
value T
|
||||
innerErr error
|
||||
)
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
value, innerErr = block(fd)
|
||||
})
|
||||
return value, E.Errors(innerErr, err)
|
||||
}
|
||||
|
|
|
@ -4,10 +4,10 @@ import (
|
|||
"syscall"
|
||||
)
|
||||
|
||||
func RoutingMark(mark int) Func {
|
||||
func RoutingMark(mark uint32) Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark)
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(mark))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
package control
|
||||
|
||||
func RoutingMark(mark int) Func {
|
||||
func RoutingMark(mark uint32) Func {
|
||||
return nil
|
||||
}
|
||||
|
|
58
common/control/redirect_darwin.go
Normal file
58
common/control/redirect_darwin.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
PF_OUT = 0x2
|
||||
DIOCNATLOOK = 0xc0544417
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
|
||||
pfFd, err := syscall.Open("/dev/pf", 0, syscall.O_RDONLY)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
defer syscall.Close(pfFd)
|
||||
nl := struct {
|
||||
saddr, daddr, rsaddr, rdaddr [16]byte
|
||||
sxport, dxport, rsxport, rdxport [4]byte
|
||||
af, proto, protoVariant, direction uint8
|
||||
}{
|
||||
af: syscall.AF_INET,
|
||||
proto: syscall.IPPROTO_TCP,
|
||||
direction: PF_OUT,
|
||||
}
|
||||
localAddr := M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
|
||||
removeAddr := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
|
||||
if localAddr.IsIPv4() {
|
||||
copy(nl.saddr[:net.IPv4len], removeAddr.Addr.AsSlice())
|
||||
copy(nl.daddr[:net.IPv4len], localAddr.Addr.AsSlice())
|
||||
nl.af = syscall.AF_INET
|
||||
} else {
|
||||
copy(nl.saddr[:], removeAddr.Addr.AsSlice())
|
||||
copy(nl.daddr[:], localAddr.Addr.AsSlice())
|
||||
nl.af = syscall.AF_INET6
|
||||
}
|
||||
binary.BigEndian.PutUint16(nl.sxport[:], removeAddr.Port)
|
||||
binary.BigEndian.PutUint16(nl.dxport[:], localAddr.Port)
|
||||
if _, _, errno := unix.Syscall(syscall.SYS_IOCTL, uintptr(pfFd), DIOCNATLOOK, uintptr(unsafe.Pointer(&nl))); errno != 0 {
|
||||
return netip.AddrPort{}, errno
|
||||
}
|
||||
var address netip.Addr
|
||||
if nl.af == unix.AF_INET {
|
||||
address = M.AddrFromIP(nl.rdaddr[:net.IPv4len])
|
||||
} else {
|
||||
address = netip.AddrFrom16(nl.rdaddr)
|
||||
}
|
||||
return netip.AddrPortFrom(address, binary.BigEndian.Uint16(nl.rdxport[:])), nil
|
||||
}
|
38
common/control/redirect_linux.go
Normal file
38
common/control/redirect_linux.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
|
||||
syscallConn, loaded := common.Cast[syscall.Conn](conn)
|
||||
if !loaded {
|
||||
return netip.AddrPort{}, os.ErrInvalid
|
||||
}
|
||||
return Conn0[netip.AddrPort](syscallConn, func(fd uintptr) (netip.AddrPort, error) {
|
||||
if M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap().IsIPv4() {
|
||||
raw, err := unix.GetsockoptIPv6Mreq(int(fd), unix.IPPROTO_IP, unix.SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
return netip.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
|
||||
} else {
|
||||
raw, err := unix.GetsockoptIPv6MTUInfo(int(fd), unix.IPPROTO_IPV6, unix.SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
var port [2]byte
|
||||
binary.BigEndian.PutUint16(port[:], raw.Addr.Port)
|
||||
return netip.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), binary.LittleEndian.Uint16(port[:])), nil
|
||||
}
|
||||
})
|
||||
}
|
13
common/control/redirect_other.go
Normal file
13
common/control/redirect_other.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
//go:build !linux && !darwin
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, os.ErrInvalid
|
||||
}
|
30
common/control/tcp_keep_alive_linux.go
Normal file
30
common/control/tcp_keep_alive_linux.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
if N.NetworkName(network) != N.NetworkTCP {
|
||||
return nil
|
||||
}
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
return E.Errors(
|
||||
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPIDLE, int(roundDurationUp(idle, time.Second))),
|
||||
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPINTVL, int(roundDurationUp(interval, time.Second))),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func roundDurationUp(d time.Duration, to time.Duration) time.Duration {
|
||||
return (d + to - 1) / to
|
||||
}
|
11
common/control/tcp_keep_alive_stub.go
Normal file
11
common/control/tcp_keep_alive_stub.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
//go:build !linux
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
|
||||
return nil
|
||||
}
|
56
common/control/tproxy_linux.go
Normal file
56
common/control/tproxy_linux.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func TProxy(fd uintptr, family int) error {
|
||||
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
|
||||
if err == nil {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
|
||||
}
|
||||
if err == nil && family == unix.AF_INET6 {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
|
||||
}
|
||||
if err == nil {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
|
||||
}
|
||||
if err == nil && family == unix.AF_INET6 {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func TProxyWriteBack() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if M.ParseSocksaddr(address).Addr.Is6() {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
|
||||
} else {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
|
||||
controlMessages, err := unix.ParseSocketControlMessage(oob)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
for _, message := range controlMessages {
|
||||
if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR {
|
||||
return netip.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil
|
||||
} else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR {
|
||||
return netip.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil
|
||||
}
|
||||
}
|
||||
return netip.AddrPort{}, E.New("not found")
|
||||
}
|
20
common/control/tproxy_other.go
Normal file
20
common/control/tproxy_other.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
//go:build !linux
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
)
|
||||
|
||||
func TProxy(fd uintptr, isIPv6 bool) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func TProxyWriteBack() Func {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, os.ErrInvalid
|
||||
}
|
67
common/domain/adguard_matcher_test.go
Normal file
67
common/domain/adguard_matcher_test.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package domain_test
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/domain"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdGuardMatcher(t *testing.T) {
|
||||
t.Parallel()
|
||||
ruleLines := []string{
|
||||
"||example.org^",
|
||||
"|example.com^",
|
||||
"example.net^",
|
||||
"||example.edu",
|
||||
"||example.edu.tw^",
|
||||
"|example.gov",
|
||||
"example.arpa",
|
||||
}
|
||||
matcher := domain.NewAdGuardMatcher(ruleLines)
|
||||
require.NotNil(t, matcher)
|
||||
matchDomain := []string{
|
||||
"example.org",
|
||||
"www.example.org",
|
||||
"example.com",
|
||||
"example.net",
|
||||
"isexample.net",
|
||||
"www.example.net",
|
||||
"example.edu",
|
||||
"example.edu.cn",
|
||||
"example.edu.tw",
|
||||
"www.example.edu",
|
||||
"www.example.edu.cn",
|
||||
"example.gov",
|
||||
"example.gov.cn",
|
||||
"example.arpa",
|
||||
"www.example.arpa",
|
||||
"isexample.arpa",
|
||||
"example.arpa.cn",
|
||||
"www.example.arpa.cn",
|
||||
"isexample.arpa.cn",
|
||||
}
|
||||
notMatchDomain := []string{
|
||||
"example.org.cn",
|
||||
"notexample.org",
|
||||
"example.com.cn",
|
||||
"www.example.com.cn",
|
||||
"example.net.cn",
|
||||
"notexample.edu",
|
||||
"notexample.edu.cn",
|
||||
"www.example.gov",
|
||||
"notexample.gov",
|
||||
}
|
||||
for _, domain := range matchDomain {
|
||||
require.True(t, matcher.Match(domain), domain)
|
||||
}
|
||||
for _, domain := range notMatchDomain {
|
||||
require.False(t, matcher.Match(domain), domain)
|
||||
}
|
||||
dLines := matcher.Dump()
|
||||
sort.Strings(ruleLines)
|
||||
sort.Strings(dLines)
|
||||
require.Equal(t, ruleLines, dLines)
|
||||
}
|
172
common/domain/adgurad_matcher.go
Normal file
172
common/domain/adgurad_matcher.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
package domain
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
)
|
||||
|
||||
const (
|
||||
anyLabel = '*'
|
||||
suffixLabel = '\b'
|
||||
)
|
||||
|
||||
type AdGuardMatcher struct {
|
||||
set *succinctSet
|
||||
}
|
||||
|
||||
func NewAdGuardMatcher(ruleLines []string) *AdGuardMatcher {
|
||||
ruleList := make([]string, 0, len(ruleLines))
|
||||
for _, ruleLine := range ruleLines {
|
||||
var (
|
||||
isSuffix bool // ||
|
||||
hasStart bool // |
|
||||
hasEnd bool // ^
|
||||
)
|
||||
if strings.HasPrefix(ruleLine, "||") {
|
||||
ruleLine = ruleLine[2:]
|
||||
isSuffix = true
|
||||
} else if strings.HasPrefix(ruleLine, "|") {
|
||||
ruleLine = ruleLine[1:]
|
||||
hasStart = true
|
||||
}
|
||||
if strings.HasSuffix(ruleLine, "^") {
|
||||
ruleLine = ruleLine[:len(ruleLine)-1]
|
||||
hasEnd = true
|
||||
}
|
||||
if isSuffix {
|
||||
ruleLine = string(rootLabel) + ruleLine
|
||||
} else if !hasStart {
|
||||
ruleLine = string(prefixLabel) + ruleLine
|
||||
}
|
||||
if !hasEnd {
|
||||
if strings.HasSuffix(ruleLine, ".") {
|
||||
ruleLine = ruleLine[:len(ruleLine)-1]
|
||||
}
|
||||
ruleLine += string(suffixLabel)
|
||||
}
|
||||
ruleList = append(ruleList, reverseDomain(ruleLine))
|
||||
}
|
||||
ruleList = common.Uniq(ruleList)
|
||||
sort.Strings(ruleList)
|
||||
return &AdGuardMatcher{newSuccinctSet(ruleList)}
|
||||
}
|
||||
|
||||
func ReadAdGuardMatcher(reader varbin.Reader) (*AdGuardMatcher, error) {
|
||||
set, err := readSuccinctSet(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AdGuardMatcher{set}, nil
|
||||
}
|
||||
|
||||
func (m *AdGuardMatcher) Write(writer varbin.Writer) error {
|
||||
return m.set.Write(writer)
|
||||
}
|
||||
|
||||
func (m *AdGuardMatcher) Match(domain string) bool {
|
||||
key := reverseDomain(domain)
|
||||
if m.has([]byte(key), 0, 0) {
|
||||
return true
|
||||
}
|
||||
for {
|
||||
if m.has([]byte(string(suffixLabel)+key), 0, 0) {
|
||||
return true
|
||||
}
|
||||
idx := strings.IndexByte(key, '.')
|
||||
if idx == -1 {
|
||||
return false
|
||||
}
|
||||
key = key[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AdGuardMatcher) has(key []byte, nodeId, bmIdx int) bool {
|
||||
for i := 0; i < len(key); i++ {
|
||||
currentChar := key[i]
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(m.set.labelBitmap, bmIdx) != 0 {
|
||||
return false
|
||||
}
|
||||
nextLabel := m.set.labels[bmIdx-nodeId]
|
||||
if nextLabel == prefixLabel {
|
||||
return true
|
||||
}
|
||||
if nextLabel == rootLabel {
|
||||
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
||||
hasNext := getBit(m.set.leaves, nextNodeId) != 0
|
||||
if currentChar == '.' && hasNext {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if nextLabel == currentChar {
|
||||
break
|
||||
}
|
||||
if nextLabel == anyLabel {
|
||||
idx := bytes.IndexRune(key[i:], '.')
|
||||
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
||||
if idx == -1 {
|
||||
if getBit(m.set.leaves, nextNodeId) != 0 {
|
||||
return true
|
||||
}
|
||||
idx = 0
|
||||
}
|
||||
nextBmIdx := selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nextNodeId-1) + 1
|
||||
if m.has(key[i+idx:], nextNodeId, nextBmIdx) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
||||
bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1
|
||||
}
|
||||
if getBit(m.set.leaves, nodeId) != 0 {
|
||||
return true
|
||||
}
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(m.set.labelBitmap, bmIdx) != 0 {
|
||||
return false
|
||||
}
|
||||
nextLabel := m.set.labels[bmIdx-nodeId]
|
||||
if nextLabel == prefixLabel || nextLabel == rootLabel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AdGuardMatcher) Dump() (ruleLines []string) {
|
||||
for _, key := range m.set.keys() {
|
||||
key = reverseDomain(key)
|
||||
var (
|
||||
isSuffix bool
|
||||
hasStart bool
|
||||
hasEnd bool
|
||||
)
|
||||
if key[0] == prefixLabel {
|
||||
key = key[1:]
|
||||
} else if key[0] == rootLabel {
|
||||
key = key[1:]
|
||||
isSuffix = true
|
||||
} else {
|
||||
hasStart = true
|
||||
}
|
||||
if key[len(key)-1] == suffixLabel {
|
||||
key = key[:len(key)-1]
|
||||
} else {
|
||||
hasEnd = true
|
||||
}
|
||||
if isSuffix {
|
||||
key = "||" + key
|
||||
} else if hasStart {
|
||||
key = "|" + key
|
||||
}
|
||||
if hasEnd {
|
||||
key += "^"
|
||||
}
|
||||
ruleLines = append(ruleLines, key)
|
||||
}
|
||||
return
|
||||
}
|
|
@ -3,21 +3,39 @@ package domain
|
|||
import (
|
||||
"sort"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
)
|
||||
|
||||
const (
|
||||
prefixLabel = '\r'
|
||||
rootLabel = '\n'
|
||||
)
|
||||
|
||||
type Matcher struct {
|
||||
set *succinctSet
|
||||
}
|
||||
|
||||
func NewMatcher(domains []string, domainSuffix []string) *Matcher {
|
||||
domainList := make([]string, 0, len(domains)+len(domainSuffix))
|
||||
func NewMatcher(domains []string, domainSuffix []string, generateLegacy bool) *Matcher {
|
||||
domainList := make([]string, 0, len(domains)+2*len(domainSuffix))
|
||||
seen := make(map[string]bool, len(domainList))
|
||||
for _, domain := range domainSuffix {
|
||||
if seen[domain] {
|
||||
continue
|
||||
}
|
||||
seen[domain] = true
|
||||
domainList = append(domainList, reverseDomainSuffix(domain))
|
||||
if domain[0] == '.' {
|
||||
domainList = append(domainList, reverseDomain(string(prefixLabel)+domain))
|
||||
} else if generateLegacy {
|
||||
domainList = append(domainList, reverseDomain(domain))
|
||||
suffixDomain := "." + domain
|
||||
if !seen[suffixDomain] {
|
||||
seen[suffixDomain] = true
|
||||
domainList = append(domainList, reverseDomain(string(prefixLabel)+suffixDomain))
|
||||
}
|
||||
} else {
|
||||
domainList = append(domainList, reverseDomain(string(rootLabel)+domain))
|
||||
}
|
||||
}
|
||||
for _, domain := range domains {
|
||||
if seen[domain] {
|
||||
|
@ -27,13 +45,94 @@ func NewMatcher(domains []string, domainSuffix []string) *Matcher {
|
|||
domainList = append(domainList, reverseDomain(domain))
|
||||
}
|
||||
sort.Strings(domainList)
|
||||
return &Matcher{
|
||||
newSuccinctSet(domainList),
|
||||
return &Matcher{newSuccinctSet(domainList)}
|
||||
}
|
||||
|
||||
func ReadMatcher(reader varbin.Reader) (*Matcher, error) {
|
||||
set, err := readSuccinctSet(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Matcher{set}, nil
|
||||
}
|
||||
|
||||
func (m *Matcher) Write(writer varbin.Writer) error {
|
||||
return m.set.Write(writer)
|
||||
}
|
||||
|
||||
func (m *Matcher) Match(domain string) bool {
|
||||
return m.set.Has(reverseDomain(domain))
|
||||
return m.has(reverseDomain(domain))
|
||||
}
|
||||
|
||||
func (m *Matcher) has(key string) bool {
|
||||
var nodeId, bmIdx int
|
||||
for i := 0; i < len(key); i++ {
|
||||
currentChar := key[i]
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(m.set.labelBitmap, bmIdx) != 0 {
|
||||
return false
|
||||
}
|
||||
nextLabel := m.set.labels[bmIdx-nodeId]
|
||||
if nextLabel == prefixLabel {
|
||||
return true
|
||||
}
|
||||
if nextLabel == rootLabel {
|
||||
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
||||
hasNext := getBit(m.set.leaves, nextNodeId) != 0
|
||||
if currentChar == '.' && hasNext {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if nextLabel == currentChar {
|
||||
break
|
||||
}
|
||||
}
|
||||
nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
||||
bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1
|
||||
}
|
||||
if getBit(m.set.leaves, nodeId) != 0 {
|
||||
return true
|
||||
}
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(m.set.labelBitmap, bmIdx) != 0 {
|
||||
return false
|
||||
}
|
||||
nextLabel := m.set.labels[bmIdx-nodeId]
|
||||
if nextLabel == prefixLabel || nextLabel == rootLabel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Matcher) Dump() (domainList []string, prefixList []string) {
|
||||
domainMap := make(map[string]bool)
|
||||
prefixMap := make(map[string]bool)
|
||||
for _, key := range m.set.keys() {
|
||||
key = reverseDomain(key)
|
||||
if key[0] == prefixLabel {
|
||||
prefixMap[key[1:]] = true
|
||||
} else if key[0] == rootLabel {
|
||||
prefixList = append(prefixList, key[1:])
|
||||
} else {
|
||||
domainMap[key] = true
|
||||
}
|
||||
}
|
||||
for rawPrefix := range prefixMap {
|
||||
if rawPrefix[0] == '.' {
|
||||
if rootDomain := rawPrefix[1:]; domainMap[rootDomain] {
|
||||
delete(domainMap, rootDomain)
|
||||
prefixList = append(prefixList, rootDomain)
|
||||
continue
|
||||
}
|
||||
}
|
||||
prefixList = append(prefixList, rawPrefix)
|
||||
}
|
||||
for domain := range domainMap {
|
||||
domainList = append(domainList, domain)
|
||||
}
|
||||
sort.Strings(domainList)
|
||||
sort.Strings(prefixList)
|
||||
return domainList, prefixList
|
||||
}
|
||||
|
||||
func reverseDomain(domain string) string {
|
||||
|
@ -46,15 +145,3 @@ func reverseDomain(domain string) string {
|
|||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func reverseDomainSuffix(domain string) string {
|
||||
l := len(domain)
|
||||
b := make([]byte, l+1)
|
||||
for i := 0; i < l; {
|
||||
r, n := utf8.DecodeRuneInString(domain[i:])
|
||||
i += n
|
||||
utf8.EncodeRune(b[l-i:], r)
|
||||
}
|
||||
b[l] = prefixLabel
|
||||
return string(b)
|
||||
}
|
||||
|
|
80
common/domain/matcher_test.go
Normal file
80
common/domain/matcher_test.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package domain_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/domain"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMatcher(t *testing.T) {
|
||||
t.Parallel()
|
||||
testDomain := []string{"example.com", "example.org"}
|
||||
testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"}
|
||||
matcher := domain.NewMatcher(testDomain, testDomainSuffix, false)
|
||||
require.NotNil(t, matcher)
|
||||
require.True(t, matcher.Match("example.com"))
|
||||
require.True(t, matcher.Match("example.org"))
|
||||
require.False(t, matcher.Match("example.cn"))
|
||||
require.True(t, matcher.Match("example.com.cn"))
|
||||
require.True(t, matcher.Match("example.org.cn"))
|
||||
require.False(t, matcher.Match("com.cn"))
|
||||
require.False(t, matcher.Match("org.cn"))
|
||||
require.True(t, matcher.Match("sagernet.org"))
|
||||
require.True(t, matcher.Match("sing-box.sagernet.org"))
|
||||
dDomain, dDomainSuffix := matcher.Dump()
|
||||
require.Equal(t, testDomain, dDomain)
|
||||
require.Equal(t, testDomainSuffix, dDomainSuffix)
|
||||
}
|
||||
|
||||
func TestMatcherLegacy(t *testing.T) {
|
||||
t.Parallel()
|
||||
testDomain := []string{"example.com", "example.org"}
|
||||
testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"}
|
||||
matcher := domain.NewMatcher(testDomain, testDomainSuffix, true)
|
||||
require.NotNil(t, matcher)
|
||||
require.True(t, matcher.Match("example.com"))
|
||||
require.True(t, matcher.Match("example.org"))
|
||||
require.False(t, matcher.Match("example.cn"))
|
||||
require.True(t, matcher.Match("example.com.cn"))
|
||||
require.True(t, matcher.Match("example.org.cn"))
|
||||
require.False(t, matcher.Match("com.cn"))
|
||||
require.False(t, matcher.Match("org.cn"))
|
||||
require.True(t, matcher.Match("sagernet.org"))
|
||||
require.True(t, matcher.Match("sing-box.sagernet.org"))
|
||||
dDomain, dDomainSuffix := matcher.Dump()
|
||||
require.Equal(t, testDomain, dDomain)
|
||||
require.Equal(t, testDomainSuffix, dDomainSuffix)
|
||||
}
|
||||
|
||||
type simpleRuleSet struct {
|
||||
Rules []struct {
|
||||
Domain []string `json:"domain"`
|
||||
DomainSuffix []string `json:"domain_suffix"`
|
||||
}
|
||||
}
|
||||
|
||||
func TestDumpLarge(t *testing.T) {
|
||||
t.Parallel()
|
||||
response, err := http.Get("https://raw.githubusercontent.com/MetaCubeX/meta-rules-dat/sing/geo/geosite/cn.json")
|
||||
require.NoError(t, err)
|
||||
defer response.Body.Close()
|
||||
var ruleSet simpleRuleSet
|
||||
err = json.NewDecoder(response.Body).Decode(&ruleSet)
|
||||
require.NoError(t, err)
|
||||
domainList := ruleSet.Rules[0].Domain
|
||||
domainSuffixList := ruleSet.Rules[0].DomainSuffix
|
||||
require.Len(t, ruleSet.Rules, 1)
|
||||
require.True(t, len(domainList)+len(domainSuffixList) > 0)
|
||||
sort.Strings(domainList)
|
||||
sort.Strings(domainSuffixList)
|
||||
matcher := domain.NewMatcher(domainList, domainSuffixList, false)
|
||||
require.NotNil(t, matcher)
|
||||
dDomain, dDomainSuffix := matcher.Dump()
|
||||
require.Equal(t, domainList, dDomain)
|
||||
require.Equal(t, domainSuffixList, dDomainSuffix)
|
||||
}
|
|
@ -1,10 +1,11 @@
|
|||
package domain
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
const prefixLabel = '\r'
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
)
|
||||
|
||||
// mod from https://github.com/openacid/succinct
|
||||
|
||||
|
@ -42,36 +43,61 @@ func newSuccinctSet(keys []string) *succinctSet {
|
|||
return ss
|
||||
}
|
||||
|
||||
func (ss *succinctSet) Has(key string) bool {
|
||||
var nodeId, bmIdx int
|
||||
for i := 0; i < len(key); i++ {
|
||||
currentChar := key[i]
|
||||
func (ss *succinctSet) keys() []string {
|
||||
var result []string
|
||||
var currentKey []byte
|
||||
var bmIdx, nodeId int
|
||||
|
||||
var traverse func(int, int)
|
||||
traverse = func(nodeId, bmIdx int) {
|
||||
if getBit(ss.leaves, nodeId) != 0 {
|
||||
result = append(result, string(currentKey))
|
||||
}
|
||||
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(ss.labelBitmap, bmIdx) != 0 {
|
||||
return false
|
||||
return
|
||||
}
|
||||
nextLabel := ss.labels[bmIdx-nodeId]
|
||||
if nextLabel == prefixLabel {
|
||||
return true
|
||||
}
|
||||
if nextLabel == currentChar {
|
||||
break
|
||||
}
|
||||
}
|
||||
nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
|
||||
bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1
|
||||
}
|
||||
if getBit(ss.leaves, nodeId) != 0 {
|
||||
return true
|
||||
}
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(ss.labelBitmap, bmIdx) != 0 {
|
||||
return false
|
||||
}
|
||||
if ss.labels[bmIdx-nodeId] == prefixLabel {
|
||||
return true
|
||||
currentKey = append(currentKey, nextLabel)
|
||||
nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
|
||||
nextBmIdx := selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nextNodeId-1) + 1
|
||||
traverse(nextNodeId, nextBmIdx)
|
||||
currentKey = currentKey[:len(currentKey)-1]
|
||||
}
|
||||
}
|
||||
|
||||
traverse(nodeId, bmIdx)
|
||||
return result
|
||||
}
|
||||
|
||||
type succinctSetData struct {
|
||||
Reserved uint8
|
||||
Leaves []uint64
|
||||
LabelBitmap []uint64
|
||||
Labels []byte
|
||||
}
|
||||
|
||||
func readSuccinctSet(reader varbin.Reader) (*succinctSet, error) {
|
||||
matcher, err := varbin.ReadValue[succinctSetData](reader, binary.BigEndian)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
set := &succinctSet{
|
||||
leaves: matcher.Leaves,
|
||||
labelBitmap: matcher.LabelBitmap,
|
||||
labels: matcher.Labels,
|
||||
}
|
||||
set.init()
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func (ss *succinctSet) Write(writer varbin.Writer) error {
|
||||
return varbin.Write(writer, binary.BigEndian, succinctSetData{
|
||||
Leaves: ss.leaves,
|
||||
LabelBitmap: ss.labelBitmap,
|
||||
Labels: ss.labels,
|
||||
})
|
||||
}
|
||||
|
||||
func setBit(bm *[]uint64, i int, v int) {
|
||||
|
|
|
@ -12,3 +12,16 @@ func (e *causeError) Error() string {
|
|||
func (e *causeError) Unwrap() error {
|
||||
return e.cause
|
||||
}
|
||||
|
||||
type causeError1 struct {
|
||||
error
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e *causeError1) Error() string {
|
||||
return e.error.Error() + ": " + e.cause.Error()
|
||||
}
|
||||
|
||||
func (e *causeError1) Unwrap() []error {
|
||||
return []error{e.error, e.cause}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
F "github.com/sagernet/sing/common/format"
|
||||
)
|
||||
|
||||
// Deprecated: wtf is this?
|
||||
type Handler interface {
|
||||
NewError(ctx context.Context, err error)
|
||||
}
|
||||
|
@ -31,6 +32,13 @@ func Cause(cause error, message ...any) error {
|
|||
return &causeError{F.ToString(message...), cause}
|
||||
}
|
||||
|
||||
func Cause1(err error, cause error) error {
|
||||
if cause == nil {
|
||||
panic("cause on an nil error")
|
||||
}
|
||||
return &causeError1{err, cause}
|
||||
}
|
||||
|
||||
func Extend(cause error, message ...any) error {
|
||||
if cause == nil {
|
||||
panic("extend on an nil error")
|
||||
|
@ -39,11 +47,11 @@ func Extend(cause error, message ...any) error {
|
|||
}
|
||||
|
||||
func IsClosedOrCanceled(err error) bool {
|
||||
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded)
|
||||
return IsClosed(err) || IsCanceled(err) || IsTimeout(err)
|
||||
}
|
||||
|
||||
func IsClosed(err error) bool {
|
||||
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET)
|
||||
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, syscall.ENOTCONN)
|
||||
}
|
||||
|
||||
func IsCanceled(err error) bool {
|
||||
|
|
|
@ -1,24 +1,14 @@
|
|||
package exceptions
|
||||
|
||||
import "github.com/sagernet/sing/common"
|
||||
import (
|
||||
"errors"
|
||||
|
||||
type HasInnerError interface {
|
||||
Unwrap() error
|
||||
}
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
// Deprecated: Use errors.Unwrap instead.
|
||||
func Unwrap(err error) error {
|
||||
for {
|
||||
inner, ok := err.(HasInnerError)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
innerErr := inner.Unwrap()
|
||||
if innerErr == nil {
|
||||
break
|
||||
}
|
||||
err = innerErr
|
||||
}
|
||||
return err
|
||||
return errors.Unwrap(err)
|
||||
}
|
||||
|
||||
func Cast[T any](err error) (T, bool) {
|
||||
|
|
|
@ -37,10 +37,13 @@ func Errors(errors ...error) error {
|
|||
}
|
||||
|
||||
func Expand(err error) []error {
|
||||
if multiErr, isMultiErr := err.(MultiError); isMultiErr {
|
||||
return ExpandAll(multiErr.Unwrap())
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if multiErr, isMultiErr := err.(MultiError); isMultiErr {
|
||||
return ExpandAll(common.FilterNotNil(multiErr.Unwrap()))
|
||||
} else {
|
||||
return []error{err}
|
||||
}
|
||||
return []error{err}
|
||||
}
|
||||
|
||||
func ExpandAll(errs []error) []error {
|
||||
|
@ -60,12 +63,5 @@ func IsMulti(err error, targetList ...error) bool {
|
|||
return true
|
||||
}
|
||||
}
|
||||
err = Unwrap(err)
|
||||
multiErr, isMulti := err.(MultiError)
|
||||
if !isMulti {
|
||||
return false
|
||||
}
|
||||
return common.All(multiErr.Unwrap(), func(it error) bool {
|
||||
return IsMulti(it, targetList...)
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -1,17 +1,21 @@
|
|||
package exceptions
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
type TimeoutError interface {
|
||||
Timeout() bool
|
||||
}
|
||||
|
||||
func IsTimeout(err error) bool {
|
||||
if netErr, isNetErr := err.(net.Error); isNetErr {
|
||||
//goland:noinspection GoDeprecation
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
//nolint:staticcheck
|
||||
return netErr.Temporary() && netErr.Timeout()
|
||||
} else if timeoutErr, isTimeout := Cast[TimeoutError](err); isTimeout {
|
||||
}
|
||||
if timeoutErr, isTimeout := Cast[TimeoutError](err); isTimeout {
|
||||
return timeoutErr.Timeout()
|
||||
}
|
||||
return false
|
||||
|
|
59
common/json/badjson/array.go
Normal file
59
common/json/badjson/array.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
type JSONArray []any
|
||||
|
||||
func (a JSONArray) IsEmpty() bool {
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
return common.All(a, func(it any) bool {
|
||||
if valueInterface, valueMaybeEmpty := it.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
func (a JSONArray) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal([]any(a))
|
||||
}
|
||||
|
||||
func (a *JSONArray) UnmarshalJSON(content []byte) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||
arrayStart, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if arrayStart != json.Delim('[') {
|
||||
return E.New("excepted array start, but got ", arrayStart)
|
||||
}
|
||||
err = a.decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
arrayEnd, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if arrayEnd != json.Delim(']') {
|
||||
return E.New("excepted array end, but got ", arrayEnd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *JSONArray) decodeJSON(decoder *json.Decoder) error {
|
||||
for decoder.More() {
|
||||
item, err := decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*a = append(*a, item)
|
||||
}
|
||||
return nil
|
||||
}
|
5
common/json/badjson/empty.go
Normal file
5
common/json/badjson/empty.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package badjson
|
||||
|
||||
type isEmpty interface {
|
||||
IsEmpty() bool
|
||||
}
|
55
common/json/badjson/json.go
Normal file
55
common/json/badjson/json.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
func Decode(ctx context.Context, content []byte) (any, error) {
|
||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
||||
return decodeJSON(decoder)
|
||||
}
|
||||
|
||||
func decodeJSON(decoder *json.Decoder) (any, error) {
|
||||
rawToken, err := decoder.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch token := rawToken.(type) {
|
||||
case json.Delim:
|
||||
switch token {
|
||||
case '{':
|
||||
var object JSONObject
|
||||
err = object.decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawToken, err = decoder.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if rawToken != json.Delim('}') {
|
||||
return nil, E.New("excepted object end, but got ", rawToken)
|
||||
}
|
||||
return &object, nil
|
||||
case '[':
|
||||
var array JSONArray
|
||||
err = array.decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawToken, err = decoder.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if rawToken != json.Delim(']') {
|
||||
return nil, E.New("excepted array end, but got ", rawToken)
|
||||
}
|
||||
return array, nil
|
||||
default:
|
||||
return nil, E.New("excepted object or array end: ", token)
|
||||
}
|
||||
}
|
||||
return rawToken, nil
|
||||
}
|
142
common/json/badjson/merge.go
Normal file
142
common/json/badjson/merge.go
Normal file
|
@ -0,0 +1,142 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
func Omitempty[T any](ctx context.Context, value T) (T, error) {
|
||||
objectContent, err := json.MarshalContext(ctx, value)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal object")
|
||||
}
|
||||
rawNewObject, err := Decode(ctx, objectContent)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), err
|
||||
}
|
||||
newObjectContent, err := json.MarshalContext(ctx, rawNewObject)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
|
||||
}
|
||||
var newObject T
|
||||
err = json.UnmarshalContext(ctx, newObjectContent, &newObject)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
|
||||
}
|
||||
return newObject, nil
|
||||
}
|
||||
|
||||
func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) {
|
||||
rawSource, err := json.MarshalContext(ctx, source)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
||||
}
|
||||
rawDestination, err := json.MarshalContext(ctx, destination)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
||||
}
|
||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
||||
}
|
||||
|
||||
func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
|
||||
if rawSource == nil {
|
||||
return destination, nil
|
||||
}
|
||||
rawDestination, err := json.MarshalContext(ctx, destination)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
||||
}
|
||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
||||
}
|
||||
|
||||
func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
||||
if rawDestination == nil {
|
||||
return source, nil
|
||||
}
|
||||
rawSource, err := json.MarshalContext(ctx, source)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
||||
}
|
||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
||||
}
|
||||
|
||||
func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
||||
rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "merge options")
|
||||
}
|
||||
var merged T
|
||||
err = json.UnmarshalContext(ctx, rawMerged, &merged)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
|
||||
if rawSource == nil && rawDestination == nil {
|
||||
return nil, os.ErrInvalid
|
||||
} else if rawSource == nil {
|
||||
return rawDestination, nil
|
||||
} else if rawDestination == nil {
|
||||
return rawSource, nil
|
||||
}
|
||||
source, err := Decode(ctx, rawSource)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode source")
|
||||
}
|
||||
destination, err := Decode(ctx, rawDestination)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode destination")
|
||||
}
|
||||
if source == nil {
|
||||
return json.MarshalContext(ctx, destination)
|
||||
} else if destination == nil {
|
||||
return json.Marshal(source)
|
||||
}
|
||||
merged, err := mergeJSON(source, destination, disableAppend)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.MarshalContext(ctx, merged)
|
||||
}
|
||||
|
||||
func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) {
|
||||
switch destination := anyDestination.(type) {
|
||||
case JSONArray:
|
||||
if !disableAppend {
|
||||
switch source := anySource.(type) {
|
||||
case JSONArray:
|
||||
destination = append(destination, source...)
|
||||
default:
|
||||
destination = append(destination, source)
|
||||
}
|
||||
}
|
||||
return destination, nil
|
||||
case *JSONObject:
|
||||
switch source := anySource.(type) {
|
||||
case *JSONObject:
|
||||
for _, entry := range source.Entries() {
|
||||
oldValue, loaded := destination.Get(entry.Key)
|
||||
if loaded {
|
||||
var err error
|
||||
entry.Value, err = mergeJSON(entry.Value, oldValue, disableAppend)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "merge object item ", entry.Key)
|
||||
}
|
||||
}
|
||||
destination.Put(entry.Key, entry.Value)
|
||||
}
|
||||
default:
|
||||
return nil, E.New("cannot merge json object into ", reflect.TypeOf(source))
|
||||
}
|
||||
return destination, nil
|
||||
default:
|
||||
return destination, nil
|
||||
}
|
||||
}
|
68
common/json/badjson/merge_objects.go
Normal file
68
common/json/badjson/merge_objects.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
cJSON "github.com/sagernet/sing/common/json/internal/contextjson"
|
||||
)
|
||||
|
||||
func MarshallObjects(objects ...any) ([]byte, error) {
|
||||
return MarshallObjectsContext(context.Background(), objects...)
|
||||
}
|
||||
|
||||
func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) {
|
||||
if len(objects) == 1 {
|
||||
return json.Marshal(objects[0])
|
||||
}
|
||||
var content JSONObject
|
||||
for _, object := range objects {
|
||||
objectMap, err := newJSONObject(ctx, object)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content.PutAll(objectMap)
|
||||
}
|
||||
return content.MarshalJSONContext(ctx)
|
||||
}
|
||||
|
||||
func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error {
|
||||
return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object)
|
||||
}
|
||||
|
||||
func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
|
||||
var content JSONObject
|
||||
err := content.UnmarshalJSONContext(ctx, inputContent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) {
|
||||
content.Remove(key)
|
||||
}
|
||||
if object == nil {
|
||||
if content.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
return E.New("unexpected key: ", content.Keys()[0])
|
||||
}
|
||||
inputContent, err = content.MarshalJSONContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object)
|
||||
}
|
||||
|
||||
func newJSONObject(ctx context.Context, object any) (*JSONObject, error) {
|
||||
inputContent, err := json.MarshalContext(ctx, object)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var content JSONObject
|
||||
err = content.UnmarshalJSONContext(ctx, inputContent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &content, nil
|
||||
}
|
107
common/json/badjson/object.go
Normal file
107
common/json/badjson/object.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/x/collections"
|
||||
"github.com/sagernet/sing/common/x/linkedhashmap"
|
||||
)
|
||||
|
||||
type JSONObject struct {
|
||||
linkedhashmap.Map[string, any]
|
||||
}
|
||||
|
||||
func (m *JSONObject) IsEmpty() bool {
|
||||
if m.Size() == 0 {
|
||||
return true
|
||||
}
|
||||
return common.All(m.Entries(), func(it collections.MapEntry[string, any]) bool {
|
||||
if valueInterface, valueMaybeEmpty := it.Value.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
func (m *JSONObject) MarshalJSON() ([]byte, error) {
|
||||
return m.MarshalJSONContext(context.Background())
|
||||
}
|
||||
|
||||
func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
buffer := new(bytes.Buffer)
|
||||
buffer.WriteString("{")
|
||||
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
|
||||
if valueInterface, valueMaybeEmpty := it.Value.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
iLen := len(items)
|
||||
for i, entry := range items {
|
||||
keyContent, err := json.MarshalContext(ctx, entry.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
||||
buffer.WriteString(": ")
|
||||
valueContent, err := json.MarshalContext(ctx, entry.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.WriteString(strings.TrimSpace(string(valueContent)))
|
||||
if i < iLen-1 {
|
||||
buffer.WriteString(", ")
|
||||
}
|
||||
}
|
||||
buffer.WriteString("}")
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func (m *JSONObject) UnmarshalJSON(content []byte) error {
|
||||
return m.UnmarshalJSONContext(context.Background(), content)
|
||||
}
|
||||
|
||||
func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
||||
m.Clear()
|
||||
objectStart, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if objectStart != json.Delim('{') {
|
||||
return E.New("expected json object start, but starts with ", objectStart)
|
||||
}
|
||||
err = m.decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return E.Cause(err, "decode json object content")
|
||||
}
|
||||
objectEnd, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if objectEnd != json.Delim('}') {
|
||||
return E.New("expected json object end, but ends with ", objectEnd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *JSONObject) decodeJSON(decoder *json.Decoder) error {
|
||||
for decoder.More() {
|
||||
var entryKey string
|
||||
keyToken, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entryKey = keyToken.(string)
|
||||
var entryValue any
|
||||
entryValue, err = decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return E.Cause(err, "decode value for ", entryKey)
|
||||
}
|
||||
m.Put(entryKey, entryValue)
|
||||
}
|
||||
return nil
|
||||
}
|
95
common/json/badjson/typed.go
Normal file
95
common/json/badjson/typed.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/x/linkedhashmap"
|
||||
)
|
||||
|
||||
type TypedMap[K comparable, V any] struct {
|
||||
linkedhashmap.Map[K, V]
|
||||
}
|
||||
|
||||
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
|
||||
return m.MarshalJSONContext(context.Background())
|
||||
}
|
||||
|
||||
func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
buffer := new(bytes.Buffer)
|
||||
buffer.WriteString("{")
|
||||
items := m.Entries()
|
||||
iLen := len(items)
|
||||
for i, entry := range items {
|
||||
keyContent, err := json.MarshalContext(ctx, entry.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
||||
buffer.WriteString(": ")
|
||||
valueContent, err := json.MarshalContext(ctx, entry.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.WriteString(strings.TrimSpace(string(valueContent)))
|
||||
if i < iLen-1 {
|
||||
buffer.WriteString(", ")
|
||||
}
|
||||
}
|
||||
buffer.WriteString("}")
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
|
||||
return m.UnmarshalJSONContext(context.Background(), content)
|
||||
}
|
||||
|
||||
func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
||||
m.Clear()
|
||||
objectStart, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if objectStart != json.Delim('{') {
|
||||
return E.New("expected json object start, but starts with ", objectStart)
|
||||
}
|
||||
err = m.decodeJSON(ctx, decoder)
|
||||
if err != nil {
|
||||
return E.Cause(err, "decode json object content")
|
||||
}
|
||||
objectEnd, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if objectEnd != json.Delim('}') {
|
||||
return E.New("expected json object end, but ends with ", objectEnd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error {
|
||||
for decoder.More() {
|
||||
keyToken, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyContent, err := json.MarshalContext(ctx, keyToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var entryKey K
|
||||
err = json.UnmarshalContext(ctx, keyContent, &entryKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var entryValue V
|
||||
err = decoder.Decode(&entryValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Put(entryKey, entryValue)
|
||||
}
|
||||
return nil
|
||||
}
|
32
common/json/badoption/duration.go
Normal file
32
common/json/badoption/duration.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package badoption
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/json/badoption/internal/my_time"
|
||||
)
|
||||
|
||||
type Duration time.Duration
|
||||
|
||||
func (d Duration) Build() time.Duration {
|
||||
return time.Duration(d)
|
||||
}
|
||||
|
||||
func (d Duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal((time.Duration)(d).String())
|
||||
}
|
||||
|
||||
func (d *Duration) UnmarshalJSON(bytes []byte) error {
|
||||
var value string
|
||||
err := json.Unmarshal(bytes, &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
duration, err := my_time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*d = Duration(duration)
|
||||
return nil
|
||||
}
|
15
common/json/badoption/http.go
Normal file
15
common/json/badoption/http.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package badoption
|
||||
|
||||
import "net/http"
|
||||
|
||||
type HTTPHeader map[string]Listable[string]
|
||||
|
||||
func (h HTTPHeader) Build() http.Header {
|
||||
header := make(http.Header)
|
||||
for name, values := range h {
|
||||
for _, value := range values {
|
||||
header.Add(name, value)
|
||||
}
|
||||
}
|
||||
return header
|
||||
}
|
226
common/json/badoption/internal/my_time/format.go
Normal file
226
common/json/badoption/internal/my_time/format.go
Normal file
|
@ -0,0 +1,226 @@
|
|||
package my_time
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
const durationDay = 24 * time.Hour
|
||||
|
||||
var unitMap = map[string]uint64{
|
||||
"ns": uint64(time.Nanosecond),
|
||||
"us": uint64(time.Microsecond),
|
||||
"µs": uint64(time.Microsecond), // U+00B5 = micro symbol
|
||||
"μs": uint64(time.Microsecond), // U+03BC = Greek letter mu
|
||||
"ms": uint64(time.Millisecond),
|
||||
"s": uint64(time.Second),
|
||||
"m": uint64(time.Minute),
|
||||
"h": uint64(time.Hour),
|
||||
"d": uint64(durationDay),
|
||||
}
|
||||
|
||||
// ParseDuration parses a duration string.
|
||||
// A duration string is a possibly signed sequence of
|
||||
// decimal numbers, each with optional fraction and a unit suffix,
|
||||
// such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func ParseDuration(s string) (time.Duration, error) {
|
||||
// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
|
||||
orig := s
|
||||
var d uint64
|
||||
neg := false
|
||||
|
||||
// Consume [-+]?
|
||||
if s != "" {
|
||||
c := s[0]
|
||||
if c == '-' || c == '+' {
|
||||
neg = c == '-'
|
||||
s = s[1:]
|
||||
}
|
||||
}
|
||||
// Special case: if all that is left is "0", this is zero.
|
||||
if s == "0" {
|
||||
return 0, nil
|
||||
}
|
||||
if s == "" {
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
for s != "" {
|
||||
var (
|
||||
v, f uint64 // integers before, after decimal point
|
||||
scale float64 = 1 // value = v + f/scale
|
||||
)
|
||||
|
||||
var err error
|
||||
|
||||
// The next character must be [0-9.]
|
||||
if !(s[0] == '.' || '0' <= s[0] && s[0] <= '9') {
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
// Consume [0-9]*
|
||||
pl := len(s)
|
||||
v, s, err = leadingInt(s)
|
||||
if err != nil {
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
pre := pl != len(s) // whether we consumed anything before a period
|
||||
|
||||
// Consume (\.[0-9]*)?
|
||||
post := false
|
||||
if s != "" && s[0] == '.' {
|
||||
s = s[1:]
|
||||
pl := len(s)
|
||||
f, scale, s = leadingFraction(s)
|
||||
post = pl != len(s)
|
||||
}
|
||||
if !pre && !post {
|
||||
// no digits (e.g. ".s" or "-.s")
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
|
||||
// Consume unit.
|
||||
i := 0
|
||||
for ; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c == '.' || '0' <= c && c <= '9' {
|
||||
break
|
||||
}
|
||||
}
|
||||
if i == 0 {
|
||||
return 0, errors.New("time: missing unit in duration " + quote(orig))
|
||||
}
|
||||
u := s[:i]
|
||||
s = s[i:]
|
||||
unit, ok := unitMap[u]
|
||||
if !ok {
|
||||
return 0, errors.New("time: unknown unit " + quote(u) + " in duration " + quote(orig))
|
||||
}
|
||||
if v > 1<<63/unit {
|
||||
// overflow
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
v *= unit
|
||||
if f > 0 {
|
||||
// float64 is needed to be nanosecond accurate for fractions of hours.
|
||||
// v >= 0 && (f*unit/scale) <= 3.6e+12 (ns/h, h is the largest unit)
|
||||
v += uint64(float64(f) * (float64(unit) / scale))
|
||||
if v > 1<<63 {
|
||||
// overflow
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
}
|
||||
d += v
|
||||
if d > 1<<63 {
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
}
|
||||
if neg {
|
||||
return -time.Duration(d), nil
|
||||
}
|
||||
if d > 1<<63-1 {
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
return time.Duration(d), nil
|
||||
}
|
||||
|
||||
var errLeadingInt = errors.New("time: bad [0-9]*") // never printed
|
||||
|
||||
// leadingInt consumes the leading [0-9]* from s.
|
||||
func leadingInt[bytes []byte | string](s bytes) (x uint64, rem bytes, err error) {
|
||||
i := 0
|
||||
for ; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c < '0' || c > '9' {
|
||||
break
|
||||
}
|
||||
if x > 1<<63/10 {
|
||||
// overflow
|
||||
return 0, rem, errLeadingInt
|
||||
}
|
||||
x = x*10 + uint64(c) - '0'
|
||||
if x > 1<<63 {
|
||||
// overflow
|
||||
return 0, rem, errLeadingInt
|
||||
}
|
||||
}
|
||||
return x, s[i:], nil
|
||||
}
|
||||
|
||||
// leadingFraction consumes the leading [0-9]* from s.
|
||||
// It is used only for fractions, so does not return an error on overflow,
|
||||
// it just stops accumulating precision.
|
||||
func leadingFraction(s string) (x uint64, scale float64, rem string) {
|
||||
i := 0
|
||||
scale = 1
|
||||
overflow := false
|
||||
for ; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c < '0' || c > '9' {
|
||||
break
|
||||
}
|
||||
if overflow {
|
||||
continue
|
||||
}
|
||||
if x > (1<<63-1)/10 {
|
||||
// It's possible for overflow to give a positive number, so take care.
|
||||
overflow = true
|
||||
continue
|
||||
}
|
||||
y := x*10 + uint64(c) - '0'
|
||||
if y > 1<<63 {
|
||||
overflow = true
|
||||
continue
|
||||
}
|
||||
x = y
|
||||
scale *= 10
|
||||
}
|
||||
return x, scale, s[i:]
|
||||
}
|
||||
|
||||
// These are borrowed from unicode/utf8 and strconv and replicate behavior in
|
||||
// that package, since we can't take a dependency on either.
|
||||
const (
|
||||
lowerhex = "0123456789abcdef"
|
||||
runeSelf = 0x80
|
||||
runeError = '\uFFFD'
|
||||
)
|
||||
|
||||
func quote(s string) string {
|
||||
buf := make([]byte, 1, len(s)+2) // slice will be at least len(s) + quotes
|
||||
buf[0] = '"'
|
||||
for i, c := range s {
|
||||
if c >= runeSelf || c < ' ' {
|
||||
// This means you are asking us to parse a time.Duration or
|
||||
// time.Location with unprintable or non-ASCII characters in it.
|
||||
// We don't expect to hit this case very often. We could try to
|
||||
// reproduce strconv.Quote's behavior with full fidelity but
|
||||
// given how rarely we expect to hit these edge cases, speed and
|
||||
// conciseness are better.
|
||||
var width int
|
||||
if c == runeError {
|
||||
width = 1
|
||||
if i+2 < len(s) && s[i:i+3] == string(runeError) {
|
||||
width = 3
|
||||
}
|
||||
} else {
|
||||
width = len(string(c))
|
||||
}
|
||||
for j := 0; j < width; j++ {
|
||||
buf = append(buf, `\x`...)
|
||||
buf = append(buf, lowerhex[s[i+j]>>4])
|
||||
buf = append(buf, lowerhex[s[i+j]&0xF])
|
||||
}
|
||||
} else {
|
||||
if c == '"' || c == '\\' {
|
||||
buf = append(buf, '\\')
|
||||
}
|
||||
buf = append(buf, string(c)...)
|
||||
}
|
||||
}
|
||||
buf = append(buf, '"')
|
||||
return string(buf)
|
||||
}
|
35
common/json/badoption/listable.go
Normal file
35
common/json/badoption/listable.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package badoption
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
type Listable[T any] []T
|
||||
|
||||
func (l Listable[T]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
arrayList := []T(l)
|
||||
if len(arrayList) == 1 {
|
||||
return json.Marshal(arrayList[0])
|
||||
}
|
||||
return json.MarshalContext(ctx, arrayList)
|
||||
}
|
||||
|
||||
func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
if string(content) == "null" {
|
||||
return nil
|
||||
}
|
||||
var singleItem T
|
||||
err := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem)
|
||||
if err == nil {
|
||||
*l = []T{singleItem}
|
||||
return nil
|
||||
}
|
||||
newErr := json.UnmarshalContextDisallowUnknownFields(ctx, content, (*[]T)(l))
|
||||
if newErr == nil {
|
||||
return nil
|
||||
}
|
||||
return E.Errors(err, newErr)
|
||||
}
|
98
common/json/badoption/netip.go
Normal file
98
common/json/badoption/netip.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package badoption
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
type Addr netip.Addr
|
||||
|
||||
func (a *Addr) Build(defaultAddr netip.Addr) netip.Addr {
|
||||
if a == nil {
|
||||
return defaultAddr
|
||||
}
|
||||
return netip.Addr(*a)
|
||||
}
|
||||
|
||||
func (a *Addr) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(netip.Addr(*a).String())
|
||||
}
|
||||
|
||||
func (a *Addr) UnmarshalJSON(content []byte) error {
|
||||
var value string
|
||||
err := json.Unmarshal(content, &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addr, err := netip.ParseAddr(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*a = Addr(addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Prefix netip.Prefix
|
||||
|
||||
func (p *Prefix) Build(defaultPrefix netip.Prefix) netip.Prefix {
|
||||
if p == nil {
|
||||
return defaultPrefix
|
||||
}
|
||||
return netip.Prefix(*p)
|
||||
}
|
||||
|
||||
func (p *Prefix) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(netip.Prefix(*p).String())
|
||||
}
|
||||
|
||||
func (p *Prefix) UnmarshalJSON(content []byte) error {
|
||||
var value string
|
||||
err := json.Unmarshal(content, &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*p = Prefix(prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Prefixable netip.Prefix
|
||||
|
||||
func (p *Prefixable) Build(defaultPrefix netip.Prefix) netip.Prefix {
|
||||
if p == nil {
|
||||
return defaultPrefix
|
||||
}
|
||||
return netip.Prefix(*p)
|
||||
}
|
||||
|
||||
func (p *Prefixable) MarshalJSON() ([]byte, error) {
|
||||
prefix := netip.Prefix(*p)
|
||||
if prefix.Bits() == prefix.Addr().BitLen() {
|
||||
return json.Marshal(prefix.Addr().String())
|
||||
} else {
|
||||
return json.Marshal(prefix.String())
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Prefixable) UnmarshalJSON(content []byte) error {
|
||||
var value string
|
||||
err := json.Unmarshal(content, &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prefix, prefixErr := netip.ParsePrefix(value)
|
||||
if prefixErr == nil {
|
||||
*p = Prefixable(prefix)
|
||||
return nil
|
||||
}
|
||||
addr, addrErr := netip.ParseAddr(value)
|
||||
if addrErr == nil {
|
||||
*p = Prefixable(netip.PrefixFrom(addr, addr.BitLen()))
|
||||
return nil
|
||||
}
|
||||
return prefixErr
|
||||
}
|
31
common/json/badoption/regexp.go
Normal file
31
common/json/badoption/regexp.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package badoption
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
type Regexp regexp.Regexp
|
||||
|
||||
func (r *Regexp) Build() *regexp.Regexp {
|
||||
return (*regexp.Regexp)(r)
|
||||
}
|
||||
|
||||
func (r *Regexp) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal((*regexp.Regexp)(r).String())
|
||||
}
|
||||
|
||||
func (r *Regexp) UnmarshalJSON(content []byte) error {
|
||||
var stringValue string
|
||||
err := json.Unmarshal(content, &stringValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
regex, err := regexp.Compile(stringValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*r = Regexp(*regex)
|
||||
return nil
|
||||
}
|
128
common/json/comment.go
Normal file
128
common/json/comment.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
)
|
||||
|
||||
// kanged from v2ray
|
||||
|
||||
type commentFilterState = byte
|
||||
|
||||
const (
|
||||
commentFilterStateContent commentFilterState = iota
|
||||
commentFilterStateEscape
|
||||
commentFilterStateDoubleQuote
|
||||
commentFilterStateDoubleQuoteEscape
|
||||
commentFilterStateSingleQuote
|
||||
commentFilterStateSingleQuoteEscape
|
||||
commentFilterStateComment
|
||||
commentFilterStateSlash
|
||||
commentFilterStateMultilineComment
|
||||
commentFilterStateMultilineCommentStar
|
||||
)
|
||||
|
||||
type CommentFilter struct {
|
||||
br *bufio.Reader
|
||||
state commentFilterState
|
||||
}
|
||||
|
||||
func NewCommentFilter(reader io.Reader) io.Reader {
|
||||
return &CommentFilter{br: bufio.NewReader(reader)}
|
||||
}
|
||||
|
||||
func (v *CommentFilter) Read(b []byte) (int, error) {
|
||||
p := b[:0]
|
||||
for len(p) < len(b)-2 {
|
||||
x, err := v.br.ReadByte()
|
||||
if err != nil {
|
||||
if len(p) == 0 {
|
||||
return 0, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
switch v.state {
|
||||
case commentFilterStateContent:
|
||||
switch x {
|
||||
case '"':
|
||||
v.state = commentFilterStateDoubleQuote
|
||||
p = append(p, x)
|
||||
case '\'':
|
||||
v.state = commentFilterStateSingleQuote
|
||||
p = append(p, x)
|
||||
case '\\':
|
||||
v.state = commentFilterStateEscape
|
||||
case '#':
|
||||
v.state = commentFilterStateComment
|
||||
case '/':
|
||||
v.state = commentFilterStateSlash
|
||||
default:
|
||||
p = append(p, x)
|
||||
}
|
||||
case commentFilterStateEscape:
|
||||
p = append(p, '\\', x)
|
||||
v.state = commentFilterStateContent
|
||||
case commentFilterStateDoubleQuote:
|
||||
switch x {
|
||||
case '"':
|
||||
v.state = commentFilterStateContent
|
||||
p = append(p, x)
|
||||
case '\\':
|
||||
v.state = commentFilterStateDoubleQuoteEscape
|
||||
default:
|
||||
p = append(p, x)
|
||||
}
|
||||
case commentFilterStateDoubleQuoteEscape:
|
||||
p = append(p, '\\', x)
|
||||
v.state = commentFilterStateDoubleQuote
|
||||
case commentFilterStateSingleQuote:
|
||||
switch x {
|
||||
case '\'':
|
||||
v.state = commentFilterStateContent
|
||||
p = append(p, x)
|
||||
case '\\':
|
||||
v.state = commentFilterStateSingleQuoteEscape
|
||||
default:
|
||||
p = append(p, x)
|
||||
}
|
||||
case commentFilterStateSingleQuoteEscape:
|
||||
p = append(p, '\\', x)
|
||||
v.state = commentFilterStateSingleQuote
|
||||
case commentFilterStateComment:
|
||||
if x == '\n' {
|
||||
v.state = commentFilterStateContent
|
||||
p = append(p, '\n')
|
||||
}
|
||||
case commentFilterStateSlash:
|
||||
switch x {
|
||||
case '/':
|
||||
v.state = commentFilterStateComment
|
||||
case '*':
|
||||
v.state = commentFilterStateMultilineComment
|
||||
default:
|
||||
p = append(p, '/', x)
|
||||
}
|
||||
case commentFilterStateMultilineComment:
|
||||
switch x {
|
||||
case '*':
|
||||
v.state = commentFilterStateMultilineCommentStar
|
||||
case '\n':
|
||||
p = append(p, '\n')
|
||||
}
|
||||
case commentFilterStateMultilineCommentStar:
|
||||
switch x {
|
||||
case '/':
|
||||
v.state = commentFilterStateContent
|
||||
case '*':
|
||||
// Stay
|
||||
case '\n':
|
||||
p = append(p, '\n')
|
||||
default:
|
||||
v.state = commentFilterStateMultilineComment
|
||||
}
|
||||
default:
|
||||
panic("Unknown state.")
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
23
common/json/context.go
Normal file
23
common/json/context.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
//go:build go1.20 && !without_contextjson
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/json/internal/contextjson"
|
||||
)
|
||||
|
||||
var (
|
||||
Marshal = json.Marshal
|
||||
Unmarshal = json.Unmarshal
|
||||
NewEncoder = json.NewEncoder
|
||||
NewDecoder = json.NewDecoder
|
||||
)
|
||||
|
||||
type (
|
||||
Encoder = json.Encoder
|
||||
Decoder = json.Decoder
|
||||
Token = json.Token
|
||||
Delim = json.Delim
|
||||
SyntaxError = json.SyntaxError
|
||||
RawMessage = json.RawMessage
|
||||
)
|
23
common/json/context_ext.go
Normal file
23
common/json/context_ext.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing/common/json/internal/contextjson"
|
||||
)
|
||||
|
||||
var (
|
||||
MarshalContext = json.MarshalContext
|
||||
UnmarshalContext = json.UnmarshalContext
|
||||
NewEncoderContext = json.NewEncoderContext
|
||||
NewDecoderContext = json.NewDecoderContext
|
||||
UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields
|
||||
)
|
||||
|
||||
type ContextMarshaler interface {
|
||||
MarshalJSONContext(ctx context.Context) ([]byte, error)
|
||||
}
|
||||
|
||||
type ContextUnmarshaler interface {
|
||||
UnmarshalJSONContext(ctx context.Context, content []byte) error
|
||||
}
|
3
common/json/internal/contextjson/README.md
Normal file
3
common/json/internal/contextjson/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# contextjson
|
||||
|
||||
mod from go1.21.4
|
11
common/json/internal/contextjson/context.go
Normal file
11
common/json/internal/contextjson/context.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package json
|
||||
|
||||
import "context"
|
||||
|
||||
type ContextMarshaler interface {
|
||||
MarshalJSONContext(ctx context.Context) ([]byte, error)
|
||||
}
|
||||
|
||||
type ContextUnmarshaler interface {
|
||||
UnmarshalJSONContext(ctx context.Context, content []byte) error
|
||||
}
|
43
common/json/internal/contextjson/context_test.go
Normal file
43
common/json/internal/contextjson/context_test.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package json_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/json/internal/contextjson"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type myStruct struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
return json.Marshal(ctx.Value("key").(string))
|
||||
}
|
||||
|
||||
func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
m.value = ctx.Value("key").(string)
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
func TestMarshalContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.WithValue(context.Background(), "key", "value")
|
||||
var s myStruct
|
||||
b, err := json.MarshalContext(ctx, &s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte(`"value"`), b)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
func TestUnmarshalContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.WithValue(context.Background(), "key", "value")
|
||||
var s myStruct
|
||||
err := json.UnmarshalContext(ctx, []byte(`{}`), &s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "value", s.value)
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue