mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-06 13:27:39 +03:00
Compare commits
9 commits
dev
...
v0.5.0-alp
Author | SHA1 | Date | |
---|---|---|---|
|
37bee34b73 | ||
|
c86c25365c | ||
|
3e02be0e9c | ||
|
89c9d91019 | ||
|
dd8cd39ef5 | ||
|
1482471859 | ||
|
cb5ca6e926 | ||
|
db031f7aef | ||
|
1c6ec119f1 |
146 changed files with 1268 additions and 7035 deletions
8
.github/renovate.json
vendored
8
.github/renovate.json
vendored
|
@ -1,13 +1,11 @@
|
||||||
{
|
{
|
||||||
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
||||||
"commitMessagePrefix": "[dependencies]",
|
"commitMessagePrefix": "[dependencies]",
|
||||||
|
"branchName": "main",
|
||||||
"extends": [
|
"extends": [
|
||||||
"config:base",
|
"config:base",
|
||||||
":disableRateLimiting"
|
":disableRateLimiting"
|
||||||
],
|
],
|
||||||
"baseBranches": [
|
|
||||||
"dev"
|
|
||||||
],
|
|
||||||
"packageRules": [
|
"packageRules": [
|
||||||
{
|
{
|
||||||
"matchManagers": [
|
"matchManagers": [
|
||||||
|
@ -17,9 +15,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"matchManagers": [
|
"matchManagers": [
|
||||||
"dockerfile"
|
"gomod"
|
||||||
],
|
],
|
||||||
"groupName": "Dockerfile"
|
"groupName": "gomod"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
name: test
|
name: Debug build
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
@ -16,7 +16,7 @@ on:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
name: Linux
|
name: Linux Debug build
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
@ -24,14 +24,14 @@ jobs:
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: ^1.23
|
go-version: ^1.22
|
||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
make test
|
make test
|
||||||
build_go120:
|
build_go120:
|
||||||
name: Linux (Go 1.20)
|
name: Linux Debug build (Go 1.20)
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
@ -39,7 +39,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: ~1.20
|
go-version: ~1.20
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
|
@ -47,7 +47,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
make test
|
make test
|
||||||
build_go121:
|
build_go121:
|
||||||
name: Linux (Go 1.21)
|
name: Linux Debug build (Go 1.21)
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
@ -55,31 +55,15 @@ jobs:
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: ~1.21
|
go-version: ~1.21
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
make test
|
make test
|
||||||
build_go122:
|
build__windows:
|
||||||
name: Linux (Go 1.22)
|
name: Windows Debug build
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
- name: Setup Go
|
|
||||||
uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version: ~1.22
|
|
||||||
continue-on-error: true
|
|
||||||
- name: Build
|
|
||||||
run: |
|
|
||||||
make test
|
|
||||||
build_windows:
|
|
||||||
name: Windows
|
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
@ -87,15 +71,15 @@ jobs:
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: ^1.23
|
go-version: ^1.22
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
make test
|
make test
|
||||||
build_darwin:
|
build_darwin:
|
||||||
name: macOS
|
name: macOS Debug build
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
@ -103,9 +87,9 @@ jobs:
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: ^1.23
|
go-version: ^1.22
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
10
.github/workflows/lint.yml
vendored
10
.github/workflows/lint.yml
vendored
|
@ -1,4 +1,4 @@
|
||||||
name: lint
|
name: Lint
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
@ -24,16 +24,16 @@ jobs:
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: ^1.23
|
go-version: ^1.22
|
||||||
- name: Cache go module
|
- name: Cache go module
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
key: go-${{ hashFiles('**/go.sum') }}
|
key: go-${{ hashFiles('**/go.sum') }}
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v6
|
uses: golangci/golangci-lint-action@v3
|
||||||
with:
|
with:
|
||||||
version: latest
|
version: latest
|
|
@ -5,8 +5,6 @@ linters:
|
||||||
- govet
|
- govet
|
||||||
- gci
|
- gci
|
||||||
- staticcheck
|
- staticcheck
|
||||||
- paralleltest
|
|
||||||
- ineffassign
|
|
||||||
|
|
||||||
linters-settings:
|
linters-settings:
|
||||||
gci:
|
gci:
|
||||||
|
@ -16,9 +14,4 @@ linters-settings:
|
||||||
- prefix(github.com/sagernet/)
|
- prefix(github.com/sagernet/)
|
||||||
- default
|
- default
|
||||||
staticcheck:
|
staticcheck:
|
||||||
checks:
|
go: '1.20'
|
||||||
- all
|
|
||||||
- -SA1003
|
|
||||||
|
|
||||||
run:
|
|
||||||
go: "1.23"
|
|
||||||
|
|
12
Makefile
12
Makefile
|
@ -8,14 +8,14 @@ fmt_install:
|
||||||
go install -v github.com/daixiang0/gci@latest
|
go install -v github.com/daixiang0/gci@latest
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
GOOS=linux golangci-lint run
|
GOOS=linux golangci-lint run ./...
|
||||||
GOOS=android golangci-lint run
|
GOOS=android golangci-lint run ./...
|
||||||
GOOS=windows golangci-lint run
|
GOOS=windows golangci-lint run ./...
|
||||||
GOOS=darwin golangci-lint run
|
GOOS=darwin golangci-lint run ./...
|
||||||
GOOS=freebsd golangci-lint run
|
GOOS=freebsd golangci-lint run ./...
|
||||||
|
|
||||||
lint_install:
|
lint_install:
|
||||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||||
|
|
||||||
test:
|
test:
|
||||||
go test ./...
|
go test $(shell go list ./... | grep -v /internal/)
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
# sing
|
# sing
|
||||||
|
|
||||||

|
|
||||||

|
|
||||||
|
|
||||||
Do you hear the people sing?
|
Do you hear the people sing?
|
|
@ -2,10 +2,11 @@ package baderror
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Contains(err error, msgList ...string) bool {
|
func Contains(err error, msgList ...string) bool {
|
||||||
|
@ -21,7 +22,8 @@ func WrapH2(err error) error {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
err = E.Unwrap(err)
|
||||||
|
if err == io.ErrUnexpectedEOF {
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {
|
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {
|
||||||
|
|
|
@ -1,18 +0,0 @@
|
||||||
package binary
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
func DataSize(t reflect.Value) int {
|
|
||||||
return dataSize(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func EncodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
|
|
||||||
(&encoder{order: order, buf: buf}).value(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
|
|
||||||
(&decoder{order: order, buf: buf}).value(v)
|
|
||||||
}
|
|
305
common/binary/variant_data.go
Normal file
305
common/binary/variant_data.go
Normal file
|
@ -0,0 +1,305 @@
|
||||||
|
package binary
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ReadDataSlice(r *bufio.Reader, order ByteOrder, data ...any) error {
|
||||||
|
for index, item := range data {
|
||||||
|
err := ReadData(r, order, item)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "[", index, "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadData(r *bufio.Reader, order ByteOrder, data any) error {
|
||||||
|
switch dataPtr := data.(type) {
|
||||||
|
case *[]uint8:
|
||||||
|
bytesLen, err := ReadUvarint(r)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "bytes length")
|
||||||
|
}
|
||||||
|
newBytes := make([]uint8, bytesLen)
|
||||||
|
_, err = io.ReadFull(r, newBytes)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "bytes value")
|
||||||
|
}
|
||||||
|
*dataPtr = newBytes
|
||||||
|
default:
|
||||||
|
if intBaseDataSize(data) != 0 {
|
||||||
|
return Read(r, order, data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dataValue := reflect.ValueOf(data)
|
||||||
|
if dataValue.Kind() == reflect.Pointer {
|
||||||
|
dataValue = dataValue.Elem()
|
||||||
|
}
|
||||||
|
return readData(r, order, dataValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readData(r *bufio.Reader, order ByteOrder, data reflect.Value) error {
|
||||||
|
switch data.Kind() {
|
||||||
|
case reflect.Pointer:
|
||||||
|
pointerValue, err := r.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if pointerValue == 0 {
|
||||||
|
data.SetZero()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if data.IsNil() {
|
||||||
|
data.Set(reflect.New(data.Type().Elem()))
|
||||||
|
}
|
||||||
|
return readData(r, order, data.Elem())
|
||||||
|
case reflect.String:
|
||||||
|
stringLength, err := ReadUvarint(r)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "string length")
|
||||||
|
}
|
||||||
|
if stringLength == 0 {
|
||||||
|
data.SetZero()
|
||||||
|
} else {
|
||||||
|
stringData := make([]byte, stringLength)
|
||||||
|
_, err = io.ReadFull(r, stringData)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "string value")
|
||||||
|
}
|
||||||
|
data.SetString(string(stringData))
|
||||||
|
}
|
||||||
|
case reflect.Array:
|
||||||
|
arrayLen := data.Len()
|
||||||
|
for i := 0; i < arrayLen; i++ {
|
||||||
|
err := readData(r, order, data.Index(i))
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "[", i, "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Slice:
|
||||||
|
sliceLength, err := ReadUvarint(r)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "slice length")
|
||||||
|
}
|
||||||
|
if !data.IsNil() && data.Cap() >= int(sliceLength) {
|
||||||
|
data.SetLen(int(sliceLength))
|
||||||
|
} else if sliceLength > 0 {
|
||||||
|
data.Set(reflect.MakeSlice(data.Type(), int(sliceLength), int(sliceLength)))
|
||||||
|
}
|
||||||
|
if sliceLength > 0 {
|
||||||
|
if data.Type().Elem().Kind() == reflect.Uint8 {
|
||||||
|
_, err = io.ReadFull(r, data.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "bytes value")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for index := 0; index < int(sliceLength); index++ {
|
||||||
|
err = readData(r, order, data.Index(index))
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "[", index, "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Map:
|
||||||
|
mapLength, err := ReadUvarint(r)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "map length")
|
||||||
|
}
|
||||||
|
data.Set(reflect.MakeMap(data.Type()))
|
||||||
|
for index := 0; index < int(mapLength); index++ {
|
||||||
|
key := reflect.New(data.Type().Key()).Elem()
|
||||||
|
err = readData(r, order, key)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "[", index, "].key")
|
||||||
|
}
|
||||||
|
value := reflect.New(data.Type().Elem()).Elem()
|
||||||
|
err = readData(r, order, value)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "[", index, "].value")
|
||||||
|
}
|
||||||
|
data.SetMapIndex(key, value)
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
fieldType := data.Type()
|
||||||
|
fieldLen := data.NumField()
|
||||||
|
for i := 0; i < fieldLen; i++ {
|
||||||
|
field := data.Field(i)
|
||||||
|
fieldName := fieldType.Field(i).Name
|
||||||
|
if field.CanSet() || fieldName != "_" {
|
||||||
|
err := readData(r, order, field)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
size := dataSize(data)
|
||||||
|
if size < 0 {
|
||||||
|
return errors.New("invalid type " + reflect.TypeOf(data).String())
|
||||||
|
}
|
||||||
|
d := &decoder{order: order, buf: make([]byte, size)}
|
||||||
|
_, err := io.ReadFull(r, d.buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
d.value(data)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteDataSlice(writer *bufio.Writer, order ByteOrder, data ...any) error {
|
||||||
|
for index, item := range data {
|
||||||
|
err := WriteData(writer, order, item)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "[", index, "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteData(writer *bufio.Writer, order ByteOrder, data any) error {
|
||||||
|
switch dataPtr := data.(type) {
|
||||||
|
case []uint8:
|
||||||
|
_, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(len(dataPtr))))
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "bytes length")
|
||||||
|
}
|
||||||
|
_, err = writer.Write(dataPtr)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "bytes value")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if intBaseDataSize(data) != 0 {
|
||||||
|
return Write(writer, order, data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return writeData(writer, order, reflect.Indirect(reflect.ValueOf(data)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeData(writer *bufio.Writer, order ByteOrder, data reflect.Value) error {
|
||||||
|
switch data.Kind() {
|
||||||
|
case reflect.Pointer:
|
||||||
|
if data.IsNil() {
|
||||||
|
err := writer.WriteByte(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := writer.WriteByte(1)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return writeData(writer, order, data.Elem())
|
||||||
|
}
|
||||||
|
case reflect.String:
|
||||||
|
stringValue := data.String()
|
||||||
|
_, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(len(stringValue))))
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "string length")
|
||||||
|
}
|
||||||
|
if stringValue != "" {
|
||||||
|
_, err = writer.WriteString(stringValue)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "string value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Array:
|
||||||
|
dataLen := data.Len()
|
||||||
|
for i := 0; i < dataLen; i++ {
|
||||||
|
err := writeData(writer, order, data.Index(i))
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "[", i, "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Slice:
|
||||||
|
dataLen := data.Len()
|
||||||
|
_, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(dataLen)))
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "slice length")
|
||||||
|
}
|
||||||
|
if dataLen > 0 {
|
||||||
|
if data.Type().Elem().Kind() == reflect.Uint8 {
|
||||||
|
_, err = writer.Write(data.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "bytes value")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := 0; i < dataLen; i++ {
|
||||||
|
err = writeData(writer, order, data.Index(i))
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "[", i, "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Map:
|
||||||
|
dataLen := data.Len()
|
||||||
|
_, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(dataLen)))
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "map length")
|
||||||
|
}
|
||||||
|
if dataLen > 0 {
|
||||||
|
for index, key := range data.MapKeys() {
|
||||||
|
err = writeData(writer, order, key)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "[", index, "].key")
|
||||||
|
}
|
||||||
|
err = writeData(writer, order, data.MapIndex(key))
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "[", index, "].value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
fieldType := data.Type()
|
||||||
|
fieldLen := data.NumField()
|
||||||
|
for i := 0; i < fieldLen; i++ {
|
||||||
|
field := data.Field(i)
|
||||||
|
fieldName := fieldType.Field(i).Name
|
||||||
|
if field.CanSet() || fieldName != "_" {
|
||||||
|
err := writeData(writer, order, field)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
size := dataSize(data)
|
||||||
|
if size < 0 {
|
||||||
|
return errors.New("binary.Write: some values are not fixed-sized in type " + data.Type().String())
|
||||||
|
}
|
||||||
|
buf := make([]byte, size)
|
||||||
|
e := &encoder{order: order, buf: buf}
|
||||||
|
e.value(data)
|
||||||
|
_, err := writer.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, reflect.TypeOf(data).String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func intBaseDataSize(data any) int {
|
||||||
|
switch data.(type) {
|
||||||
|
case bool, int8, uint8:
|
||||||
|
return 1
|
||||||
|
case int16, uint16:
|
||||||
|
return 2
|
||||||
|
case int32, uint32:
|
||||||
|
return 4
|
||||||
|
case int64, uint64:
|
||||||
|
return 8
|
||||||
|
case float32:
|
||||||
|
return 4
|
||||||
|
case float64:
|
||||||
|
return 8
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
|
@ -9,20 +9,19 @@ import (
|
||||||
|
|
||||||
type AddrConn struct {
|
type AddrConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
Source M.Socksaddr
|
M.Metadata
|
||||||
Destination M.Socksaddr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AddrConn) LocalAddr() net.Addr {
|
func (c *AddrConn) LocalAddr() net.Addr {
|
||||||
if c.Destination.IsValid() {
|
if c.Metadata.Destination.IsValid() {
|
||||||
return c.Destination.TCPAddr()
|
return c.Metadata.Destination.TCPAddr()
|
||||||
}
|
}
|
||||||
return c.Conn.LocalAddr()
|
return c.Conn.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AddrConn) RemoteAddr() net.Addr {
|
func (c *AddrConn) RemoteAddr() net.Addr {
|
||||||
if c.Source.IsValid() {
|
if c.Metadata.Source.IsValid() {
|
||||||
return c.Source.TCPAddr()
|
return c.Metadata.Source.TCPAddr()
|
||||||
}
|
}
|
||||||
return c.Conn.RemoteAddr()
|
return c.Conn.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,25 +41,6 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *BufferedWriter) WriteByte(c byte) error {
|
|
||||||
w.access.Lock()
|
|
||||||
defer w.access.Unlock()
|
|
||||||
if w.buffer == nil {
|
|
||||||
return common.Error(w.upstream.Write([]byte{c}))
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
err := w.buffer.WriteByte(c)
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
_, err = w.upstream.Write(w.buffer.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
w.buffer.Reset()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *BufferedWriter) Fallthrough() error {
|
func (w *BufferedWriter) Fallthrough() error {
|
||||||
w.access.Lock()
|
w.access.Lock()
|
||||||
defer w.access.Unlock()
|
defer w.access.Unlock()
|
||||||
|
|
|
@ -3,6 +3,7 @@ package bufio
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
@ -59,6 +60,13 @@ func (c *CachedConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *CachedConn) SetReadDeadline(t time.Time) error {
|
||||||
|
if c.buffer != nil && !c.buffer.IsEmpty() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.Conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *CachedConn) ReadFrom(r io.Reader) (n int64, err error) {
|
func (c *CachedConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
return Copy(c.Conn, r)
|
return Copy(c.Conn, r)
|
||||||
}
|
}
|
||||||
|
@ -184,12 +192,10 @@ func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
|
||||||
if buffer != nil {
|
if buffer != nil {
|
||||||
buffer.DecRef()
|
buffer.DecRef()
|
||||||
}
|
}
|
||||||
packet := N.NewPacketBuffer()
|
return &N.PacketBuffer{
|
||||||
*packet = N.PacketBuffer{
|
|
||||||
Buffer: buffer,
|
Buffer: buffer,
|
||||||
Destination: c.destination,
|
Destination: c.destination,
|
||||||
}
|
}
|
||||||
return packet
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CachedPacketConn) Upstream() any {
|
func (c *CachedPacketConn) Upstream() any {
|
||||||
|
|
|
@ -35,7 +35,14 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||||
|
|
||||||
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||||
defer buffer.Release()
|
defer buffer.Release()
|
||||||
return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
|
if destination.IsFqdn() {
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
|
||||||
|
}
|
||||||
|
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ExtendedUDPConn) Upstream() any {
|
func (w *ExtendedUDPConn) Upstream() any {
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/common/rw"
|
||||||
"github.com/sagernet/sing/common/task"
|
"github.com/sagernet/sing/common/task"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -29,35 +30,27 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
||||||
if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
||||||
cachedBuffer := cachedSrc.ReadCached()
|
cachedBuffer := cachedSrc.ReadCached()
|
||||||
if cachedBuffer != nil {
|
if cachedBuffer != nil {
|
||||||
dataLen := cachedBuffer.Len()
|
if !cachedBuffer.IsEmpty() {
|
||||||
_, err = destination.Write(cachedBuffer.Bytes())
|
_, err = destination.Write(cachedBuffer.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
cachedBuffer.Release()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
cachedBuffer.Release()
|
cachedBuffer.Release()
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, counter := range readCounters {
|
|
||||||
counter(int64(dataLen))
|
|
||||||
}
|
|
||||||
for _, counter := range writeCounters {
|
|
||||||
counter(int64(dataLen))
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break
|
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
|
||||||
}
|
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
||||||
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
|
if srcIsSyscall && dstIsSyscall {
|
||||||
}
|
var handled bool
|
||||||
|
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||||
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
if handled {
|
||||||
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
|
return
|
||||||
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
}
|
||||||
if srcIsSyscall && dstIsSyscall {
|
|
||||||
var handled bool
|
|
||||||
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
|
||||||
if handled {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
||||||
}
|
}
|
||||||
|
@ -83,7 +76,6 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N
|
||||||
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
|
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: not used
|
|
||||||
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||||
buffer.IncRef()
|
buffer.IncRef()
|
||||||
defer buffer.DecRef()
|
defer buffer.DecRef()
|
||||||
|
@ -122,10 +114,19 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||||
options := N.NewReadWaitOptions(source, destination)
|
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||||
|
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||||
|
bufferSize := N.CalculateMTU(source, destination)
|
||||||
|
if bufferSize > 0 {
|
||||||
|
bufferSize += frontHeadroom + rearHeadroom
|
||||||
|
} else {
|
||||||
|
bufferSize = buf.BufferSize
|
||||||
|
}
|
||||||
var notFirstTime bool
|
var notFirstTime bool
|
||||||
for {
|
for {
|
||||||
buffer := options.NewBuffer()
|
buffer := buf.NewSize(bufferSize)
|
||||||
|
buffer.Resize(frontHeadroom, 0)
|
||||||
|
buffer.Reserve(rearHeadroom)
|
||||||
err = source.ReadBuffer(buffer)
|
err = source.ReadBuffer(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
|
@ -136,7 +137,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataLen := buffer.Len()
|
dataLen := buffer.Len()
|
||||||
options.PostReturn(buffer)
|
buffer.OverCap(rearHeadroom)
|
||||||
err = destination.WriteBuffer(buffer)
|
err = destination.WriteBuffer(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Leak()
|
buffer.Leak()
|
||||||
|
@ -157,12 +158,16 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
|
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
|
||||||
|
return CopyConnContextList([]context.Context{ctx}, source, destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
|
||||||
var group task.Group
|
var group task.Group
|
||||||
if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex {
|
if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex {
|
||||||
group.Append("upload", func(ctx context.Context) error {
|
group.Append("upload", func(ctx context.Context) error {
|
||||||
err := common.Error(Copy(destination, source))
|
err := common.Error(Copy(destination, source))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
N.CloseWrite(destination)
|
rw.CloseWrite(destination)
|
||||||
} else {
|
} else {
|
||||||
common.Close(destination)
|
common.Close(destination)
|
||||||
}
|
}
|
||||||
|
@ -174,11 +179,11 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error
|
||||||
return common.Error(Copy(destination, source))
|
return common.Error(Copy(destination, source))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if _, srcDuplex := common.Cast[N.WriteCloser](source); srcDuplex {
|
if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex {
|
||||||
group.Append("download", func(ctx context.Context) error {
|
group.Append("download", func(ctx context.Context) error {
|
||||||
err := common.Error(Copy(source, destination))
|
err := common.Error(Copy(source, destination))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
N.CloseWrite(source)
|
rw.CloseWrite(source)
|
||||||
} else {
|
} else {
|
||||||
common.Close(source)
|
common.Close(source)
|
||||||
}
|
}
|
||||||
|
@ -193,7 +198,7 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error
|
||||||
group.Cleanup(func() {
|
group.Cleanup(func() {
|
||||||
common.Close(source, destination)
|
common.Close(source, destination)
|
||||||
})
|
})
|
||||||
return group.Run(ctx)
|
return group.RunContextList(contextList)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
|
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
|
||||||
|
@ -213,24 +218,24 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if cachedPackets != nil {
|
if cachedPackets != nil {
|
||||||
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
|
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
|
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||||
n += copeN
|
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
|
||||||
var (
|
var (
|
||||||
handled bool
|
handled bool
|
||||||
copeN int64
|
copeN int64
|
||||||
)
|
)
|
||||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
||||||
if isReadWaiter {
|
if isReadWaiter {
|
||||||
needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
|
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||||
|
FrontHeadroom: frontHeadroom,
|
||||||
|
RearHeadroom: rearHeadroom,
|
||||||
|
MTU: N.CalculateMTU(source, destinationConn),
|
||||||
|
})
|
||||||
if !needCopy || common.LowMemory {
|
if !needCopy || common.LowMemory {
|
||||||
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
||||||
if handled {
|
if handled {
|
||||||
|
@ -244,19 +249,28 @@ func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReade
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||||
options := N.NewReadWaitOptions(source, destination)
|
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||||
var destinationAddress M.Socksaddr
|
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||||
|
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||||
|
if bufferSize > 0 {
|
||||||
|
bufferSize += frontHeadroom + rearHeadroom
|
||||||
|
} else {
|
||||||
|
bufferSize = buf.UDPBufferSize
|
||||||
|
}
|
||||||
|
var destination M.Socksaddr
|
||||||
for {
|
for {
|
||||||
buffer := options.NewPacketBuffer()
|
buffer := buf.NewSize(bufferSize)
|
||||||
destinationAddress, err = source.ReadPacket(buffer)
|
buffer.Resize(frontHeadroom, 0)
|
||||||
|
buffer.Reserve(rearHeadroom)
|
||||||
|
destination, err = source.ReadPacket(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataLen := buffer.Len()
|
dataLen := buffer.Len()
|
||||||
options.PostReturn(buffer)
|
buffer.OverCap(rearHeadroom)
|
||||||
err = destination.WritePacket(buffer, destinationAddress)
|
err = destinationConn.WritePacket(buffer, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Leak()
|
buffer.Leak()
|
||||||
if !notFirstTime {
|
if !notFirstTime {
|
||||||
|
@ -264,25 +278,34 @@ func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter,
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
n += int64(dataLen)
|
||||||
for _, counter := range readCounters {
|
for _, counter := range readCounters {
|
||||||
counter(int64(dataLen))
|
counter(int64(dataLen))
|
||||||
}
|
}
|
||||||
for _, counter := range writeCounters {
|
for _, counter := range writeCounters {
|
||||||
counter(int64(dataLen))
|
counter(int64(dataLen))
|
||||||
}
|
}
|
||||||
n += int64(dataLen)
|
|
||||||
notFirstTime = true
|
notFirstTime = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
|
||||||
options := N.NewReadWaitOptions(nil, destination)
|
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||||
|
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||||
var notFirstTime bool
|
var notFirstTime bool
|
||||||
for _, packetBuffer := range packetBuffers {
|
for _, packetBuffer := range packetBuffers {
|
||||||
buffer := options.Copy(packetBuffer.Buffer)
|
buffer := buf.NewPacket()
|
||||||
|
buffer.Resize(frontHeadroom, 0)
|
||||||
|
buffer.Reserve(rearHeadroom)
|
||||||
|
_, err = buffer.Write(packetBuffer.Buffer.Bytes())
|
||||||
|
packetBuffer.Buffer.Release()
|
||||||
|
if err != nil {
|
||||||
|
buffer.Release()
|
||||||
|
continue
|
||||||
|
}
|
||||||
dataLen := buffer.Len()
|
dataLen := buffer.Len()
|
||||||
err = destination.WritePacket(buffer, packetBuffer.Destination)
|
buffer.OverCap(rearHeadroom)
|
||||||
N.PutPacketBuffer(packetBuffer)
|
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Leak()
|
buffer.Leak()
|
||||||
if !notFirstTime {
|
if !notFirstTime {
|
||||||
|
@ -290,19 +313,16 @@ func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, counter := range readCounters {
|
|
||||||
counter(int64(dataLen))
|
|
||||||
}
|
|
||||||
for _, counter := range writeCounters {
|
|
||||||
counter(int64(dataLen))
|
|
||||||
}
|
|
||||||
n += int64(dataLen)
|
n += int64(dataLen)
|
||||||
notFirstTime = true
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
|
func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
|
||||||
|
return CopyPacketConnContextList([]context.Context{ctx}, source, destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
|
||||||
var group task.Group
|
var group task.Group
|
||||||
group.Append("upload", func(ctx context.Context) error {
|
group.Append("upload", func(ctx context.Context) error {
|
||||||
return common.Error(CopyPacket(destination, source))
|
return common.Error(CopyPacket(destination, source))
|
||||||
|
@ -314,5 +334,5 @@ func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.Pack
|
||||||
common.Close(source, destination)
|
common.Close(source, destination)
|
||||||
})
|
})
|
||||||
group.FastFail()
|
group.FastFail()
|
||||||
return group.Run(ctx)
|
return group.RunContextList(contextList)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
@ -14,6 +15,49 @@ import (
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||||
|
|
||||||
|
var procrecv = modws2_32.NewProc("recv")
|
||||||
|
|
||||||
|
// Do the interface allocations only once for common
|
||||||
|
// Errno values.
|
||||||
|
const (
|
||||||
|
errnoERROR_IO_PENDING = 997
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||||
|
errERROR_EINVAL error = syscall.EINVAL
|
||||||
|
)
|
||||||
|
|
||||||
|
// errnoErr returns common boxed Errno values, to prevent
|
||||||
|
// allocations at runtime.
|
||||||
|
func errnoErr(e syscall.Errno) error {
|
||||||
|
switch e {
|
||||||
|
case 0:
|
||||||
|
return errERROR_EINVAL
|
||||||
|
case errnoERROR_IO_PENDING:
|
||||||
|
return errERROR_IO_PENDING
|
||||||
|
}
|
||||||
|
// TODO: add more here, after collecting data on the common
|
||||||
|
// error values see on Windows. (perhaps when running
|
||||||
|
// all.bat?)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
|
||||||
|
var _p0 *byte
|
||||||
|
if len(buf) > 0 {
|
||||||
|
_p0 = &buf[0]
|
||||||
|
}
|
||||||
|
r0, _, e1 := syscall.SyscallN(procrecv.Addr(), uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags))
|
||||||
|
n = int32(r0)
|
||||||
|
if n == -1 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
||||||
|
|
||||||
type syscallReadWaiter struct {
|
type syscallReadWaiter struct {
|
||||||
|
@ -120,16 +164,16 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
|
||||||
var readN int
|
var readN int
|
||||||
var from windows.Sockaddr
|
var from windows.Sockaddr
|
||||||
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
|
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
|
||||||
//goland:noinspection GoDirectComparisonOfErrors
|
|
||||||
if w.readErr != nil {
|
|
||||||
buffer.Release()
|
|
||||||
return w.readErr != windows.WSAEWOULDBLOCK
|
|
||||||
}
|
|
||||||
if readN > 0 {
|
if readN > 0 {
|
||||||
buffer.Truncate(readN)
|
buffer.Truncate(readN)
|
||||||
|
w.options.PostReturn(buffer)
|
||||||
|
w.buffer = buffer
|
||||||
|
} else {
|
||||||
|
buffer.Release()
|
||||||
|
}
|
||||||
|
if w.readErr == windows.WSAEWOULDBLOCK {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
w.options.PostReturn(buffer)
|
|
||||||
w.buffer = buffer
|
|
||||||
if from != nil {
|
if from != nil {
|
||||||
switch fromAddr := from.(type) {
|
switch fromAddr := from.(type) {
|
||||||
case *windows.SockaddrInet4:
|
case *windows.SockaddrInet4:
|
||||||
|
|
|
@ -25,45 +25,6 @@ func ReadPacket(reader N.PacketReader, buffer *buf.Buffer) (n int, addr net.Addr
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadBufferSize(reader io.Reader, bufferSize int) (buffer *buf.Buffer, err error) {
|
|
||||||
readWaiter, isReadWaiter := CreateReadWaiter(reader)
|
|
||||||
if isReadWaiter {
|
|
||||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
|
||||||
MTU: bufferSize,
|
|
||||||
})
|
|
||||||
return readWaiter.WaitReadBuffer()
|
|
||||||
}
|
|
||||||
buffer = buf.NewSize(bufferSize)
|
|
||||||
if extendedReader, isExtendedReader := reader.(N.ExtendedReader); isExtendedReader {
|
|
||||||
err = extendedReader.ReadBuffer(buffer)
|
|
||||||
} else {
|
|
||||||
_, err = buffer.ReadOnceFrom(reader)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
buffer.Release()
|
|
||||||
buffer = nil
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadPacketSize(reader N.PacketReader, packetSize int) (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
|
||||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(reader)
|
|
||||||
if isReadWaiter {
|
|
||||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
|
||||||
MTU: packetSize,
|
|
||||||
})
|
|
||||||
buffer, destination, err = readWaiter.WaitReadPacket()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
buffer = buf.NewSize(packetSize)
|
|
||||||
destination, err = reader.ReadPacket(buffer)
|
|
||||||
if err != nil {
|
|
||||||
buffer.Release()
|
|
||||||
buffer = nil
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func Write(writer io.Writer, data []byte) (n int, err error) {
|
func Write(writer io.Writer, data []byte) (n int, err error) {
|
||||||
if extendedWriter, isExtended := writer.(N.ExtendedWriter); isExtended {
|
if extendedWriter, isExtended := writer.(N.ExtendedWriter); isExtended {
|
||||||
return WriteBuffer(extendedWriter, buf.As(data))
|
return WriteBuffer(extendedWriter, buf.As(data))
|
||||||
|
|
|
@ -30,14 +30,6 @@ func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.So
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
|
|
||||||
return &destinationNATPacketConn{
|
|
||||||
NetPacketConn: conn,
|
|
||||||
origin: origin,
|
|
||||||
destination: destination,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type unidirectionalNATPacketConn struct {
|
type unidirectionalNATPacketConn struct {
|
||||||
N.NetPacketConn
|
N.NetPacketConn
|
||||||
origin M.Socksaddr
|
origin M.Socksaddr
|
||||||
|
@ -152,60 +144,6 @@ func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
|
||||||
return c.destination.UDPAddr()
|
return c.destination.UDPAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
type destinationNATPacketConn struct {
|
|
||||||
N.NetPacketConn
|
|
||||||
origin M.Socksaddr
|
|
||||||
destination M.Socksaddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
||||||
n, addr, err = c.NetPacketConn.ReadFrom(p)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if M.SocksaddrFromNet(addr) == c.origin {
|
|
||||||
addr = c.destination.UDPAddr()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
||||||
if M.SocksaddrFromNet(addr) == c.destination {
|
|
||||||
addr = c.origin.UDPAddr()
|
|
||||||
}
|
|
||||||
return c.NetPacketConn.WriteTo(p, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
|
||||||
destination, err = c.NetPacketConn.ReadPacket(buffer)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if destination == c.origin {
|
|
||||||
destination = c.destination
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
|
||||||
if destination == c.destination {
|
|
||||||
destination = c.origin
|
|
||||||
}
|
|
||||||
return c.NetPacketConn.WritePacket(buffer, destination)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
|
|
||||||
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) Upstream() any {
|
|
||||||
return c.NetPacketConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) RemoteAddr() net.Addr {
|
|
||||||
return c.destination.UDPAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
|
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
|
||||||
destination.Port = 0
|
destination.Port = 0
|
||||||
return destination
|
return destination
|
||||||
|
|
|
@ -36,7 +36,7 @@ func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
|
||||||
clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
|
clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
|
||||||
return clientErr
|
return clientErr
|
||||||
})
|
})
|
||||||
err = group.Run(context.Background())
|
err = group.Run()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
listener.Close()
|
listener.Close()
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
package bufio
|
|
||||||
|
|
||||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
|
|
||||||
|
|
||||||
//sys recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) [failretval == -1] = ws2_32.recv
|
|
|
@ -38,6 +38,7 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
|
||||||
var innerErr unix.Errno
|
var innerErr unix.Errno
|
||||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
|
//goland:noinspection GoDeprecation
|
||||||
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
|
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
|
||||||
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
|
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,57 +0,0 @@
|
||||||
// Code generated by 'go generate'; DO NOT EDIT.
|
|
||||||
|
|
||||||
package bufio
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ unsafe.Pointer
|
|
||||||
|
|
||||||
// Do the interface allocations only once for common
|
|
||||||
// Errno values.
|
|
||||||
const (
|
|
||||||
errnoERROR_IO_PENDING = 997
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
|
||||||
errERROR_EINVAL error = syscall.EINVAL
|
|
||||||
)
|
|
||||||
|
|
||||||
// errnoErr returns common boxed Errno values, to prevent
|
|
||||||
// allocations at runtime.
|
|
||||||
func errnoErr(e syscall.Errno) error {
|
|
||||||
switch e {
|
|
||||||
case 0:
|
|
||||||
return errERROR_EINVAL
|
|
||||||
case errnoERROR_IO_PENDING:
|
|
||||||
return errERROR_IO_PENDING
|
|
||||||
}
|
|
||||||
// TODO: add more here, after collecting data on the common
|
|
||||||
// error values see on Windows. (perhaps when running
|
|
||||||
// all.bat?)
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
|
||||||
|
|
||||||
procrecv = modws2_32.NewProc("recv")
|
|
||||||
)
|
|
||||||
|
|
||||||
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
|
|
||||||
var _p0 *byte
|
|
||||||
if len(buf) > 0 {
|
|
||||||
_p0 = &buf[0]
|
|
||||||
}
|
|
||||||
r0, _, e1 := syscall.Syscall6(procrecv.Addr(), 4, uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags), 0, 0)
|
|
||||||
n = int32(r0)
|
|
||||||
if n == -1 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration {
|
||||||
return i.timeout
|
return i.timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Instance) SetTimeout(timeout time.Duration) bool {
|
func (i *Instance) SetTimeout(timeout time.Duration) {
|
||||||
i.timeout = timeout
|
i.timeout = timeout
|
||||||
return i.Update()
|
i.Update()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Instance) wait() {
|
func (i *Instance) wait() {
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
type PacketConn interface {
|
type PacketConn interface {
|
||||||
N.PacketConn
|
N.PacketConn
|
||||||
Timeout() time.Duration
|
Timeout() time.Duration
|
||||||
SetTimeout(timeout time.Duration) bool
|
SetTimeout(timeout time.Duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TimerPacketConn struct {
|
type TimerPacketConn struct {
|
||||||
|
@ -24,12 +24,10 @@ type TimerPacketConn struct {
|
||||||
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
|
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
|
||||||
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
|
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
|
||||||
oldTimeout := timeoutConn.Timeout()
|
oldTimeout := timeoutConn.Timeout()
|
||||||
if oldTimeout > 0 && timeout >= oldTimeout {
|
if timeout < oldTimeout {
|
||||||
return ctx, conn
|
timeoutConn.SetTimeout(timeout)
|
||||||
}
|
|
||||||
if timeoutConn.SetTimeout(timeout) {
|
|
||||||
return ctx, conn
|
|
||||||
}
|
}
|
||||||
|
return ctx, conn
|
||||||
}
|
}
|
||||||
err := conn.SetReadDeadline(time.Time{})
|
err := conn.SetReadDeadline(time.Time{})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -60,8 +58,8 @@ func (c *TimerPacketConn) Timeout() time.Duration {
|
||||||
return c.instance.Timeout()
|
return c.instance.Timeout()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
|
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) {
|
||||||
return c.instance.SetTimeout(timeout)
|
c.instance.SetTimeout(timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimerPacketConn) Close() error {
|
func (c *TimerPacketConn) Close() error {
|
||||||
|
|
|
@ -61,9 +61,9 @@ func (c *TimeoutPacketConn) Timeout() time.Duration {
|
||||||
return c.timeout
|
return c.timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
|
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
|
||||||
c.timeout = timeout
|
c.timeout = timeout
|
||||||
return c.PacketConn.SetReadDeadline(time.Now()) == nil
|
c.PacketConn.SetReadDeadline(time.Now())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimeoutPacketConn) Close() error {
|
func (c *TimeoutPacketConn) Close() error {
|
||||||
|
|
|
@ -157,18 +157,6 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func Equal[S ~[]E, E comparable](s1, s2 S) bool {
|
|
||||||
if len(s1) != len(s2) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := range s1 {
|
|
||||||
if s1[i] != s2[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:norace
|
//go:norace
|
||||||
func Dup[T any](obj T) T {
|
func Dup[T any](obj T) T {
|
||||||
pointer := uintptr(unsafe.Pointer(&obj))
|
pointer := uintptr(unsafe.Pointer(&obj))
|
||||||
|
@ -280,14 +268,6 @@ func Reverse[T any](arr []T) []T {
|
||||||
return arr
|
return arr
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K {
|
|
||||||
ret := make(map[V]K, len(m))
|
|
||||||
for k, v := range m {
|
|
||||||
ret[v] = k
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func Done(ctx context.Context) bool {
|
func Done(ctx context.Context) bool {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
@ -382,3 +362,22 @@ func Close(closers ...any) error {
|
||||||
}
|
}
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Starter interface {
|
||||||
|
Start() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func Start(starters ...any) error {
|
||||||
|
for _, rawStarter := range starters {
|
||||||
|
if rawStarter == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if starter, isStarter := rawStarter.(Starter); isStarter {
|
||||||
|
err := starter.Start()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Deprecated: not used
|
|
||||||
func SelectContext(contextList []context.Context) (int, error) {
|
func SelectContext(contextList []context.Context) (int, error) {
|
||||||
if len(contextList) == 1 {
|
if len(contextList) == 1 {
|
||||||
<-contextList[0].Done()
|
<-contextList[0].Done()
|
||||||
|
|
|
@ -9,15 +9,15 @@ import (
|
||||||
|
|
||||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||||
return Raw(conn, func(fd uintptr) error {
|
return Raw(conn, func(fd uintptr) error {
|
||||||
|
var err error
|
||||||
if interfaceIndex == -1 {
|
if interfaceIndex == -1 {
|
||||||
if finder == nil {
|
if finder == nil {
|
||||||
return os.ErrInvalid
|
return os.ErrInvalid
|
||||||
}
|
}
|
||||||
iif, err := finder.ByName(interfaceName)
|
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
interfaceIndex = iif.Index
|
|
||||||
}
|
}
|
||||||
switch network {
|
switch network {
|
||||||
case "tcp6", "udp6":
|
case "tcp6", "udp6":
|
||||||
|
|
|
@ -3,57 +3,19 @@ package control
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type InterfaceFinder interface {
|
type InterfaceFinder interface {
|
||||||
Update() error
|
|
||||||
Interfaces() []Interface
|
Interfaces() []Interface
|
||||||
ByName(name string) (*Interface, error)
|
InterfaceIndexByName(name string) (int, error)
|
||||||
ByIndex(index int) (*Interface, error)
|
InterfaceNameByIndex(index int) (string, error)
|
||||||
ByAddr(addr netip.Addr) (*Interface, error)
|
InterfaceByAddr(addr netip.Addr) (*Interface, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
Index int
|
Index int
|
||||||
MTU int
|
MTU int
|
||||||
Name string
|
Name string
|
||||||
HardwareAddr net.HardwareAddr
|
|
||||||
Flags net.Flags
|
|
||||||
Addresses []netip.Prefix
|
Addresses []netip.Prefix
|
||||||
}
|
HardwareAddr net.HardwareAddr
|
||||||
|
|
||||||
func (i Interface) Equals(other Interface) bool {
|
|
||||||
return i.Index == other.Index &&
|
|
||||||
i.MTU == other.MTU &&
|
|
||||||
i.Name == other.Name &&
|
|
||||||
common.Equal(i.HardwareAddr, other.HardwareAddr) &&
|
|
||||||
i.Flags == other.Flags &&
|
|
||||||
common.Equal(i.Addresses, other.Addresses)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i Interface) NetInterface() net.Interface {
|
|
||||||
return *(*net.Interface)(unsafe.Pointer(&i))
|
|
||||||
}
|
|
||||||
|
|
||||||
func InterfaceFromNet(iif net.Interface) (Interface, error) {
|
|
||||||
ifAddrs, err := iif.Addrs()
|
|
||||||
if err != nil {
|
|
||||||
return Interface{}, err
|
|
||||||
}
|
|
||||||
return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface {
|
|
||||||
return Interface{
|
|
||||||
Index: iif.Index,
|
|
||||||
MTU: iif.MTU,
|
|
||||||
Name: iif.Name,
|
|
||||||
HardwareAddr: iif.HardwareAddr,
|
|
||||||
Flags: iif.Flags,
|
|
||||||
Addresses: addresses,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,8 +3,10 @@ package control
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
_ "unsafe"
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
"github.com/sagernet/sing/common"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
|
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
|
||||||
|
@ -24,12 +26,17 @@ func (f *DefaultInterfaceFinder) Update() error {
|
||||||
}
|
}
|
||||||
interfaces := make([]Interface, 0, len(netIfs))
|
interfaces := make([]Interface, 0, len(netIfs))
|
||||||
for _, netIf := range netIfs {
|
for _, netIf := range netIfs {
|
||||||
var iif Interface
|
ifAddrs, err := netIf.Addrs()
|
||||||
iif, err = InterfaceFromNet(netIf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
interfaces = append(interfaces, iif)
|
interfaces = append(interfaces, Interface{
|
||||||
|
Index: netIf.Index,
|
||||||
|
MTU: netIf.MTU,
|
||||||
|
Name: netIf.Name,
|
||||||
|
Addresses: common.Map(ifAddrs, M.PrefixFromNet),
|
||||||
|
HardwareAddr: netIf.HardwareAddr,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
f.interfaces = interfaces
|
f.interfaces = interfaces
|
||||||
return nil
|
return nil
|
||||||
|
@ -43,41 +50,38 @@ func (f *DefaultInterfaceFinder) Interfaces() []Interface {
|
||||||
return f.interfaces
|
return f.interfaces
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) {
|
func (f *DefaultInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
|
||||||
for _, netInterface := range f.interfaces {
|
for _, netInterface := range f.interfaces {
|
||||||
if netInterface.Name == name {
|
if netInterface.Name == name {
|
||||||
return &netInterface, nil
|
return netInterface.Index, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err := net.InterfaceByName(name)
|
netInterface, err := net.InterfaceByName(name)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
err = f.Update()
|
return 0, err
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return f.ByName(name)
|
|
||||||
}
|
}
|
||||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
|
f.Update()
|
||||||
|
return netInterface.Index, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) {
|
func (f *DefaultInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
|
||||||
for _, netInterface := range f.interfaces {
|
for _, netInterface := range f.interfaces {
|
||||||
if netInterface.Index == index {
|
if netInterface.Index == index {
|
||||||
return &netInterface, nil
|
return netInterface.Name, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err := net.InterfaceByIndex(index)
|
netInterface, err := net.InterfaceByIndex(index)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
err = f.Update()
|
return "", err
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return f.ByIndex(index)
|
|
||||||
}
|
}
|
||||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
|
f.Update()
|
||||||
|
return netInterface.Name, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) {
|
//go:linkname errNoSuchInterface net.errNoSuchInterface
|
||||||
|
var errNoSuchInterface error
|
||||||
|
|
||||||
|
func (f *DefaultInterfaceFinder) InterfaceByAddr(addr netip.Addr) (*Interface, error) {
|
||||||
for _, netInterface := range f.interfaces {
|
for _, netInterface := range f.interfaces {
|
||||||
for _, prefix := range netInterface.Addresses {
|
for _, prefix := range netInterface.Addresses {
|
||||||
if prefix.Contains(addr) {
|
if prefix.Contains(addr) {
|
||||||
|
@ -85,5 +89,16 @@ func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: addr.AsSlice()}, Err: E.New("no such network interface")}
|
err := f.Update()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, netInterface := range f.interfaces {
|
||||||
|
for _, prefix := range netInterface.Addresses {
|
||||||
|
if prefix.Contains(addr) {
|
||||||
|
return &netInterface, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: addr.AsSlice()}, Err: errNoSuchInterface}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,11 +19,11 @@ func bindToInterface(conn syscall.RawConn, network string, address string, finde
|
||||||
if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
return os.ErrInvalid
|
return os.ErrInvalid
|
||||||
}
|
}
|
||||||
iif, err := finder.ByName(interfaceName)
|
var err error
|
||||||
|
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
interfaceIndex = iif.Index
|
|
||||||
}
|
}
|
||||||
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
|
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
|
@ -11,19 +11,19 @@ import (
|
||||||
|
|
||||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||||
return Raw(conn, func(fd uintptr) error {
|
return Raw(conn, func(fd uintptr) error {
|
||||||
|
var err error
|
||||||
if interfaceIndex == -1 {
|
if interfaceIndex == -1 {
|
||||||
if finder == nil {
|
if finder == nil {
|
||||||
return os.ErrInvalid
|
return os.ErrInvalid
|
||||||
}
|
}
|
||||||
iif, err := finder.ByName(interfaceName)
|
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
interfaceIndex = iif.Index
|
|
||||||
}
|
}
|
||||||
handle := syscall.Handle(fd)
|
handle := syscall.Handle(fd)
|
||||||
if M.ParseSocksaddr(address).AddrString() == "" {
|
if M.ParseSocksaddr(address).AddrString() == "" {
|
||||||
err := bind4(handle, interfaceIndex)
|
err = bind4(handle, interfaceIndex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,26 +4,19 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
func DisableUDPFragment() Func {
|
func DisableUDPFragment() Func {
|
||||||
return func(network, address string, conn syscall.RawConn) error {
|
return func(network, address string, conn syscall.RawConn) error {
|
||||||
if N.NetworkName(network) != N.NetworkUDP {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return Raw(conn, func(fd uintptr) error {
|
return Raw(conn, func(fd uintptr) error {
|
||||||
if network == "udp" || network == "udp4" {
|
switch network {
|
||||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
|
case "udp4":
|
||||||
if err != nil {
|
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1); err != nil {
|
||||||
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
|
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
|
||||||
}
|
}
|
||||||
}
|
case "udp6":
|
||||||
if network == "udp" || network == "udp6" {
|
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1); err != nil {
|
||||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
|
|
||||||
if err != nil {
|
|
||||||
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
|
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,19 +11,17 @@ import (
|
||||||
|
|
||||||
func DisableUDPFragment() Func {
|
func DisableUDPFragment() Func {
|
||||||
return func(network, address string, conn syscall.RawConn) error {
|
return func(network, address string, conn syscall.RawConn) error {
|
||||||
if N.NetworkName(network) != N.NetworkUDP {
|
switch N.NetworkName(network) {
|
||||||
|
case N.NetworkUDP:
|
||||||
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return Raw(conn, func(fd uintptr) error {
|
return Raw(conn, func(fd uintptr) error {
|
||||||
if network == "udp" || network == "udp4" {
|
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
|
||||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
|
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||||
if err != nil {
|
|
||||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if network == "udp" || network == "udp6" {
|
if network == "udp6" {
|
||||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
|
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
|
||||||
if err != nil {
|
|
||||||
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,19 +25,17 @@ const (
|
||||||
|
|
||||||
func DisableUDPFragment() Func {
|
func DisableUDPFragment() Func {
|
||||||
return func(network, address string, conn syscall.RawConn) error {
|
return func(network, address string, conn syscall.RawConn) error {
|
||||||
if N.NetworkName(network) != N.NetworkUDP {
|
switch N.NetworkName(network) {
|
||||||
|
case N.NetworkUDP:
|
||||||
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return Raw(conn, func(fd uintptr) error {
|
return Raw(conn, func(fd uintptr) error {
|
||||||
if network == "udp" || network == "udp4" {
|
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
|
||||||
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO)
|
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||||
if err != nil {
|
|
||||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if network == "udp" || network == "udp6" {
|
if network == "udp6" {
|
||||||
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO)
|
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
|
||||||
if err != nil {
|
|
||||||
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,10 +4,10 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RoutingMark(mark uint32) Func {
|
func RoutingMark(mark int) Func {
|
||||||
return func(network, address string, conn syscall.RawConn) error {
|
return func(network, address string, conn syscall.RawConn) error {
|
||||||
return Raw(conn, func(fd uintptr) error {
|
return Raw(conn, func(fd uintptr) error {
|
||||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(mark))
|
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,6 @@
|
||||||
|
|
||||||
package control
|
package control
|
||||||
|
|
||||||
func RoutingMark(mark uint32) Func {
|
func RoutingMark(mark int) Func {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,67 +0,0 @@
|
||||||
package domain_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sort"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/domain"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAdGuardMatcher(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
ruleLines := []string{
|
|
||||||
"||example.org^",
|
|
||||||
"|example.com^",
|
|
||||||
"example.net^",
|
|
||||||
"||example.edu",
|
|
||||||
"||example.edu.tw^",
|
|
||||||
"|example.gov",
|
|
||||||
"example.arpa",
|
|
||||||
}
|
|
||||||
matcher := domain.NewAdGuardMatcher(ruleLines)
|
|
||||||
require.NotNil(t, matcher)
|
|
||||||
matchDomain := []string{
|
|
||||||
"example.org",
|
|
||||||
"www.example.org",
|
|
||||||
"example.com",
|
|
||||||
"example.net",
|
|
||||||
"isexample.net",
|
|
||||||
"www.example.net",
|
|
||||||
"example.edu",
|
|
||||||
"example.edu.cn",
|
|
||||||
"example.edu.tw",
|
|
||||||
"www.example.edu",
|
|
||||||
"www.example.edu.cn",
|
|
||||||
"example.gov",
|
|
||||||
"example.gov.cn",
|
|
||||||
"example.arpa",
|
|
||||||
"www.example.arpa",
|
|
||||||
"isexample.arpa",
|
|
||||||
"example.arpa.cn",
|
|
||||||
"www.example.arpa.cn",
|
|
||||||
"isexample.arpa.cn",
|
|
||||||
}
|
|
||||||
notMatchDomain := []string{
|
|
||||||
"example.org.cn",
|
|
||||||
"notexample.org",
|
|
||||||
"example.com.cn",
|
|
||||||
"www.example.com.cn",
|
|
||||||
"example.net.cn",
|
|
||||||
"notexample.edu",
|
|
||||||
"notexample.edu.cn",
|
|
||||||
"www.example.gov",
|
|
||||||
"notexample.gov",
|
|
||||||
}
|
|
||||||
for _, domain := range matchDomain {
|
|
||||||
require.True(t, matcher.Match(domain), domain)
|
|
||||||
}
|
|
||||||
for _, domain := range notMatchDomain {
|
|
||||||
require.False(t, matcher.Match(domain), domain)
|
|
||||||
}
|
|
||||||
dLines := matcher.Dump()
|
|
||||||
sort.Strings(ruleLines)
|
|
||||||
sort.Strings(dLines)
|
|
||||||
require.Equal(t, ruleLines, dLines)
|
|
||||||
}
|
|
|
@ -1,172 +0,0 @@
|
||||||
package domain
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
"github.com/sagernet/sing/common/varbin"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
anyLabel = '*'
|
|
||||||
suffixLabel = '\b'
|
|
||||||
)
|
|
||||||
|
|
||||||
type AdGuardMatcher struct {
|
|
||||||
set *succinctSet
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAdGuardMatcher(ruleLines []string) *AdGuardMatcher {
|
|
||||||
ruleList := make([]string, 0, len(ruleLines))
|
|
||||||
for _, ruleLine := range ruleLines {
|
|
||||||
var (
|
|
||||||
isSuffix bool // ||
|
|
||||||
hasStart bool // |
|
|
||||||
hasEnd bool // ^
|
|
||||||
)
|
|
||||||
if strings.HasPrefix(ruleLine, "||") {
|
|
||||||
ruleLine = ruleLine[2:]
|
|
||||||
isSuffix = true
|
|
||||||
} else if strings.HasPrefix(ruleLine, "|") {
|
|
||||||
ruleLine = ruleLine[1:]
|
|
||||||
hasStart = true
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(ruleLine, "^") {
|
|
||||||
ruleLine = ruleLine[:len(ruleLine)-1]
|
|
||||||
hasEnd = true
|
|
||||||
}
|
|
||||||
if isSuffix {
|
|
||||||
ruleLine = string(rootLabel) + ruleLine
|
|
||||||
} else if !hasStart {
|
|
||||||
ruleLine = string(prefixLabel) + ruleLine
|
|
||||||
}
|
|
||||||
if !hasEnd {
|
|
||||||
if strings.HasSuffix(ruleLine, ".") {
|
|
||||||
ruleLine = ruleLine[:len(ruleLine)-1]
|
|
||||||
}
|
|
||||||
ruleLine += string(suffixLabel)
|
|
||||||
}
|
|
||||||
ruleList = append(ruleList, reverseDomain(ruleLine))
|
|
||||||
}
|
|
||||||
ruleList = common.Uniq(ruleList)
|
|
||||||
sort.Strings(ruleList)
|
|
||||||
return &AdGuardMatcher{newSuccinctSet(ruleList)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadAdGuardMatcher(reader varbin.Reader) (*AdGuardMatcher, error) {
|
|
||||||
set, err := readSuccinctSet(reader)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &AdGuardMatcher{set}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AdGuardMatcher) Write(writer varbin.Writer) error {
|
|
||||||
return m.set.Write(writer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AdGuardMatcher) Match(domain string) bool {
|
|
||||||
key := reverseDomain(domain)
|
|
||||||
if m.has([]byte(key), 0, 0) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
if m.has([]byte(string(suffixLabel)+key), 0, 0) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
idx := strings.IndexByte(key, '.')
|
|
||||||
if idx == -1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
key = key[idx+1:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AdGuardMatcher) has(key []byte, nodeId, bmIdx int) bool {
|
|
||||||
for i := 0; i < len(key); i++ {
|
|
||||||
currentChar := key[i]
|
|
||||||
for ; ; bmIdx++ {
|
|
||||||
if getBit(m.set.labelBitmap, bmIdx) != 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
nextLabel := m.set.labels[bmIdx-nodeId]
|
|
||||||
if nextLabel == prefixLabel {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if nextLabel == rootLabel {
|
|
||||||
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
|
||||||
hasNext := getBit(m.set.leaves, nextNodeId) != 0
|
|
||||||
if currentChar == '.' && hasNext {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if nextLabel == currentChar {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if nextLabel == anyLabel {
|
|
||||||
idx := bytes.IndexRune(key[i:], '.')
|
|
||||||
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
|
||||||
if idx == -1 {
|
|
||||||
if getBit(m.set.leaves, nextNodeId) != 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
idx = 0
|
|
||||||
}
|
|
||||||
nextBmIdx := selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nextNodeId-1) + 1
|
|
||||||
if m.has(key[i+idx:], nextNodeId, nextBmIdx) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
|
||||||
bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1
|
|
||||||
}
|
|
||||||
if getBit(m.set.leaves, nodeId) != 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
for ; ; bmIdx++ {
|
|
||||||
if getBit(m.set.labelBitmap, bmIdx) != 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
nextLabel := m.set.labels[bmIdx-nodeId]
|
|
||||||
if nextLabel == prefixLabel || nextLabel == rootLabel {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AdGuardMatcher) Dump() (ruleLines []string) {
|
|
||||||
for _, key := range m.set.keys() {
|
|
||||||
key = reverseDomain(key)
|
|
||||||
var (
|
|
||||||
isSuffix bool
|
|
||||||
hasStart bool
|
|
||||||
hasEnd bool
|
|
||||||
)
|
|
||||||
if key[0] == prefixLabel {
|
|
||||||
key = key[1:]
|
|
||||||
} else if key[0] == rootLabel {
|
|
||||||
key = key[1:]
|
|
||||||
isSuffix = true
|
|
||||||
} else {
|
|
||||||
hasStart = true
|
|
||||||
}
|
|
||||||
if key[len(key)-1] == suffixLabel {
|
|
||||||
key = key[:len(key)-1]
|
|
||||||
} else {
|
|
||||||
hasEnd = true
|
|
||||||
}
|
|
||||||
if isSuffix {
|
|
||||||
key = "||" + key
|
|
||||||
} else if hasStart {
|
|
||||||
key = "|" + key
|
|
||||||
}
|
|
||||||
if hasEnd {
|
|
||||||
key += "^"
|
|
||||||
}
|
|
||||||
ruleLines = append(ruleLines, key)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,22 +1,19 @@
|
||||||
package domain
|
package domain
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
"sort"
|
"sort"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/varbin"
|
"github.com/sagernet/sing/common/rw"
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
prefixLabel = '\r'
|
|
||||||
rootLabel = '\n'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Matcher struct {
|
type Matcher struct {
|
||||||
set *succinctSet
|
set *succinctSet
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMatcher(domains []string, domainSuffix []string, generateLegacy bool) *Matcher {
|
func NewMatcher(domains []string, domainSuffix []string) *Matcher {
|
||||||
domainList := make([]string, 0, len(domains)+2*len(domainSuffix))
|
domainList := make([]string, 0, len(domains)+2*len(domainSuffix))
|
||||||
seen := make(map[string]bool, len(domainList))
|
seen := make(map[string]bool, len(domainList))
|
||||||
for _, domain := range domainSuffix {
|
for _, domain := range domainSuffix {
|
||||||
|
@ -25,16 +22,10 @@ func NewMatcher(domains []string, domainSuffix []string, generateLegacy bool) *M
|
||||||
}
|
}
|
||||||
seen[domain] = true
|
seen[domain] = true
|
||||||
if domain[0] == '.' {
|
if domain[0] == '.' {
|
||||||
domainList = append(domainList, reverseDomain(string(prefixLabel)+domain))
|
domainList = append(domainList, reverseDomainSuffix(domain))
|
||||||
} else if generateLegacy {
|
|
||||||
domainList = append(domainList, reverseDomain(domain))
|
|
||||||
suffixDomain := "." + domain
|
|
||||||
if !seen[suffixDomain] {
|
|
||||||
seen[suffixDomain] = true
|
|
||||||
domainList = append(domainList, reverseDomain(string(prefixLabel)+suffixDomain))
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
domainList = append(domainList, reverseDomain(string(rootLabel)+domain))
|
domainList = append(domainList, reverseDomain(domain))
|
||||||
|
domainList = append(domainList, reverseRootDomainSuffix(domain))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
|
@ -48,91 +39,82 @@ func NewMatcher(domains []string, domainSuffix []string, generateLegacy bool) *M
|
||||||
return &Matcher{newSuccinctSet(domainList)}
|
return &Matcher{newSuccinctSet(domainList)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadMatcher(reader varbin.Reader) (*Matcher, error) {
|
func ReadMatcher(reader io.Reader) (*Matcher, error) {
|
||||||
set, err := readSuccinctSet(reader)
|
var version uint8
|
||||||
|
err := binary.Read(reader, binary.BigEndian, &version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
leavesLength, err := rw.ReadUVariant(reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
leaves := make([]uint64, leavesLength)
|
||||||
|
err = binary.Read(reader, binary.BigEndian, leaves)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
labelBitmapLength, err := rw.ReadUVariant(reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
labelBitmap := make([]uint64, labelBitmapLength)
|
||||||
|
err = binary.Read(reader, binary.BigEndian, labelBitmap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
labelsLength, err := rw.ReadUVariant(reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
labels := make([]byte, labelsLength)
|
||||||
|
_, err = io.ReadFull(reader, labels)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
set := &succinctSet{
|
||||||
|
leaves: leaves,
|
||||||
|
labelBitmap: labelBitmap,
|
||||||
|
labels: labels,
|
||||||
|
}
|
||||||
|
set.init()
|
||||||
return &Matcher{set}, nil
|
return &Matcher{set}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Matcher) Write(writer varbin.Writer) error {
|
|
||||||
return m.set.Write(writer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Matcher) Match(domain string) bool {
|
func (m *Matcher) Match(domain string) bool {
|
||||||
return m.has(reverseDomain(domain))
|
return m.set.Has(reverseDomain(domain))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Matcher) has(key string) bool {
|
func (m *Matcher) Write(writer io.Writer) error {
|
||||||
var nodeId, bmIdx int
|
err := binary.Write(writer, binary.BigEndian, byte(1))
|
||||||
for i := 0; i < len(key); i++ {
|
if err != nil {
|
||||||
currentChar := key[i]
|
return err
|
||||||
for ; ; bmIdx++ {
|
|
||||||
if getBit(m.set.labelBitmap, bmIdx) != 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
nextLabel := m.set.labels[bmIdx-nodeId]
|
|
||||||
if nextLabel == prefixLabel {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if nextLabel == rootLabel {
|
|
||||||
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
|
||||||
hasNext := getBit(m.set.leaves, nextNodeId) != 0
|
|
||||||
if currentChar == '.' && hasNext {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if nextLabel == currentChar {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
|
||||||
bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1
|
|
||||||
}
|
}
|
||||||
if getBit(m.set.leaves, nodeId) != 0 {
|
err = rw.WriteUVariant(writer, uint64(len(m.set.leaves)))
|
||||||
return true
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
for ; ; bmIdx++ {
|
err = binary.Write(writer, binary.BigEndian, m.set.leaves)
|
||||||
if getBit(m.set.labelBitmap, bmIdx) != 0 {
|
if err != nil {
|
||||||
return false
|
return err
|
||||||
}
|
|
||||||
nextLabel := m.set.labels[bmIdx-nodeId]
|
|
||||||
if nextLabel == prefixLabel || nextLabel == rootLabel {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
err = rw.WriteUVariant(writer, uint64(len(m.set.labelBitmap)))
|
||||||
|
if err != nil {
|
||||||
func (m *Matcher) Dump() (domainList []string, prefixList []string) {
|
return err
|
||||||
domainMap := make(map[string]bool)
|
|
||||||
prefixMap := make(map[string]bool)
|
|
||||||
for _, key := range m.set.keys() {
|
|
||||||
key = reverseDomain(key)
|
|
||||||
if key[0] == prefixLabel {
|
|
||||||
prefixMap[key[1:]] = true
|
|
||||||
} else if key[0] == rootLabel {
|
|
||||||
prefixList = append(prefixList, key[1:])
|
|
||||||
} else {
|
|
||||||
domainMap[key] = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for rawPrefix := range prefixMap {
|
err = binary.Write(writer, binary.BigEndian, m.set.labelBitmap)
|
||||||
if rawPrefix[0] == '.' {
|
if err != nil {
|
||||||
if rootDomain := rawPrefix[1:]; domainMap[rootDomain] {
|
return err
|
||||||
delete(domainMap, rootDomain)
|
|
||||||
prefixList = append(prefixList, rootDomain)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
prefixList = append(prefixList, rawPrefix)
|
|
||||||
}
|
}
|
||||||
for domain := range domainMap {
|
err = rw.WriteUVariant(writer, uint64(len(m.set.labels)))
|
||||||
domainList = append(domainList, domain)
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
sort.Strings(domainList)
|
_, err = writer.Write(m.set.labels)
|
||||||
sort.Strings(prefixList)
|
if err != nil {
|
||||||
return domainList, prefixList
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func reverseDomain(domain string) string {
|
func reverseDomain(domain string) string {
|
||||||
|
@ -145,3 +127,28 @@ func reverseDomain(domain string) string {
|
||||||
}
|
}
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func reverseDomainSuffix(domain string) string {
|
||||||
|
l := len(domain)
|
||||||
|
b := make([]byte, l+1)
|
||||||
|
for i := 0; i < l; {
|
||||||
|
r, n := utf8.DecodeRuneInString(domain[i:])
|
||||||
|
i += n
|
||||||
|
utf8.EncodeRune(b[l-i:], r)
|
||||||
|
}
|
||||||
|
b[l] = prefixLabel
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseRootDomainSuffix(domain string) string {
|
||||||
|
l := len(domain)
|
||||||
|
b := make([]byte, l+2)
|
||||||
|
for i := 0; i < l; {
|
||||||
|
r, n := utf8.DecodeRuneInString(domain[i:])
|
||||||
|
i += n
|
||||||
|
utf8.EncodeRune(b[l-i:], r)
|
||||||
|
}
|
||||||
|
b[l] = '.'
|
||||||
|
b[l+1] = prefixLabel
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
|
@ -1,80 +0,0 @@
|
||||||
package domain_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"sort"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/domain"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMatcher(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
testDomain := []string{"example.com", "example.org"}
|
|
||||||
testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"}
|
|
||||||
matcher := domain.NewMatcher(testDomain, testDomainSuffix, false)
|
|
||||||
require.NotNil(t, matcher)
|
|
||||||
require.True(t, matcher.Match("example.com"))
|
|
||||||
require.True(t, matcher.Match("example.org"))
|
|
||||||
require.False(t, matcher.Match("example.cn"))
|
|
||||||
require.True(t, matcher.Match("example.com.cn"))
|
|
||||||
require.True(t, matcher.Match("example.org.cn"))
|
|
||||||
require.False(t, matcher.Match("com.cn"))
|
|
||||||
require.False(t, matcher.Match("org.cn"))
|
|
||||||
require.True(t, matcher.Match("sagernet.org"))
|
|
||||||
require.True(t, matcher.Match("sing-box.sagernet.org"))
|
|
||||||
dDomain, dDomainSuffix := matcher.Dump()
|
|
||||||
require.Equal(t, testDomain, dDomain)
|
|
||||||
require.Equal(t, testDomainSuffix, dDomainSuffix)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMatcherLegacy(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
testDomain := []string{"example.com", "example.org"}
|
|
||||||
testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"}
|
|
||||||
matcher := domain.NewMatcher(testDomain, testDomainSuffix, true)
|
|
||||||
require.NotNil(t, matcher)
|
|
||||||
require.True(t, matcher.Match("example.com"))
|
|
||||||
require.True(t, matcher.Match("example.org"))
|
|
||||||
require.False(t, matcher.Match("example.cn"))
|
|
||||||
require.True(t, matcher.Match("example.com.cn"))
|
|
||||||
require.True(t, matcher.Match("example.org.cn"))
|
|
||||||
require.False(t, matcher.Match("com.cn"))
|
|
||||||
require.False(t, matcher.Match("org.cn"))
|
|
||||||
require.True(t, matcher.Match("sagernet.org"))
|
|
||||||
require.True(t, matcher.Match("sing-box.sagernet.org"))
|
|
||||||
dDomain, dDomainSuffix := matcher.Dump()
|
|
||||||
require.Equal(t, testDomain, dDomain)
|
|
||||||
require.Equal(t, testDomainSuffix, dDomainSuffix)
|
|
||||||
}
|
|
||||||
|
|
||||||
type simpleRuleSet struct {
|
|
||||||
Rules []struct {
|
|
||||||
Domain []string `json:"domain"`
|
|
||||||
DomainSuffix []string `json:"domain_suffix"`
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDumpLarge(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
response, err := http.Get("https://raw.githubusercontent.com/MetaCubeX/meta-rules-dat/sing/geo/geosite/cn.json")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer response.Body.Close()
|
|
||||||
var ruleSet simpleRuleSet
|
|
||||||
err = json.NewDecoder(response.Body).Decode(&ruleSet)
|
|
||||||
require.NoError(t, err)
|
|
||||||
domainList := ruleSet.Rules[0].Domain
|
|
||||||
domainSuffixList := ruleSet.Rules[0].DomainSuffix
|
|
||||||
require.Len(t, ruleSet.Rules, 1)
|
|
||||||
require.True(t, len(domainList)+len(domainSuffixList) > 0)
|
|
||||||
sort.Strings(domainList)
|
|
||||||
sort.Strings(domainSuffixList)
|
|
||||||
matcher := domain.NewMatcher(domainList, domainSuffixList, false)
|
|
||||||
require.NotNil(t, matcher)
|
|
||||||
dDomain, dDomainSuffix := matcher.Dump()
|
|
||||||
require.Equal(t, domainList, dDomain)
|
|
||||||
require.Equal(t, domainSuffixList, dDomainSuffix)
|
|
||||||
}
|
|
|
@ -1,12 +1,11 @@
|
||||||
package domain
|
package domain
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"math/bits"
|
"math/bits"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/varbin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const prefixLabel = '\r'
|
||||||
|
|
||||||
// mod from https://github.com/openacid/succinct
|
// mod from https://github.com/openacid/succinct
|
||||||
|
|
||||||
type succinctSet struct {
|
type succinctSet struct {
|
||||||
|
@ -43,61 +42,36 @@ func newSuccinctSet(keys []string) *succinctSet {
|
||||||
return ss
|
return ss
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *succinctSet) keys() []string {
|
func (ss *succinctSet) Has(key string) bool {
|
||||||
var result []string
|
var nodeId, bmIdx int
|
||||||
var currentKey []byte
|
for i := 0; i < len(key); i++ {
|
||||||
var bmIdx, nodeId int
|
currentChar := key[i]
|
||||||
|
|
||||||
var traverse func(int, int)
|
|
||||||
traverse = func(nodeId, bmIdx int) {
|
|
||||||
if getBit(ss.leaves, nodeId) != 0 {
|
|
||||||
result = append(result, string(currentKey))
|
|
||||||
}
|
|
||||||
|
|
||||||
for ; ; bmIdx++ {
|
for ; ; bmIdx++ {
|
||||||
if getBit(ss.labelBitmap, bmIdx) != 0 {
|
if getBit(ss.labelBitmap, bmIdx) != 0 {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
nextLabel := ss.labels[bmIdx-nodeId]
|
nextLabel := ss.labels[bmIdx-nodeId]
|
||||||
currentKey = append(currentKey, nextLabel)
|
if nextLabel == prefixLabel {
|
||||||
nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
|
return true
|
||||||
nextBmIdx := selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nextNodeId-1) + 1
|
}
|
||||||
traverse(nextNodeId, nextBmIdx)
|
if nextLabel == currentChar {
|
||||||
currentKey = currentKey[:len(currentKey)-1]
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
|
||||||
|
bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1
|
||||||
|
}
|
||||||
|
if getBit(ss.leaves, nodeId) != 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for ; ; bmIdx++ {
|
||||||
|
if getBit(ss.labelBitmap, bmIdx) != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if ss.labels[bmIdx-nodeId] == prefixLabel {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
traverse(nodeId, bmIdx)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
type succinctSetData struct {
|
|
||||||
Reserved uint8
|
|
||||||
Leaves []uint64
|
|
||||||
LabelBitmap []uint64
|
|
||||||
Labels []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func readSuccinctSet(reader varbin.Reader) (*succinctSet, error) {
|
|
||||||
matcher, err := varbin.ReadValue[succinctSetData](reader, binary.BigEndian)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
set := &succinctSet{
|
|
||||||
leaves: matcher.Leaves,
|
|
||||||
labelBitmap: matcher.LabelBitmap,
|
|
||||||
labels: matcher.Labels,
|
|
||||||
}
|
|
||||||
set.init()
|
|
||||||
return set, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *succinctSet) Write(writer varbin.Writer) error {
|
|
||||||
return varbin.Write(writer, binary.BigEndian, succinctSetData{
|
|
||||||
Leaves: ss.leaves,
|
|
||||||
LabelBitmap: ss.labelBitmap,
|
|
||||||
Labels: ss.labels,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func setBit(bm *[]uint64, i int, v int) {
|
func setBit(bm *[]uint64, i int, v int) {
|
||||||
|
|
|
@ -12,16 +12,3 @@ func (e *causeError) Error() string {
|
||||||
func (e *causeError) Unwrap() error {
|
func (e *causeError) Unwrap() error {
|
||||||
return e.cause
|
return e.cause
|
||||||
}
|
}
|
||||||
|
|
||||||
type causeError1 struct {
|
|
||||||
error
|
|
||||||
cause error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *causeError1) Error() string {
|
|
||||||
return e.error.Error() + ": " + e.cause.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *causeError1) Unwrap() []error {
|
|
||||||
return []error{e.error, e.cause}
|
|
||||||
}
|
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
F "github.com/sagernet/sing/common/format"
|
F "github.com/sagernet/sing/common/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
type Handler interface {
|
type Handler interface {
|
||||||
NewError(ctx context.Context, err error)
|
NewError(ctx context.Context, err error)
|
||||||
}
|
}
|
||||||
|
@ -32,13 +31,6 @@ func Cause(cause error, message ...any) error {
|
||||||
return &causeError{F.ToString(message...), cause}
|
return &causeError{F.ToString(message...), cause}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Cause1(err error, cause error) error {
|
|
||||||
if cause == nil {
|
|
||||||
panic("cause on an nil error")
|
|
||||||
}
|
|
||||||
return &causeError1{err, cause}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Extend(cause error, message ...any) error {
|
func Extend(cause error, message ...any) error {
|
||||||
if cause == nil {
|
if cause == nil {
|
||||||
panic("extend on an nil error")
|
panic("extend on an nil error")
|
||||||
|
@ -47,11 +39,11 @@ func Extend(cause error, message ...any) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsClosedOrCanceled(err error) bool {
|
func IsClosedOrCanceled(err error) bool {
|
||||||
return IsClosed(err) || IsCanceled(err) || IsTimeout(err)
|
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded)
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsClosed(err error) bool {
|
func IsClosed(err error) bool {
|
||||||
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, syscall.ENOTCONN)
|
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET)
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsCanceled(err error) bool {
|
func IsCanceled(err error) bool {
|
||||||
|
|
|
@ -1,14 +1,24 @@
|
||||||
package exceptions
|
package exceptions
|
||||||
|
|
||||||
import (
|
import "github.com/sagernet/sing/common"
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
type HasInnerError interface {
|
||||||
)
|
Unwrap() error
|
||||||
|
}
|
||||||
|
|
||||||
// Deprecated: Use errors.Unwrap instead.
|
|
||||||
func Unwrap(err error) error {
|
func Unwrap(err error) error {
|
||||||
return errors.Unwrap(err)
|
for {
|
||||||
|
inner, ok := err.(HasInnerError)
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
innerErr := inner.Unwrap()
|
||||||
|
if innerErr == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = innerErr
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Cast[T any](err error) (T, bool) {
|
func Cast[T any](err error) (T, bool) {
|
||||||
|
|
|
@ -63,5 +63,12 @@ func IsMulti(err error, targetList ...error) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
err = Unwrap(err)
|
||||||
|
multiErr, isMulti := err.(MultiError)
|
||||||
|
if !isMulti {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return common.All(multiErr.Unwrap(), func(it error) bool {
|
||||||
|
return IsMulti(it, targetList...)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,21 +1,17 @@
|
||||||
package exceptions
|
package exceptions
|
||||||
|
|
||||||
import (
|
import "net"
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TimeoutError interface {
|
type TimeoutError interface {
|
||||||
Timeout() bool
|
Timeout() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsTimeout(err error) bool {
|
func IsTimeout(err error) bool {
|
||||||
var netErr net.Error
|
if netErr, isNetErr := err.(net.Error); isNetErr {
|
||||||
if errors.As(err, &netErr) {
|
//goland:noinspection GoDeprecation
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
return netErr.Temporary() && netErr.Timeout()
|
return netErr.Temporary() && netErr.Timeout()
|
||||||
}
|
} else if timeoutErr, isTimeout := Cast[TimeoutError](err); isTimeout {
|
||||||
if timeoutErr, isTimeout := Cast[TimeoutError](err); isTimeout {
|
|
||||||
return timeoutErr.Timeout()
|
return timeoutErr.Timeout()
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
|
|
@ -2,14 +2,13 @@ package badjson
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/json"
|
"github.com/sagernet/sing/common/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Decode(ctx context.Context, content []byte) (any, error) {
|
func Decode(content []byte) (any, error) {
|
||||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||||
return decodeJSON(decoder)
|
return decodeJSON(decoder)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package badjson
|
package badjson
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
|
@ -10,75 +9,75 @@ import (
|
||||||
"github.com/sagernet/sing/common/json"
|
"github.com/sagernet/sing/common/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Omitempty[T any](ctx context.Context, value T) (T, error) {
|
func Omitempty[T any](value T) (T, error) {
|
||||||
objectContent, err := json.MarshalContext(ctx, value)
|
objectContent, err := json.Marshal(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "marshal object")
|
return common.DefaultValue[T](), E.Cause(err, "marshal object")
|
||||||
}
|
}
|
||||||
rawNewObject, err := Decode(ctx, objectContent)
|
rawNewObject, err := Decode(objectContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), err
|
return common.DefaultValue[T](), err
|
||||||
}
|
}
|
||||||
newObjectContent, err := json.MarshalContext(ctx, rawNewObject)
|
newObjectContent, err := json.Marshal(rawNewObject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
|
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
|
||||||
}
|
}
|
||||||
var newObject T
|
var newObject T
|
||||||
err = json.UnmarshalContext(ctx, newObjectContent, &newObject)
|
err = json.Unmarshal(newObjectContent, &newObject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
|
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
|
||||||
}
|
}
|
||||||
return newObject, nil
|
return newObject, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) {
|
func Merge[T any](source T, destination T) (T, error) {
|
||||||
rawSource, err := json.MarshalContext(ctx, source)
|
rawSource, err := json.Marshal(source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
||||||
}
|
}
|
||||||
rawDestination, err := json.MarshalContext(ctx, destination)
|
rawDestination, err := json.Marshal(destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
||||||
}
|
}
|
||||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
return MergeFrom[T](rawSource, rawDestination)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
|
func MergeFromSource[T any](rawSource json.RawMessage, destination T) (T, error) {
|
||||||
if rawSource == nil {
|
if rawSource == nil {
|
||||||
return destination, nil
|
return destination, nil
|
||||||
}
|
}
|
||||||
rawDestination, err := json.MarshalContext(ctx, destination)
|
rawDestination, err := json.Marshal(destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
||||||
}
|
}
|
||||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
return MergeFrom[T](rawSource, rawDestination)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
func MergeFromDestination[T any](source T, rawDestination json.RawMessage) (T, error) {
|
||||||
if rawDestination == nil {
|
if rawDestination == nil {
|
||||||
return source, nil
|
return source, nil
|
||||||
}
|
}
|
||||||
rawSource, err := json.MarshalContext(ctx, source)
|
rawSource, err := json.Marshal(source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
||||||
}
|
}
|
||||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
return MergeFrom[T](rawSource, rawDestination)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage) (T, error) {
|
||||||
rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend)
|
rawMerged, err := MergeJSON(rawSource, rawDestination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "merge options")
|
return common.DefaultValue[T](), E.Cause(err, "merge options")
|
||||||
}
|
}
|
||||||
var merged T
|
var merged T
|
||||||
err = json.UnmarshalContext(ctx, rawMerged, &merged)
|
err = json.Unmarshal(rawMerged, &merged)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
|
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
|
||||||
}
|
}
|
||||||
return merged, nil
|
return merged, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
|
func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage) (json.RawMessage, error) {
|
||||||
if rawSource == nil && rawDestination == nil {
|
if rawSource == nil && rawDestination == nil {
|
||||||
return nil, os.ErrInvalid
|
return nil, os.ErrInvalid
|
||||||
} else if rawSource == nil {
|
} else if rawSource == nil {
|
||||||
|
@ -86,36 +85,34 @@ func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination js
|
||||||
} else if rawDestination == nil {
|
} else if rawDestination == nil {
|
||||||
return rawSource, nil
|
return rawSource, nil
|
||||||
}
|
}
|
||||||
source, err := Decode(ctx, rawSource)
|
source, err := Decode(rawSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "decode source")
|
return nil, E.Cause(err, "decode source")
|
||||||
}
|
}
|
||||||
destination, err := Decode(ctx, rawDestination)
|
destination, err := Decode(rawDestination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "decode destination")
|
return nil, E.Cause(err, "decode destination")
|
||||||
}
|
}
|
||||||
if source == nil {
|
if source == nil {
|
||||||
return json.MarshalContext(ctx, destination)
|
return json.Marshal(destination)
|
||||||
} else if destination == nil {
|
} else if destination == nil {
|
||||||
return json.Marshal(source)
|
return json.Marshal(source)
|
||||||
}
|
}
|
||||||
merged, err := mergeJSON(source, destination, disableAppend)
|
merged, err := mergeJSON(source, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return json.MarshalContext(ctx, merged)
|
return json.Marshal(merged)
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) {
|
func mergeJSON(anySource any, anyDestination any) (any, error) {
|
||||||
switch destination := anyDestination.(type) {
|
switch destination := anyDestination.(type) {
|
||||||
case JSONArray:
|
case JSONArray:
|
||||||
if !disableAppend {
|
switch source := anySource.(type) {
|
||||||
switch source := anySource.(type) {
|
case JSONArray:
|
||||||
case JSONArray:
|
destination = append(destination, source...)
|
||||||
destination = append(destination, source...)
|
default:
|
||||||
default:
|
destination = append(destination, source)
|
||||||
destination = append(destination, source)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return destination, nil
|
return destination, nil
|
||||||
case *JSONObject:
|
case *JSONObject:
|
||||||
|
@ -125,7 +122,7 @@ func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, erro
|
||||||
oldValue, loaded := destination.Get(entry.Key)
|
oldValue, loaded := destination.Get(entry.Key)
|
||||||
if loaded {
|
if loaded {
|
||||||
var err error
|
var err error
|
||||||
entry.Value, err = mergeJSON(entry.Value, oldValue, disableAppend)
|
entry.Value, err = mergeJSON(entry.Value, oldValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "merge object item ", entry.Key)
|
return nil, E.Cause(err, "merge object item ", entry.Key)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,68 +0,0 @@
|
||||||
package badjson
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
"github.com/sagernet/sing/common/json"
|
|
||||||
cJSON "github.com/sagernet/sing/common/json/internal/contextjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
func MarshallObjects(objects ...any) ([]byte, error) {
|
|
||||||
return MarshallObjectsContext(context.Background(), objects...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) {
|
|
||||||
if len(objects) == 1 {
|
|
||||||
return json.Marshal(objects[0])
|
|
||||||
}
|
|
||||||
var content JSONObject
|
|
||||||
for _, object := range objects {
|
|
||||||
objectMap, err := newJSONObject(ctx, object)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
content.PutAll(objectMap)
|
|
||||||
}
|
|
||||||
return content.MarshalJSONContext(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error {
|
|
||||||
return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object)
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
|
|
||||||
var content JSONObject
|
|
||||||
err := content.UnmarshalJSONContext(ctx, inputContent)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) {
|
|
||||||
content.Remove(key)
|
|
||||||
}
|
|
||||||
if object == nil {
|
|
||||||
if content.IsEmpty() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return E.New("unexpected key: ", content.Keys()[0])
|
|
||||||
}
|
|
||||||
inputContent, err = content.MarshalJSONContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newJSONObject(ctx context.Context, object any) (*JSONObject, error) {
|
|
||||||
inputContent, err := json.MarshalContext(ctx, object)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var content JSONObject
|
|
||||||
err = content.UnmarshalJSONContext(ctx, inputContent)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &content, nil
|
|
||||||
}
|
|
|
@ -2,7 +2,6 @@ package badjson
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
|
@ -29,10 +28,6 @@ func (m *JSONObject) IsEmpty() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *JSONObject) MarshalJSON() ([]byte, error) {
|
func (m *JSONObject) MarshalJSON() ([]byte, error) {
|
||||||
return m.MarshalJSONContext(context.Background())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
|
||||||
buffer := new(bytes.Buffer)
|
buffer := new(bytes.Buffer)
|
||||||
buffer.WriteString("{")
|
buffer.WriteString("{")
|
||||||
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
|
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
|
||||||
|
@ -43,13 +38,13 @@ func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||||
})
|
})
|
||||||
iLen := len(items)
|
iLen := len(items)
|
||||||
for i, entry := range items {
|
for i, entry := range items {
|
||||||
keyContent, err := json.MarshalContext(ctx, entry.Key)
|
keyContent, err := json.Marshal(entry.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
||||||
buffer.WriteString(": ")
|
buffer.WriteString(": ")
|
||||||
valueContent, err := json.MarshalContext(ctx, entry.Value)
|
valueContent, err := json.Marshal(entry.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -63,11 +58,7 @@ func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *JSONObject) UnmarshalJSON(content []byte) error {
|
func (m *JSONObject) UnmarshalJSON(content []byte) error {
|
||||||
return m.UnmarshalJSONContext(context.Background(), content)
|
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||||
}
|
|
||||||
|
|
||||||
func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
|
||||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
|
||||||
m.Clear()
|
m.Clear()
|
||||||
objectStart, err := decoder.Token()
|
objectStart, err := decoder.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -2,7 +2,6 @@ package badjson
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
@ -15,22 +14,18 @@ type TypedMap[K comparable, V any] struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
|
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
|
||||||
return m.MarshalJSONContext(context.Background())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
|
||||||
buffer := new(bytes.Buffer)
|
buffer := new(bytes.Buffer)
|
||||||
buffer.WriteString("{")
|
buffer.WriteString("{")
|
||||||
items := m.Entries()
|
items := m.Entries()
|
||||||
iLen := len(items)
|
iLen := len(items)
|
||||||
for i, entry := range items {
|
for i, entry := range items {
|
||||||
keyContent, err := json.MarshalContext(ctx, entry.Key)
|
keyContent, err := json.Marshal(entry.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
||||||
buffer.WriteString(": ")
|
buffer.WriteString(": ")
|
||||||
valueContent, err := json.MarshalContext(ctx, entry.Value)
|
valueContent, err := json.Marshal(entry.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -44,11 +39,7 @@ func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
|
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
|
||||||
return m.UnmarshalJSONContext(context.Background(), content)
|
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
|
||||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
|
||||||
m.Clear()
|
m.Clear()
|
||||||
objectStart, err := decoder.Token()
|
objectStart, err := decoder.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -56,7 +47,7 @@ func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byt
|
||||||
} else if objectStart != json.Delim('{') {
|
} else if objectStart != json.Delim('{') {
|
||||||
return E.New("expected json object start, but starts with ", objectStart)
|
return E.New("expected json object start, but starts with ", objectStart)
|
||||||
}
|
}
|
||||||
err = m.decodeJSON(ctx, decoder)
|
err = m.decodeJSON(decoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "decode json object content")
|
return E.Cause(err, "decode json object content")
|
||||||
}
|
}
|
||||||
|
@ -69,18 +60,18 @@ func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byt
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error {
|
func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error {
|
||||||
for decoder.More() {
|
for decoder.More() {
|
||||||
keyToken, err := decoder.Token()
|
keyToken, err := decoder.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
keyContent, err := json.MarshalContext(ctx, keyToken)
|
keyContent, err := json.Marshal(keyToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var entryKey K
|
var entryKey K
|
||||||
err = json.UnmarshalContext(ctx, keyContent, &entryKey)
|
err = json.Unmarshal(keyContent, &entryKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,32 +0,0 @@
|
||||||
package badoption
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/json"
|
|
||||||
"github.com/sagernet/sing/common/json/badoption/internal/my_time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Duration time.Duration
|
|
||||||
|
|
||||||
func (d Duration) Build() time.Duration {
|
|
||||||
return time.Duration(d)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d Duration) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal((time.Duration)(d).String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Duration) UnmarshalJSON(bytes []byte) error {
|
|
||||||
var value string
|
|
||||||
err := json.Unmarshal(bytes, &value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
duration, err := my_time.ParseDuration(value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
*d = Duration(duration)
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,15 +0,0 @@
|
||||||
package badoption
|
|
||||||
|
|
||||||
import "net/http"
|
|
||||||
|
|
||||||
type HTTPHeader map[string]Listable[string]
|
|
||||||
|
|
||||||
func (h HTTPHeader) Build() http.Header {
|
|
||||||
header := make(http.Header)
|
|
||||||
for name, values := range h {
|
|
||||||
for _, value := range values {
|
|
||||||
header.Add(name, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return header
|
|
||||||
}
|
|
|
@ -1,226 +0,0 @@
|
||||||
package my_time
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Copyright 2010 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
const durationDay = 24 * time.Hour
|
|
||||||
|
|
||||||
var unitMap = map[string]uint64{
|
|
||||||
"ns": uint64(time.Nanosecond),
|
|
||||||
"us": uint64(time.Microsecond),
|
|
||||||
"µs": uint64(time.Microsecond), // U+00B5 = micro symbol
|
|
||||||
"μs": uint64(time.Microsecond), // U+03BC = Greek letter mu
|
|
||||||
"ms": uint64(time.Millisecond),
|
|
||||||
"s": uint64(time.Second),
|
|
||||||
"m": uint64(time.Minute),
|
|
||||||
"h": uint64(time.Hour),
|
|
||||||
"d": uint64(durationDay),
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseDuration parses a duration string.
|
|
||||||
// A duration string is a possibly signed sequence of
|
|
||||||
// decimal numbers, each with optional fraction and a unit suffix,
|
|
||||||
// such as "300ms", "-1.5h" or "2h45m".
|
|
||||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
|
||||||
func ParseDuration(s string) (time.Duration, error) {
|
|
||||||
// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
|
|
||||||
orig := s
|
|
||||||
var d uint64
|
|
||||||
neg := false
|
|
||||||
|
|
||||||
// Consume [-+]?
|
|
||||||
if s != "" {
|
|
||||||
c := s[0]
|
|
||||||
if c == '-' || c == '+' {
|
|
||||||
neg = c == '-'
|
|
||||||
s = s[1:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Special case: if all that is left is "0", this is zero.
|
|
||||||
if s == "0" {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
if s == "" {
|
|
||||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
|
||||||
}
|
|
||||||
for s != "" {
|
|
||||||
var (
|
|
||||||
v, f uint64 // integers before, after decimal point
|
|
||||||
scale float64 = 1 // value = v + f/scale
|
|
||||||
)
|
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// The next character must be [0-9.]
|
|
||||||
if !(s[0] == '.' || '0' <= s[0] && s[0] <= '9') {
|
|
||||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
|
||||||
}
|
|
||||||
// Consume [0-9]*
|
|
||||||
pl := len(s)
|
|
||||||
v, s, err = leadingInt(s)
|
|
||||||
if err != nil {
|
|
||||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
|
||||||
}
|
|
||||||
pre := pl != len(s) // whether we consumed anything before a period
|
|
||||||
|
|
||||||
// Consume (\.[0-9]*)?
|
|
||||||
post := false
|
|
||||||
if s != "" && s[0] == '.' {
|
|
||||||
s = s[1:]
|
|
||||||
pl := len(s)
|
|
||||||
f, scale, s = leadingFraction(s)
|
|
||||||
post = pl != len(s)
|
|
||||||
}
|
|
||||||
if !pre && !post {
|
|
||||||
// no digits (e.g. ".s" or "-.s")
|
|
||||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Consume unit.
|
|
||||||
i := 0
|
|
||||||
for ; i < len(s); i++ {
|
|
||||||
c := s[i]
|
|
||||||
if c == '.' || '0' <= c && c <= '9' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if i == 0 {
|
|
||||||
return 0, errors.New("time: missing unit in duration " + quote(orig))
|
|
||||||
}
|
|
||||||
u := s[:i]
|
|
||||||
s = s[i:]
|
|
||||||
unit, ok := unitMap[u]
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("time: unknown unit " + quote(u) + " in duration " + quote(orig))
|
|
||||||
}
|
|
||||||
if v > 1<<63/unit {
|
|
||||||
// overflow
|
|
||||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
|
||||||
}
|
|
||||||
v *= unit
|
|
||||||
if f > 0 {
|
|
||||||
// float64 is needed to be nanosecond accurate for fractions of hours.
|
|
||||||
// v >= 0 && (f*unit/scale) <= 3.6e+12 (ns/h, h is the largest unit)
|
|
||||||
v += uint64(float64(f) * (float64(unit) / scale))
|
|
||||||
if v > 1<<63 {
|
|
||||||
// overflow
|
|
||||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
d += v
|
|
||||||
if d > 1<<63 {
|
|
||||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if neg {
|
|
||||||
return -time.Duration(d), nil
|
|
||||||
}
|
|
||||||
if d > 1<<63-1 {
|
|
||||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
|
||||||
}
|
|
||||||
return time.Duration(d), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var errLeadingInt = errors.New("time: bad [0-9]*") // never printed
|
|
||||||
|
|
||||||
// leadingInt consumes the leading [0-9]* from s.
|
|
||||||
func leadingInt[bytes []byte | string](s bytes) (x uint64, rem bytes, err error) {
|
|
||||||
i := 0
|
|
||||||
for ; i < len(s); i++ {
|
|
||||||
c := s[i]
|
|
||||||
if c < '0' || c > '9' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if x > 1<<63/10 {
|
|
||||||
// overflow
|
|
||||||
return 0, rem, errLeadingInt
|
|
||||||
}
|
|
||||||
x = x*10 + uint64(c) - '0'
|
|
||||||
if x > 1<<63 {
|
|
||||||
// overflow
|
|
||||||
return 0, rem, errLeadingInt
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return x, s[i:], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// leadingFraction consumes the leading [0-9]* from s.
|
|
||||||
// It is used only for fractions, so does not return an error on overflow,
|
|
||||||
// it just stops accumulating precision.
|
|
||||||
func leadingFraction(s string) (x uint64, scale float64, rem string) {
|
|
||||||
i := 0
|
|
||||||
scale = 1
|
|
||||||
overflow := false
|
|
||||||
for ; i < len(s); i++ {
|
|
||||||
c := s[i]
|
|
||||||
if c < '0' || c > '9' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if overflow {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if x > (1<<63-1)/10 {
|
|
||||||
// It's possible for overflow to give a positive number, so take care.
|
|
||||||
overflow = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
y := x*10 + uint64(c) - '0'
|
|
||||||
if y > 1<<63 {
|
|
||||||
overflow = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
x = y
|
|
||||||
scale *= 10
|
|
||||||
}
|
|
||||||
return x, scale, s[i:]
|
|
||||||
}
|
|
||||||
|
|
||||||
// These are borrowed from unicode/utf8 and strconv and replicate behavior in
|
|
||||||
// that package, since we can't take a dependency on either.
|
|
||||||
const (
|
|
||||||
lowerhex = "0123456789abcdef"
|
|
||||||
runeSelf = 0x80
|
|
||||||
runeError = '\uFFFD'
|
|
||||||
)
|
|
||||||
|
|
||||||
func quote(s string) string {
|
|
||||||
buf := make([]byte, 1, len(s)+2) // slice will be at least len(s) + quotes
|
|
||||||
buf[0] = '"'
|
|
||||||
for i, c := range s {
|
|
||||||
if c >= runeSelf || c < ' ' {
|
|
||||||
// This means you are asking us to parse a time.Duration or
|
|
||||||
// time.Location with unprintable or non-ASCII characters in it.
|
|
||||||
// We don't expect to hit this case very often. We could try to
|
|
||||||
// reproduce strconv.Quote's behavior with full fidelity but
|
|
||||||
// given how rarely we expect to hit these edge cases, speed and
|
|
||||||
// conciseness are better.
|
|
||||||
var width int
|
|
||||||
if c == runeError {
|
|
||||||
width = 1
|
|
||||||
if i+2 < len(s) && s[i:i+3] == string(runeError) {
|
|
||||||
width = 3
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
width = len(string(c))
|
|
||||||
}
|
|
||||||
for j := 0; j < width; j++ {
|
|
||||||
buf = append(buf, `\x`...)
|
|
||||||
buf = append(buf, lowerhex[s[i+j]>>4])
|
|
||||||
buf = append(buf, lowerhex[s[i+j]&0xF])
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if c == '"' || c == '\\' {
|
|
||||||
buf = append(buf, '\\')
|
|
||||||
}
|
|
||||||
buf = append(buf, string(c)...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
buf = append(buf, '"')
|
|
||||||
return string(buf)
|
|
||||||
}
|
|
|
@ -1,35 +0,0 @@
|
||||||
package badoption
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
"github.com/sagernet/sing/common/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Listable[T any] []T
|
|
||||||
|
|
||||||
func (l Listable[T]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
|
||||||
arrayList := []T(l)
|
|
||||||
if len(arrayList) == 1 {
|
|
||||||
return json.Marshal(arrayList[0])
|
|
||||||
}
|
|
||||||
return json.MarshalContext(ctx, arrayList)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
|
||||||
if string(content) == "null" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var singleItem T
|
|
||||||
err := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem)
|
|
||||||
if err == nil {
|
|
||||||
*l = []T{singleItem}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
newErr := json.UnmarshalContextDisallowUnknownFields(ctx, content, (*[]T)(l))
|
|
||||||
if newErr == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return E.Errors(err, newErr)
|
|
||||||
}
|
|
|
@ -1,98 +0,0 @@
|
||||||
package badoption
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Addr netip.Addr
|
|
||||||
|
|
||||||
func (a *Addr) Build(defaultAddr netip.Addr) netip.Addr {
|
|
||||||
if a == nil {
|
|
||||||
return defaultAddr
|
|
||||||
}
|
|
||||||
return netip.Addr(*a)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Addr) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(netip.Addr(*a).String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Addr) UnmarshalJSON(content []byte) error {
|
|
||||||
var value string
|
|
||||||
err := json.Unmarshal(content, &value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
addr, err := netip.ParseAddr(value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
*a = Addr(addr)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Prefix netip.Prefix
|
|
||||||
|
|
||||||
func (p *Prefix) Build(defaultPrefix netip.Prefix) netip.Prefix {
|
|
||||||
if p == nil {
|
|
||||||
return defaultPrefix
|
|
||||||
}
|
|
||||||
return netip.Prefix(*p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Prefix) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(netip.Prefix(*p).String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Prefix) UnmarshalJSON(content []byte) error {
|
|
||||||
var value string
|
|
||||||
err := json.Unmarshal(content, &value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
prefix, err := netip.ParsePrefix(value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
*p = Prefix(prefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Prefixable netip.Prefix
|
|
||||||
|
|
||||||
func (p *Prefixable) Build(defaultPrefix netip.Prefix) netip.Prefix {
|
|
||||||
if p == nil {
|
|
||||||
return defaultPrefix
|
|
||||||
}
|
|
||||||
return netip.Prefix(*p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Prefixable) MarshalJSON() ([]byte, error) {
|
|
||||||
prefix := netip.Prefix(*p)
|
|
||||||
if prefix.Bits() == prefix.Addr().BitLen() {
|
|
||||||
return json.Marshal(prefix.Addr().String())
|
|
||||||
} else {
|
|
||||||
return json.Marshal(prefix.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Prefixable) UnmarshalJSON(content []byte) error {
|
|
||||||
var value string
|
|
||||||
err := json.Unmarshal(content, &value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
prefix, prefixErr := netip.ParsePrefix(value)
|
|
||||||
if prefixErr == nil {
|
|
||||||
*p = Prefixable(prefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
addr, addrErr := netip.ParseAddr(value)
|
|
||||||
if addrErr == nil {
|
|
||||||
*p = Prefixable(netip.PrefixFrom(addr, addr.BitLen()))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return prefixErr
|
|
||||||
}
|
|
|
@ -1,31 +0,0 @@
|
||||||
package badoption
|
|
||||||
|
|
||||||
import (
|
|
||||||
"regexp"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Regexp regexp.Regexp
|
|
||||||
|
|
||||||
func (r *Regexp) Build() *regexp.Regexp {
|
|
||||||
return (*regexp.Regexp)(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Regexp) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal((*regexp.Regexp)(r).String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Regexp) UnmarshalJSON(content []byte) error {
|
|
||||||
var stringValue string
|
|
||||||
err := json.Unmarshal(content, &stringValue)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
regex, err := regexp.Compile(stringValue)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
*r = Regexp(*regex)
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,23 +0,0 @@
|
||||||
package json
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/json/internal/contextjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
MarshalContext = json.MarshalContext
|
|
||||||
UnmarshalContext = json.UnmarshalContext
|
|
||||||
NewEncoderContext = json.NewEncoderContext
|
|
||||||
NewDecoderContext = json.NewDecoderContext
|
|
||||||
UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields
|
|
||||||
)
|
|
||||||
|
|
||||||
type ContextMarshaler interface {
|
|
||||||
MarshalJSONContext(ctx context.Context) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContextUnmarshaler interface {
|
|
||||||
UnmarshalJSONContext(ctx context.Context, content []byte) error
|
|
||||||
}
|
|
|
@ -1,11 +0,0 @@
|
||||||
package json
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
type ContextMarshaler interface {
|
|
||||||
MarshalJSONContext(ctx context.Context) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContextUnmarshaler interface {
|
|
||||||
UnmarshalJSONContext(ctx context.Context, content []byte) error
|
|
||||||
}
|
|
|
@ -1,43 +0,0 @@
|
||||||
package json_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/json/internal/contextjson"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type myStruct struct {
|
|
||||||
value string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
|
||||||
return json.Marshal(ctx.Value("key").(string))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
|
||||||
m.value = ctx.Value("key").(string)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:staticcheck
|
|
||||||
func TestMarshalContext(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
ctx := context.WithValue(context.Background(), "key", "value")
|
|
||||||
var s myStruct
|
|
||||||
b, err := json.MarshalContext(ctx, &s)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, []byte(`"value"`), b)
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:staticcheck
|
|
||||||
func TestUnmarshalContext(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
ctx := context.WithValue(context.Background(), "key", "value")
|
|
||||||
var s myStruct
|
|
||||||
err := json.UnmarshalContext(ctx, []byte(`{}`), &s)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "value", s.value)
|
|
||||||
}
|
|
|
@ -8,7 +8,6 @@
|
||||||
package json
|
package json
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding"
|
"encoding"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -96,15 +95,10 @@ import (
|
||||||
// Instead, they are replaced by the Unicode replacement
|
// Instead, they are replaced by the Unicode replacement
|
||||||
// character U+FFFD.
|
// character U+FFFD.
|
||||||
func Unmarshal(data []byte, v any) error {
|
func Unmarshal(data []byte, v any) error {
|
||||||
return UnmarshalContext(context.Background(), data, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnmarshalContext(ctx context.Context, data []byte, v any) error {
|
|
||||||
// Check for well-formedness.
|
// Check for well-formedness.
|
||||||
// Avoids filling out half a data structure
|
// Avoids filling out half a data structure
|
||||||
// before discovering a JSON syntax error.
|
// before discovering a JSON syntax error.
|
||||||
var d decodeState
|
var d decodeState
|
||||||
d.ctx = ctx
|
|
||||||
err := checkValid(data, &d.scan)
|
err := checkValid(data, &d.scan)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -215,7 +209,6 @@ type errorContext struct {
|
||||||
|
|
||||||
// decodeState represents the state while decoding a JSON value.
|
// decodeState represents the state while decoding a JSON value.
|
||||||
type decodeState struct {
|
type decodeState struct {
|
||||||
ctx context.Context
|
|
||||||
data []byte
|
data []byte
|
||||||
off int // next read offset in data
|
off int // next read offset in data
|
||||||
opcode int // last read result
|
opcode int // last read result
|
||||||
|
@ -435,7 +428,7 @@ func (d *decodeState) valueQuoted() any {
|
||||||
// If it encounters an Unmarshaler, indirect stops and returns that.
|
// If it encounters an Unmarshaler, indirect stops and returns that.
|
||||||
// If decodingNull is true, indirect stops at the first settable pointer so it
|
// If decodingNull is true, indirect stops at the first settable pointer so it
|
||||||
// can be set to nil.
|
// can be set to nil.
|
||||||
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) {
|
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
|
||||||
// Issue #24153 indicates that it is generally not a guaranteed property
|
// Issue #24153 indicates that it is generally not a guaranteed property
|
||||||
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
|
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
|
||||||
// and expect the value to still be settable for values derived from
|
// and expect the value to still be settable for values derived from
|
||||||
|
@ -489,14 +482,11 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshal
|
||||||
}
|
}
|
||||||
if v.Type().NumMethod() > 0 && v.CanInterface() {
|
if v.Type().NumMethod() > 0 && v.CanInterface() {
|
||||||
if u, ok := v.Interface().(Unmarshaler); ok {
|
if u, ok := v.Interface().(Unmarshaler); ok {
|
||||||
return u, nil, nil, reflect.Value{}
|
return u, nil, reflect.Value{}
|
||||||
}
|
|
||||||
if cu, ok := v.Interface().(ContextUnmarshaler); ok {
|
|
||||||
return nil, cu, nil, reflect.Value{}
|
|
||||||
}
|
}
|
||||||
if !decodingNull {
|
if !decodingNull {
|
||||||
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
|
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
|
||||||
return nil, nil, u, reflect.Value{}
|
return nil, u, reflect.Value{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -508,14 +498,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshal
|
||||||
v = v.Elem()
|
v = v.Elem()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, nil, nil, v
|
return nil, nil, v
|
||||||
}
|
}
|
||||||
|
|
||||||
// array consumes an array from d.data[d.off-1:], decoding into v.
|
// array consumes an array from d.data[d.off-1:], decoding into v.
|
||||||
// The first byte of the array ('[') has been read already.
|
// The first byte of the array ('[') has been read already.
|
||||||
func (d *decodeState) array(v reflect.Value) error {
|
func (d *decodeState) array(v reflect.Value) error {
|
||||||
// Check for unmarshaler.
|
// Check for unmarshaler.
|
||||||
u, cu, ut, pv := indirect(v, false)
|
u, ut, pv := indirect(v, false)
|
||||||
if u != nil {
|
if u != nil {
|
||||||
start := d.readIndex()
|
start := d.readIndex()
|
||||||
d.skip()
|
d.skip()
|
||||||
|
@ -525,15 +515,6 @@ func (d *decodeState) array(v reflect.Value) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if cu != nil {
|
|
||||||
start := d.readIndex()
|
|
||||||
d.skip()
|
|
||||||
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
|
|
||||||
if err != nil {
|
|
||||||
d.saveError(err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if ut != nil {
|
if ut != nil {
|
||||||
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
|
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
|
||||||
d.skip()
|
d.skip()
|
||||||
|
@ -631,7 +612,7 @@ var (
|
||||||
// The first byte ('{') of the object has been read already.
|
// The first byte ('{') of the object has been read already.
|
||||||
func (d *decodeState) object(v reflect.Value) error {
|
func (d *decodeState) object(v reflect.Value) error {
|
||||||
// Check for unmarshaler.
|
// Check for unmarshaler.
|
||||||
u, cu, ut, pv := indirect(v, false)
|
u, ut, pv := indirect(v, false)
|
||||||
if u != nil {
|
if u != nil {
|
||||||
start := d.readIndex()
|
start := d.readIndex()
|
||||||
d.skip()
|
d.skip()
|
||||||
|
@ -641,15 +622,6 @@ func (d *decodeState) object(v reflect.Value) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if cu != nil {
|
|
||||||
start := d.readIndex()
|
|
||||||
d.skip()
|
|
||||||
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
|
|
||||||
if err != nil {
|
|
||||||
d.saveError(err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if ut != nil {
|
if ut != nil {
|
||||||
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
|
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
|
||||||
d.skip()
|
d.skip()
|
||||||
|
@ -898,7 +870,7 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
isNull := item[0] == 'n' // null
|
isNull := item[0] == 'n' // null
|
||||||
u, cu, ut, pv := indirect(v, isNull)
|
u, ut, pv := indirect(v, isNull)
|
||||||
if u != nil {
|
if u != nil {
|
||||||
err := u.UnmarshalJSON(item)
|
err := u.UnmarshalJSON(item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -906,13 +878,6 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if cu != nil {
|
|
||||||
err := cu.UnmarshalJSONContext(d.ctx, item)
|
|
||||||
if err != nil {
|
|
||||||
d.saveError(err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if ut != nil {
|
if ut != nil {
|
||||||
if item[0] != '"' {
|
if item[0] != '"' {
|
||||||
if fromQuoted {
|
if fromQuoted {
|
||||||
|
|
|
@ -12,7 +12,6 @@ package json
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding"
|
"encoding"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -157,11 +156,7 @@ import (
|
||||||
// handle them. Passing cyclic structures to Marshal will result in
|
// handle them. Passing cyclic structures to Marshal will result in
|
||||||
// an error.
|
// an error.
|
||||||
func Marshal(v any) ([]byte, error) {
|
func Marshal(v any) ([]byte, error) {
|
||||||
return MarshalContext(context.Background(), v)
|
e := newEncodeState()
|
||||||
}
|
|
||||||
|
|
||||||
func MarshalContext(ctx context.Context, v any) ([]byte, error) {
|
|
||||||
e := newEncodeState(ctx)
|
|
||||||
defer encodeStatePool.Put(e)
|
defer encodeStatePool.Put(e)
|
||||||
|
|
||||||
err := e.marshal(v, encOpts{escapeHTML: true})
|
err := e.marshal(v, encOpts{escapeHTML: true})
|
||||||
|
@ -256,7 +251,6 @@ var hex = "0123456789abcdef"
|
||||||
type encodeState struct {
|
type encodeState struct {
|
||||||
bytes.Buffer // accumulated output
|
bytes.Buffer // accumulated output
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
// Keep track of what pointers we've seen in the current recursive call
|
// Keep track of what pointers we've seen in the current recursive call
|
||||||
// path, to avoid cycles that could lead to a stack overflow. Only do
|
// path, to avoid cycles that could lead to a stack overflow. Only do
|
||||||
// the relatively expensive map operations if ptrLevel is larger than
|
// the relatively expensive map operations if ptrLevel is larger than
|
||||||
|
@ -270,7 +264,7 @@ const startDetectingCyclesAfter = 1000
|
||||||
|
|
||||||
var encodeStatePool sync.Pool
|
var encodeStatePool sync.Pool
|
||||||
|
|
||||||
func newEncodeState(ctx context.Context) *encodeState {
|
func newEncodeState() *encodeState {
|
||||||
if v := encodeStatePool.Get(); v != nil {
|
if v := encodeStatePool.Get(); v != nil {
|
||||||
e := v.(*encodeState)
|
e := v.(*encodeState)
|
||||||
e.Reset()
|
e.Reset()
|
||||||
|
@ -280,7 +274,7 @@ func newEncodeState(ctx context.Context) *encodeState {
|
||||||
e.ptrLevel = 0
|
e.ptrLevel = 0
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})}
|
return &encodeState{ptrSeen: make(map[any]struct{})}
|
||||||
}
|
}
|
||||||
|
|
||||||
// jsonError is an error wrapper type for internal use only.
|
// jsonError is an error wrapper type for internal use only.
|
||||||
|
@ -377,9 +371,8 @@ func typeEncoder(t reflect.Type) encoderFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
|
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
|
||||||
contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem()
|
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
|
||||||
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// newTypeEncoder constructs an encoderFunc for a type.
|
// newTypeEncoder constructs an encoderFunc for a type.
|
||||||
|
@ -392,15 +385,9 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
|
||||||
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
|
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
|
||||||
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
|
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
|
||||||
}
|
}
|
||||||
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) {
|
|
||||||
return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false))
|
|
||||||
}
|
|
||||||
if t.Implements(marshalerType) {
|
if t.Implements(marshalerType) {
|
||||||
return marshalerEncoder
|
return marshalerEncoder
|
||||||
}
|
}
|
||||||
if t.Implements(contextMarshalerType) {
|
|
||||||
return contextMarshalerEncoder
|
|
||||||
}
|
|
||||||
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
|
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
|
||||||
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
|
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
|
||||||
}
|
}
|
||||||
|
@ -483,47 +470,6 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|
||||||
if v.Kind() == reflect.Pointer && v.IsNil() {
|
|
||||||
e.WriteString("null")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m, ok := v.Interface().(ContextMarshaler)
|
|
||||||
if !ok {
|
|
||||||
e.WriteString("null")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b, err := m.MarshalJSONContext(e.ctx)
|
|
||||||
if err == nil {
|
|
||||||
e.Grow(len(b))
|
|
||||||
out := availableBuffer(&e.Buffer)
|
|
||||||
out, err = appendCompact(out, b, opts.escapeHTML)
|
|
||||||
e.Buffer.Write(out)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|
||||||
va := v.Addr()
|
|
||||||
if va.IsNil() {
|
|
||||||
e.WriteString("null")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m := va.Interface().(ContextMarshaler)
|
|
||||||
b, err := m.MarshalJSONContext(e.ctx)
|
|
||||||
if err == nil {
|
|
||||||
e.Grow(len(b))
|
|
||||||
out := availableBuffer(&e.Buffer)
|
|
||||||
out, err = appendCompact(out, b, opts.escapeHTML)
|
|
||||||
e.Buffer.Write(out)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
if v.Kind() == reflect.Pointer && v.IsNil() {
|
if v.Kind() == reflect.Pointer && v.IsNil() {
|
||||||
e.WriteString("null")
|
e.WriteString("null")
|
||||||
|
@ -881,7 +827,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc {
|
||||||
// Byte slices get special treatment; arrays don't.
|
// Byte slices get special treatment; arrays don't.
|
||||||
if t.Elem().Kind() == reflect.Uint8 {
|
if t.Elem().Kind() == reflect.Uint8 {
|
||||||
p := reflect.PointerTo(t.Elem())
|
p := reflect.PointerTo(t.Elem())
|
||||||
if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) {
|
if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) {
|
||||||
return encodeByteSlice
|
return encodeByteSlice
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,20 +0,0 @@
|
||||||
package json
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ObjectKeys(object reflect.Type) []string {
|
|
||||||
switch object.Kind() {
|
|
||||||
case reflect.Pointer:
|
|
||||||
return ObjectKeys(object.Elem())
|
|
||||||
case reflect.Struct:
|
|
||||||
default:
|
|
||||||
panic("invalid non-struct input")
|
|
||||||
}
|
|
||||||
return common.Map(cachedTypeFields(object).list, func(field field) string {
|
|
||||||
return field.name
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -1,26 +0,0 @@
|
||||||
package json_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
json "github.com/sagernet/sing/common/json/internal/contextjson"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MyObject struct {
|
|
||||||
Hello string `json:"hello,omitempty"`
|
|
||||||
MyWorld
|
|
||||||
MyWorld2 string `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MyWorld struct {
|
|
||||||
World string `json:"world,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestObjectKeys(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
keys := json.ObjectKeys(reflect.TypeOf(&MyObject{}))
|
|
||||||
require.Equal(t, []string{"hello", "world"}, keys)
|
|
||||||
}
|
|
|
@ -6,7 +6,6 @@ package json
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
@ -30,11 +29,7 @@ type Decoder struct {
|
||||||
// The decoder introduces its own buffering and may
|
// The decoder introduces its own buffering and may
|
||||||
// read data from r beyond the JSON values requested.
|
// read data from r beyond the JSON values requested.
|
||||||
func NewDecoder(r io.Reader) *Decoder {
|
func NewDecoder(r io.Reader) *Decoder {
|
||||||
return NewDecoderContext(context.Background(), r)
|
return &Decoder{r: r}
|
||||||
}
|
|
||||||
|
|
||||||
func NewDecoderContext(ctx context.Context, r io.Reader) *Decoder {
|
|
||||||
return &Decoder{r: r, d: decodeState{ctx: ctx}}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
|
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
|
||||||
|
@ -188,7 +183,6 @@ func nonSpace(b []byte) bool {
|
||||||
|
|
||||||
// An Encoder writes JSON values to an output stream.
|
// An Encoder writes JSON values to an output stream.
|
||||||
type Encoder struct {
|
type Encoder struct {
|
||||||
ctx context.Context
|
|
||||||
w io.Writer
|
w io.Writer
|
||||||
err error
|
err error
|
||||||
escapeHTML bool
|
escapeHTML bool
|
||||||
|
@ -200,11 +194,7 @@ type Encoder struct {
|
||||||
|
|
||||||
// NewEncoder returns a new encoder that writes to w.
|
// NewEncoder returns a new encoder that writes to w.
|
||||||
func NewEncoder(w io.Writer) *Encoder {
|
func NewEncoder(w io.Writer) *Encoder {
|
||||||
return NewEncoderContext(context.Background(), w)
|
return &Encoder{w: w, escapeHTML: true}
|
||||||
}
|
|
||||||
|
|
||||||
func NewEncoderContext(ctx context.Context, w io.Writer) *Encoder {
|
|
||||||
return &Encoder{ctx: ctx, w: w, escapeHTML: true}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode writes the JSON encoding of v to the stream,
|
// Encode writes the JSON encoding of v to the stream,
|
||||||
|
@ -217,7 +207,7 @@ func (enc *Encoder) Encode(v any) error {
|
||||||
return enc.err
|
return enc.err
|
||||||
}
|
}
|
||||||
|
|
||||||
e := newEncodeState(enc.ctx)
|
e := newEncodeState()
|
||||||
defer encodeStatePool.Put(e)
|
defer encodeStatePool.Put(e)
|
||||||
|
|
||||||
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
|
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
|
||||||
|
|
|
@ -1,26 +0,0 @@
|
||||||
package json
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
func UnmarshalDisallowUnknownFields(data []byte, v any) error {
|
|
||||||
var d decodeState
|
|
||||||
d.disallowUnknownFields = true
|
|
||||||
err := checkValid(data, &d.scan)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
d.init(data)
|
|
||||||
return d.unmarshal(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnmarshalContextDisallowUnknownFields(ctx context.Context, data []byte, v any) error {
|
|
||||||
var d decodeState
|
|
||||||
d.ctx = ctx
|
|
||||||
d.disallowUnknownFields = true
|
|
||||||
err := checkValid(data, &d.scan)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
d.init(data)
|
|
||||||
return d.unmarshal(v)
|
|
||||||
}
|
|
|
@ -2,8 +2,6 @@ package json
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
|
@ -11,18 +9,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func UnmarshalExtended[T any](content []byte) (T, error) {
|
func UnmarshalExtended[T any](content []byte) (T, error) {
|
||||||
return UnmarshalExtendedContext[T](context.Background(), content)
|
decoder := NewDecoder(NewCommentFilter(bytes.NewReader(content)))
|
||||||
}
|
|
||||||
|
|
||||||
func UnmarshalExtendedContext[T any](ctx context.Context, content []byte) (T, error) {
|
|
||||||
decoder := NewDecoderContext(ctx, NewCommentFilter(bytes.NewReader(content)))
|
|
||||||
var value T
|
var value T
|
||||||
err := decoder.Decode(&value)
|
err := decoder.Decode(&value)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return value, err
|
return value, err
|
||||||
}
|
}
|
||||||
var syntaxError *SyntaxError
|
if syntaxError, isSyntaxError := err.(*SyntaxError); isSyntaxError {
|
||||||
if errors.As(err, &syntaxError) {
|
|
||||||
prefix := string(content[:syntaxError.Offset])
|
prefix := string(content[:syntaxError.Offset])
|
||||||
row := strings.Count(prefix, "\n") + 1
|
row := strings.Count(prefix, "\n") + 1
|
||||||
column := len(prefix) - strings.LastIndex(prefix, "\n") - 1
|
column := len(prefix) - strings.LastIndex(prefix, "\n") - 1
|
||||||
|
|
|
@ -1,9 +0,0 @@
|
||||||
//go:build go1.20 && !without_contextjson
|
|
||||||
|
|
||||||
package json
|
|
||||||
|
|
||||||
import (
|
|
||||||
json "github.com/sagernet/sing/common/json/internal/contextjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
var UnmarshalDisallowUnknownFields = json.UnmarshalDisallowUnknownFields
|
|
|
@ -1,13 +0,0 @@
|
||||||
//go:build !go1.20 || without_contextjson
|
|
||||||
|
|
||||||
package json
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
func UnmarshalDisallowUnknownFields(content []byte, value any) error {
|
|
||||||
decoder := NewDecoder(bytes.NewReader(content))
|
|
||||||
decoder.DisallowUnknownFields()
|
|
||||||
return decoder.Decode(value)
|
|
||||||
}
|
|
|
@ -1,6 +1,5 @@
|
||||||
package metadata
|
package metadata
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
type Metadata struct {
|
type Metadata struct {
|
||||||
Protocol string
|
Protocol string
|
||||||
Source Socksaddr
|
Source Socksaddr
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/rw"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -115,7 +116,7 @@ func (s *Serializer) WriteAddrPort(writer io.Writer, destination Socksaddr) erro
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !isBuffer {
|
if !isBuffer {
|
||||||
err = common.Error(writer.Write(buffer.Bytes()))
|
err = rw.WriteBytes(writer, buffer.Bytes())
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -129,8 +130,7 @@ func (s *Serializer) AddrPortLen(destination Socksaddr) int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) {
|
func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) {
|
||||||
var af byte
|
af, err := rw.ReadByte(reader)
|
||||||
err := binary.Read(reader, binary.BigEndian, &af)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Socksaddr{}, err
|
return Socksaddr{}, err
|
||||||
}
|
}
|
||||||
|
@ -164,12 +164,11 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) {
|
func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) {
|
||||||
var port uint16
|
port, err := rw.ReadBytes(reader, 2)
|
||||||
err := binary.Read(reader, binary.BigEndian, &port)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, E.Cause(err, "read port")
|
return 0, E.Cause(err, "read port")
|
||||||
}
|
}
|
||||||
return port, nil
|
return binary.BigEndian.Uint16(port), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err error) {
|
func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err error) {
|
||||||
|
@ -196,17 +195,11 @@ func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadSockString(reader io.Reader) (string, error) {
|
func ReadSockString(reader io.Reader) (string, error) {
|
||||||
var strLen byte
|
strLen, err := rw.ReadByte(reader)
|
||||||
err := binary.Read(reader, binary.BigEndian, &strLen)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
strBytes := make([]byte, strLen)
|
return rw.ReadString(reader, int(strLen))
|
||||||
_, err = io.ReadFull(reader, strBytes)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return string(strBytes), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteSocksString(buffer *buf.Buffer, str string) error {
|
func WriteSocksString(buffer *buf.Buffer, str string) error {
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
//go:build go1.21
|
|
||||||
|
|
||||||
package common
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cmp"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Min[T cmp.Ordered](x, y T) T {
|
|
||||||
return min(x, y)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Max[T cmp.Ordered](x, y T) T {
|
|
||||||
return max(x, y)
|
|
||||||
}
|
|
|
@ -1,19 +0,0 @@
|
||||||
//go:build go1.20 && !go1.21
|
|
||||||
|
|
||||||
package common
|
|
||||||
|
|
||||||
import "github.com/sagernet/sing/common/x/constraints"
|
|
||||||
|
|
||||||
func Min[T constraints.Ordered](x, y T) T {
|
|
||||||
if x < y {
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
return y
|
|
||||||
}
|
|
||||||
|
|
||||||
func Max[T constraints.Ordered](x, y T) T {
|
|
||||||
if x < y {
|
|
||||||
return y
|
|
||||||
}
|
|
||||||
return x
|
|
||||||
}
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
|
@ -71,39 +70,8 @@ type ExtendedConn interface {
|
||||||
net.Conn
|
net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
type CloseHandlerFunc = func(it error)
|
|
||||||
|
|
||||||
func AppendClose(parent CloseHandlerFunc, onClose CloseHandlerFunc) CloseHandlerFunc {
|
|
||||||
if onClose == nil {
|
|
||||||
panic("nil onClose")
|
|
||||||
}
|
|
||||||
if parent == nil {
|
|
||||||
return onClose
|
|
||||||
}
|
|
||||||
return func(it error) {
|
|
||||||
onClose(it)
|
|
||||||
parent(it)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func OnceClose(onClose CloseHandlerFunc) CloseHandlerFunc {
|
|
||||||
var once sync.Once
|
|
||||||
return func(it error) {
|
|
||||||
once.Do(func() {
|
|
||||||
onClose(it)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated: Use TCPConnectionHandlerEx instead.
|
|
||||||
type TCPConnectionHandler interface {
|
type TCPConnectionHandler interface {
|
||||||
NewConnection(ctx context.Context, conn net.Conn,
|
NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error
|
||||||
//nolint:staticcheck
|
|
||||||
metadata M.Metadata) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type TCPConnectionHandlerEx interface {
|
|
||||||
NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetPacketConn interface {
|
type NetPacketConn interface {
|
||||||
|
@ -117,26 +85,12 @@ type BindPacketConn interface {
|
||||||
net.Conn
|
net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use UDPHandlerEx instead.
|
|
||||||
type UDPHandler interface {
|
type UDPHandler interface {
|
||||||
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer,
|
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
|
||||||
//nolint:staticcheck
|
|
||||||
metadata M.Metadata) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UDPHandlerEx interface {
|
|
||||||
NewPacketEx(buffer *buf.Buffer, source M.Socksaddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated: Use UDPConnectionHandlerEx instead.
|
|
||||||
type UDPConnectionHandler interface {
|
type UDPConnectionHandler interface {
|
||||||
NewPacketConnection(ctx context.Context, conn PacketConn,
|
NewPacketConnection(ctx context.Context, conn PacketConn, metadata M.Metadata) error
|
||||||
//nolint:staticcheck
|
|
||||||
metadata M.Metadata) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type UDPConnectionHandlerEx interface {
|
|
||||||
NewPacketConnectionEx(ctx context.Context, conn PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CachedReader interface {
|
type CachedReader interface {
|
||||||
|
@ -147,6 +101,11 @@ type CachedPacketReader interface {
|
||||||
ReadCachedPacket() *PacketBuffer
|
ReadCachedPacket() *PacketBuffer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PacketBuffer struct {
|
||||||
|
Buffer *buf.Buffer
|
||||||
|
Destination M.Socksaddr
|
||||||
|
}
|
||||||
|
|
||||||
type WithUpstreamReader interface {
|
type WithUpstreamReader interface {
|
||||||
UpstreamReader() any
|
UpstreamReader() any
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,10 @@ type Dialer interface {
|
||||||
ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error)
|
ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PayloadDialer interface {
|
||||||
|
DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
type ParallelDialer interface {
|
type ParallelDialer interface {
|
||||||
Dialer
|
Dialer
|
||||||
DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error)
|
DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error)
|
||||||
|
|
|
@ -15,39 +15,19 @@ type ReadWaitOptions struct {
|
||||||
MTU int
|
MTU int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewReadWaitOptions(source any, destination any) ReadWaitOptions {
|
|
||||||
return ReadWaitOptions{
|
|
||||||
FrontHeadroom: CalculateFrontHeadroom(destination),
|
|
||||||
RearHeadroom: CalculateRearHeadroom(destination),
|
|
||||||
MTU: CalculateMTU(source, destination),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o ReadWaitOptions) NeedHeadroom() bool {
|
func (o ReadWaitOptions) NeedHeadroom() bool {
|
||||||
return o.FrontHeadroom > 0 || o.RearHeadroom > 0
|
return o.FrontHeadroom > 0 || o.RearHeadroom > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o ReadWaitOptions) Copy(buffer *buf.Buffer) *buf.Buffer {
|
|
||||||
if o.FrontHeadroom > buffer.Start() ||
|
|
||||||
o.RearHeadroom > buffer.FreeLen() {
|
|
||||||
newBuffer := o.newBuffer(buf.UDPBufferSize, false)
|
|
||||||
newBuffer.Write(buffer.Bytes())
|
|
||||||
buffer.Release()
|
|
||||||
return newBuffer
|
|
||||||
} else {
|
|
||||||
return buffer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
|
func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
|
||||||
return o.newBuffer(buf.BufferSize, true)
|
return o.newBuffer(buf.BufferSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
|
func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
|
||||||
return o.newBuffer(buf.UDPBufferSize, true)
|
return o.newBuffer(buf.UDPBufferSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o ReadWaitOptions) newBuffer(defaultBufferSize int, reserve bool) *buf.Buffer {
|
func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer {
|
||||||
var bufferSize int
|
var bufferSize int
|
||||||
if o.MTU > 0 {
|
if o.MTU > 0 {
|
||||||
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
|
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
|
||||||
|
@ -58,7 +38,7 @@ func (o ReadWaitOptions) newBuffer(defaultBufferSize int, reserve bool) *buf.Buf
|
||||||
if o.FrontHeadroom > 0 {
|
if o.FrontHeadroom > 0 {
|
||||||
buffer.Resize(o.FrontHeadroom, 0)
|
buffer.Resize(o.FrontHeadroom, 0)
|
||||||
}
|
}
|
||||||
if o.RearHeadroom > 0 && reserve {
|
if o.RearHeadroom > 0 {
|
||||||
buffer.Reserve(o.RearHeadroom)
|
buffer.Reserve(o.RearHeadroom)
|
||||||
}
|
}
|
||||||
return buffer
|
return buffer
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
package network
|
package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
)
|
)
|
||||||
|
@ -16,75 +13,17 @@ type HandshakeSuccess interface {
|
||||||
HandshakeSuccess() error
|
HandshakeSuccess() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConnHandshakeSuccess interface {
|
func ReportHandshakeFailure(conn any, err error) error {
|
||||||
ConnHandshakeSuccess(conn net.Conn) error
|
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](conn); isHandshakeConn {
|
||||||
}
|
|
||||||
|
|
||||||
type PacketConnHandshakeSuccess interface {
|
|
||||||
PacketConnHandshakeSuccess(conn net.PacketConn) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReportHandshakeFailure(reporter any, err error) error {
|
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn {
|
|
||||||
return E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error {
|
return E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error {
|
||||||
return E.Cause(err, "write handshake failure")
|
return E.Cause(err, "write handshake failure")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CloseOnHandshakeFailure(reporter io.Closer, onClose CloseHandlerFunc, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn {
|
|
||||||
hErr := handshakeConn.HandshakeFailure(err)
|
|
||||||
err = E.Append(err, hErr, func(err error) error {
|
|
||||||
if closer, isCloser := reporter.(io.Closer); isCloser {
|
|
||||||
err = E.Append(err, closer.Close(), func(err error) error {
|
|
||||||
return E.Cause(err, "close")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return E.Cause(err, "write handshake failure")
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
if tcpConn, isTCPConn := common.Cast[interface {
|
|
||||||
SetLinger(sec int) error
|
|
||||||
}](reporter); isTCPConn {
|
|
||||||
tcpConn.SetLinger(0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = E.Append(err, reporter.Close(), func(err error) error {
|
|
||||||
return E.Cause(err, "close")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if onClose != nil {
|
|
||||||
onClose(err)
|
|
||||||
}
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: use ReportConnHandshakeSuccess/ReportPacketConnHandshakeSuccess instead
|
func ReportHandshakeSuccess(conn any) error {
|
||||||
func ReportHandshakeSuccess(reporter any) error {
|
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](conn); isHandshakeConn {
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
|
|
||||||
return handshakeConn.HandshakeSuccess()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReportConnHandshakeSuccess(reporter any, conn net.Conn) error {
|
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[ConnHandshakeSuccess](reporter); isHandshakeConn {
|
|
||||||
return handshakeConn.ConnHandshakeSuccess(conn)
|
|
||||||
}
|
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
|
|
||||||
return handshakeConn.HandshakeSuccess()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReportPacketConnHandshakeSuccess(reporter any, conn net.PacketConn) error {
|
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[PacketConnHandshakeSuccess](reporter); isHandshakeConn {
|
|
||||||
return handshakeConn.PacketConnHandshakeSuccess(conn)
|
|
||||||
}
|
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
|
|
||||||
return handshakeConn.HandshakeSuccess()
|
return handshakeConn.HandshakeSuccess()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
package network
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PacketBuffer struct {
|
|
||||||
Buffer *buf.Buffer
|
|
||||||
Destination M.Socksaddr
|
|
||||||
}
|
|
||||||
|
|
||||||
var packetPool = sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
return new(PacketBuffer)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPacketBuffer() *PacketBuffer {
|
|
||||||
return packetPool.Get().(*PacketBuffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func PutPacketBuffer(packet *PacketBuffer) {
|
|
||||||
*packet = PacketBuffer{}
|
|
||||||
packetPool.Put(packet)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReleaseMultiPacketBuffer(packetBuffers []*PacketBuffer) {
|
|
||||||
for _, packet := range packetBuffers {
|
|
||||||
packet.Buffer.Release()
|
|
||||||
PutPacketBuffer(packet)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -11,7 +11,6 @@ type ThreadUnsafeWriter interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use ReadWaiter interface instead.
|
// Deprecated: Use ReadWaiter interface instead.
|
||||||
|
|
||||||
type ThreadSafeReader interface {
|
type ThreadSafeReader interface {
|
||||||
// Deprecated: Use ReadWaiter interface instead.
|
// Deprecated: Use ReadWaiter interface instead.
|
||||||
ReadBufferThreadSafe() (buffer *buf.Buffer, err error)
|
ReadBufferThreadSafe() (buffer *buf.Buffer, err error)
|
||||||
|
@ -19,6 +18,7 @@ type ThreadSafeReader interface {
|
||||||
|
|
||||||
// Deprecated: Use ReadWaiter interface instead.
|
// Deprecated: Use ReadWaiter interface instead.
|
||||||
type ThreadSafePacketReader interface {
|
type ThreadSafePacketReader interface {
|
||||||
|
// Deprecated: Use ReadWaiter interface instead.
|
||||||
ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error)
|
ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,6 @@ import (
|
||||||
"github.com/sagernet/sing/common/logger"
|
"github.com/sagernet/sing/common/logger"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
"github.com/sagernet/sing/common/x/list"
|
|
||||||
"github.com/sagernet/sing/service"
|
"github.com/sagernet/sing/service"
|
||||||
"github.com/sagernet/sing/service/pause"
|
"github.com/sagernet/sing/service/pause"
|
||||||
)
|
)
|
||||||
|
@ -27,7 +26,6 @@ type Options struct {
|
||||||
Logger logger.Logger
|
Logger logger.Logger
|
||||||
Server M.Socksaddr
|
Server M.Socksaddr
|
||||||
Interval time.Duration
|
Interval time.Duration
|
||||||
Timeout time.Duration
|
|
||||||
WriteToSystem bool
|
WriteToSystem bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,11 +39,8 @@ type Service struct {
|
||||||
server M.Socksaddr
|
server M.Socksaddr
|
||||||
writeToSystem bool
|
writeToSystem bool
|
||||||
ticker *time.Ticker
|
ticker *time.Ticker
|
||||||
interval time.Duration
|
|
||||||
timeout time.Duration
|
|
||||||
clockOffset time.Duration
|
clockOffset time.Duration
|
||||||
pause pause.Manager
|
pause pause.Manager
|
||||||
pauseCallback *list.Element[pause.Callback]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(options Options) *Service {
|
func NewService(options Options) *Service {
|
||||||
|
@ -85,8 +80,7 @@ func NewService(options Options) *Service {
|
||||||
logger: options.Logger,
|
logger: options.Logger,
|
||||||
writeToSystem: options.WriteToSystem,
|
writeToSystem: options.WriteToSystem,
|
||||||
server: destination,
|
server: destination,
|
||||||
interval: interval,
|
ticker: time.NewTicker(interval),
|
||||||
timeout: options.Timeout,
|
|
||||||
pause: service.FromContext[pause.Manager](ctx),
|
pause: service.FromContext[pause.Manager](ctx),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -94,15 +88,10 @@ func NewService(options Options) *Service {
|
||||||
func (s *Service) Start() error {
|
func (s *Service) Start() error {
|
||||||
err := s.update()
|
err := s.update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error(E.Cause(err, "initialize time"))
|
return E.Cause(err, "initialize time")
|
||||||
} else {
|
|
||||||
s.logger.Info("updated time: ", s.TimeFunc()().Local().Format(TimeLayout))
|
|
||||||
}
|
}
|
||||||
s.ticker = time.NewTicker(s.interval)
|
s.logger.Info("updated time: ", s.TimeFunc()().Local().Format(TimeLayout))
|
||||||
go s.loopUpdate()
|
go s.loopUpdate()
|
||||||
if s.pause != nil {
|
|
||||||
s.pauseCallback = pause.RegisterTicker(s.pause, s.ticker, s.interval, s.updateOnce)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,29 +114,25 @@ func (s *Service) loopUpdate() {
|
||||||
return
|
return
|
||||||
case <-s.ticker.C:
|
case <-s.ticker.C:
|
||||||
}
|
}
|
||||||
s.updateOnce()
|
if s.pause != nil {
|
||||||
}
|
s.pause.WaitActive()
|
||||||
}
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
func (s *Service) updateOnce() {
|
return
|
||||||
err := s.update()
|
default:
|
||||||
if err == nil {
|
}
|
||||||
s.logger.Info("updated time: ", s.TimeFunc()().Local().Format(TimeLayout))
|
}
|
||||||
} else {
|
err := s.update()
|
||||||
s.logger.Error("update time: ", err)
|
if err == nil {
|
||||||
|
s.logger.Debug("updated time: ", s.TimeFunc()().Local().Format(TimeLayout))
|
||||||
|
} else {
|
||||||
|
s.logger.Warn("update time: ", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) update() error {
|
func (s *Service) update() error {
|
||||||
ctx := s.ctx
|
response, err := Exchange(s.ctx, s.dialer, s.server)
|
||||||
var cancel context.CancelFunc
|
|
||||||
if s.timeout > 0 {
|
|
||||||
ctx, cancel = context.WithTimeout(ctx, s.timeout)
|
|
||||||
}
|
|
||||||
response, err := Exchange(ctx, s.dialer, s.server)
|
|
||||||
if cancel != nil {
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -155,7 +140,7 @@ func (s *Service) update() error {
|
||||||
if s.writeToSystem {
|
if s.writeToSystem {
|
||||||
writeErr := SetSystemTime(s.TimeFunc()())
|
writeErr := SetSystemTime(s.TimeFunc()())
|
||||||
if writeErr != nil {
|
if writeErr != nil {
|
||||||
s.logger.Error("write time to system: ", writeErr)
|
s.logger.Warn("write time to system: ", writeErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
package ntp
|
|
||||||
|
|
||||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
|
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/windows/win32/api/sysinfoapi/nf-sysinfoapi-setsystemtime
|
|
||||||
//sys setSystemTime(lpSystemTime *windows.Systemtime) (err error) = kernel32.SetSystemTime
|
|
|
@ -2,12 +2,12 @@ package ntp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetSystemTime(nowTime time.Time) error {
|
func SetSystemTime(nowTime time.Time) error {
|
||||||
nowTime = nowTime.UTC()
|
|
||||||
var systemTime windows.Systemtime
|
var systemTime windows.Systemtime
|
||||||
systemTime.Year = uint16(nowTime.Year())
|
systemTime.Year = uint16(nowTime.Year())
|
||||||
systemTime.Month = uint16(nowTime.Month())
|
systemTime.Month = uint16(nowTime.Month())
|
||||||
|
@ -16,5 +16,17 @@ func SetSystemTime(nowTime time.Time) error {
|
||||||
systemTime.Minute = uint16(nowTime.Minute())
|
systemTime.Minute = uint16(nowTime.Minute())
|
||||||
systemTime.Second = uint16(nowTime.Second())
|
systemTime.Second = uint16(nowTime.Second())
|
||||||
systemTime.Milliseconds = uint16(nowTime.UnixMilli() - nowTime.Unix()*1000)
|
systemTime.Milliseconds = uint16(nowTime.UnixMilli() - nowTime.Unix()*1000)
|
||||||
return setSystemTime(&systemTime)
|
|
||||||
|
dllKernel32 := windows.NewLazySystemDLL("kernel32.dll")
|
||||||
|
proc := dllKernel32.NewProc("SetSystemTime")
|
||||||
|
|
||||||
|
_, _, err := proc.Call(
|
||||||
|
uintptr(unsafe.Pointer(&systemTime)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil && err.Error() != "The operation completed successfully." {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
// Code generated by 'go generate'; DO NOT EDIT.
|
|
||||||
|
|
||||||
package ntp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ unsafe.Pointer
|
|
||||||
|
|
||||||
// Do the interface allocations only once for common
|
|
||||||
// Errno values.
|
|
||||||
const (
|
|
||||||
errnoERROR_IO_PENDING = 997
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
|
||||||
errERROR_EINVAL error = syscall.EINVAL
|
|
||||||
)
|
|
||||||
|
|
||||||
// errnoErr returns common boxed Errno values, to prevent
|
|
||||||
// allocations at runtime.
|
|
||||||
func errnoErr(e syscall.Errno) error {
|
|
||||||
switch e {
|
|
||||||
case 0:
|
|
||||||
return errERROR_EINVAL
|
|
||||||
case errnoERROR_IO_PENDING:
|
|
||||||
return errERROR_IO_PENDING
|
|
||||||
}
|
|
||||||
// TODO: add more here, after collecting data on the common
|
|
||||||
// error values see on Windows. (perhaps when running
|
|
||||||
// all.bat?)
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
|
||||||
|
|
||||||
procSetSystemTime = modkernel32.NewProc("SetSystemTime")
|
|
||||||
)
|
|
||||||
|
|
||||||
func setSystemTime(lpSystemTime *windows.Systemtime) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procSetSystemTime.Addr(), 1, uintptr(unsafe.Pointer(lpSystemTime)), 0, 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,20 +0,0 @@
|
||||||
//go:build go1.21
|
|
||||||
|
|
||||||
package common
|
|
||||||
|
|
||||||
import "sync"
|
|
||||||
|
|
||||||
// OnceFunc is a wrapper around sync.OnceFunc.
|
|
||||||
func OnceFunc(f func()) func() {
|
|
||||||
return sync.OnceFunc(f)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnceValue is a wrapper around sync.OnceValue.
|
|
||||||
func OnceValue[T any](f func() T) func() T {
|
|
||||||
return sync.OnceValue(f)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnceValues is a wrapper around sync.OnceValues.
|
|
||||||
func OnceValues[T1, T2 any](f func() (T1, T2)) func() (T1, T2) {
|
|
||||||
return sync.OnceValues(f)
|
|
||||||
}
|
|
|
@ -1,104 +0,0 @@
|
||||||
// Copyright 2022 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build go1.20 && !go1.21
|
|
||||||
|
|
||||||
package common
|
|
||||||
|
|
||||||
import "sync"
|
|
||||||
|
|
||||||
// OnceFunc returns a function that invokes f only once. The returned function
|
|
||||||
// may be called concurrently.
|
|
||||||
//
|
|
||||||
// If f panics, the returned function will panic with the same value on every call.
|
|
||||||
func OnceFunc(f func()) func() {
|
|
||||||
var (
|
|
||||||
once sync.Once
|
|
||||||
valid bool
|
|
||||||
p any
|
|
||||||
)
|
|
||||||
// Construct the inner closure just once to reduce costs on the fast path.
|
|
||||||
g := func() {
|
|
||||||
defer func() {
|
|
||||||
p = recover()
|
|
||||||
if !valid {
|
|
||||||
// Re-panic immediately so on the first call the user gets a
|
|
||||||
// complete stack trace into f.
|
|
||||||
panic(p)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
f()
|
|
||||||
f = nil // Do not keep f alive after invoking it.
|
|
||||||
valid = true // Set only if f does not panic.
|
|
||||||
}
|
|
||||||
return func() {
|
|
||||||
once.Do(g)
|
|
||||||
if !valid {
|
|
||||||
panic(p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnceValue returns a function that invokes f only once and returns the value
|
|
||||||
// returned by f. The returned function may be called concurrently.
|
|
||||||
//
|
|
||||||
// If f panics, the returned function will panic with the same value on every call.
|
|
||||||
func OnceValue[T any](f func() T) func() T {
|
|
||||||
var (
|
|
||||||
once sync.Once
|
|
||||||
valid bool
|
|
||||||
p any
|
|
||||||
result T
|
|
||||||
)
|
|
||||||
g := func() {
|
|
||||||
defer func() {
|
|
||||||
p = recover()
|
|
||||||
if !valid {
|
|
||||||
panic(p)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
result = f()
|
|
||||||
f = nil
|
|
||||||
valid = true
|
|
||||||
}
|
|
||||||
return func() T {
|
|
||||||
once.Do(g)
|
|
||||||
if !valid {
|
|
||||||
panic(p)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnceValues returns a function that invokes f only once and returns the values
|
|
||||||
// returned by f. The returned function may be called concurrently.
|
|
||||||
//
|
|
||||||
// If f panics, the returned function will panic with the same value on every call.
|
|
||||||
func OnceValues[T1, T2 any](f func() (T1, T2)) func() (T1, T2) {
|
|
||||||
var (
|
|
||||||
once sync.Once
|
|
||||||
valid bool
|
|
||||||
p any
|
|
||||||
r1 T1
|
|
||||||
r2 T2
|
|
||||||
)
|
|
||||||
g := func() {
|
|
||||||
defer func() {
|
|
||||||
p = recover()
|
|
||||||
if !valid {
|
|
||||||
panic(p)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
r1, r2 = f()
|
|
||||||
f = nil
|
|
||||||
valid = true
|
|
||||||
}
|
|
||||||
return func() (T1, T2) {
|
|
||||||
once.Do(g)
|
|
||||||
if !valid {
|
|
||||||
panic(p)
|
|
||||||
}
|
|
||||||
return r1, r2
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -14,24 +14,24 @@ import (
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Deadline is an abstraction for handling timeouts.
|
// pipeDeadline is an abstraction for handling timeouts.
|
||||||
type Deadline struct {
|
type pipeDeadline struct {
|
||||||
mu sync.Mutex // Guards timer and cancel
|
mu sync.Mutex // Guards timer and cancel
|
||||||
timer *time.Timer
|
timer *time.Timer
|
||||||
cancel chan struct{} // Must be non-nil
|
cancel chan struct{} // Must be non-nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeDeadline() Deadline {
|
func makePipeDeadline() pipeDeadline {
|
||||||
return Deadline{cancel: make(chan struct{})}
|
return pipeDeadline{cancel: make(chan struct{})}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set sets the point in time when the deadline will time out.
|
// set sets the point in time when the deadline will time out.
|
||||||
// A timeout event is signaled by closing the channel returned by waiter.
|
// A timeout event is signaled by closing the channel returned by waiter.
|
||||||
// Once a timeout has occurred, the deadline can be refreshed by specifying a
|
// Once a timeout has occurred, the deadline can be refreshed by specifying a
|
||||||
// t value in the future.
|
// t value in the future.
|
||||||
//
|
//
|
||||||
// A zero value for t prevents timeout.
|
// A zero value for t prevents timeout.
|
||||||
func (d *Deadline) Set(t time.Time) {
|
func (d *pipeDeadline) set(t time.Time) {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
@ -66,8 +66,8 @@ func (d *Deadline) Set(t time.Time) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait returns a channel that is closed when the deadline is exceeded.
|
// wait returns a channel that is closed when the deadline is exceeded.
|
||||||
func (d *Deadline) Wait() chan struct{} {
|
func (d *pipeDeadline) wait() chan struct{} {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
return d.cancel
|
return d.cancel
|
||||||
|
@ -104,8 +104,8 @@ type pipe struct {
|
||||||
localDone chan struct{}
|
localDone chan struct{}
|
||||||
remoteDone <-chan struct{}
|
remoteDone <-chan struct{}
|
||||||
|
|
||||||
readDeadline Deadline
|
readDeadline pipeDeadline
|
||||||
writeDeadline Deadline
|
writeDeadline pipeDeadline
|
||||||
|
|
||||||
readWaitOptions N.ReadWaitOptions
|
readWaitOptions N.ReadWaitOptions
|
||||||
}
|
}
|
||||||
|
@ -127,15 +127,15 @@ func Pipe() (net.Conn, net.Conn) {
|
||||||
rdRx: cb1, rdTx: cn1,
|
rdRx: cb1, rdTx: cn1,
|
||||||
wrTx: cb2, wrRx: cn2,
|
wrTx: cb2, wrRx: cn2,
|
||||||
localDone: done1, remoteDone: done2,
|
localDone: done1, remoteDone: done2,
|
||||||
readDeadline: MakeDeadline(),
|
readDeadline: makePipeDeadline(),
|
||||||
writeDeadline: MakeDeadline(),
|
writeDeadline: makePipeDeadline(),
|
||||||
}
|
}
|
||||||
p2 := &pipe{
|
p2 := &pipe{
|
||||||
rdRx: cb2, rdTx: cn2,
|
rdRx: cb2, rdTx: cn2,
|
||||||
wrTx: cb1, wrRx: cn1,
|
wrTx: cb1, wrRx: cn1,
|
||||||
localDone: done2, remoteDone: done1,
|
localDone: done2, remoteDone: done1,
|
||||||
readDeadline: MakeDeadline(),
|
readDeadline: makePipeDeadline(),
|
||||||
writeDeadline: MakeDeadline(),
|
writeDeadline: makePipeDeadline(),
|
||||||
}
|
}
|
||||||
return p1, p2
|
return p1, p2
|
||||||
}
|
}
|
||||||
|
@ -157,7 +157,7 @@ func (p *pipe) read(b []byte) (n int, err error) {
|
||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
case isClosedChan(p.remoteDone):
|
case isClosedChan(p.remoteDone):
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
case isClosedChan(p.readDeadline.Wait()):
|
case isClosedChan(p.readDeadline.wait()):
|
||||||
return 0, os.ErrDeadlineExceeded
|
return 0, os.ErrDeadlineExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,7 +170,7 @@ func (p *pipe) read(b []byte) (n int, err error) {
|
||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
case <-p.remoteDone:
|
case <-p.remoteDone:
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
case <-p.readDeadline.Wait():
|
case <-p.readDeadline.wait():
|
||||||
return 0, os.ErrDeadlineExceeded
|
return 0, os.ErrDeadlineExceeded
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -189,7 +189,7 @@ func (p *pipe) write(b []byte) (n int, err error) {
|
||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
case isClosedChan(p.remoteDone):
|
case isClosedChan(p.remoteDone):
|
||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
case isClosedChan(p.writeDeadline.Wait()):
|
case isClosedChan(p.writeDeadline.wait()):
|
||||||
return 0, os.ErrDeadlineExceeded
|
return 0, os.ErrDeadlineExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,7 +205,7 @@ func (p *pipe) write(b []byte) (n int, err error) {
|
||||||
return n, io.ErrClosedPipe
|
return n, io.ErrClosedPipe
|
||||||
case <-p.remoteDone:
|
case <-p.remoteDone:
|
||||||
return n, io.ErrClosedPipe
|
return n, io.ErrClosedPipe
|
||||||
case <-p.writeDeadline.Wait():
|
case <-p.writeDeadline.wait():
|
||||||
return n, os.ErrDeadlineExceeded
|
return n, os.ErrDeadlineExceeded
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -216,8 +216,8 @@ func (p *pipe) SetDeadline(t time.Time) error {
|
||||||
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
|
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
|
||||||
return io.ErrClosedPipe
|
return io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
p.readDeadline.Set(t)
|
p.readDeadline.set(t)
|
||||||
p.writeDeadline.Set(t)
|
p.writeDeadline.set(t)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,7 +225,7 @@ func (p *pipe) SetReadDeadline(t time.Time) error {
|
||||||
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
|
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
|
||||||
return io.ErrClosedPipe
|
return io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
p.readDeadline.Set(t)
|
p.readDeadline.set(t)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,7 +233,7 @@ func (p *pipe) SetWriteDeadline(t time.Time) error {
|
||||||
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
|
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
|
||||||
return io.ErrClosedPipe
|
return io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
p.writeDeadline.Set(t)
|
p.writeDeadline.set(t)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||||
return nil, io.ErrClosedPipe
|
return nil, io.ErrClosedPipe
|
||||||
case isClosedChan(p.remoteDone):
|
case isClosedChan(p.remoteDone):
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
case isClosedChan(p.readDeadline.Wait()):
|
case isClosedChan(p.readDeadline.wait()):
|
||||||
return nil, os.ErrDeadlineExceeded
|
return nil, os.ErrDeadlineExceeded
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
|
@ -49,7 +49,7 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||||
return nil, io.ErrClosedPipe
|
return nil, io.ErrClosedPipe
|
||||||
case <-p.remoteDone:
|
case <-p.remoteDone:
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
case <-p.readDeadline.Wait():
|
case <-p.readDeadline.wait():
|
||||||
return nil, os.ErrDeadlineExceeded
|
return nil, os.ErrDeadlineExceeded
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,5 +20,6 @@ func InitializeSeed() {
|
||||||
func initializeSeed() {
|
func initializeSeed() {
|
||||||
var seed int64
|
var seed int64
|
||||||
common.Must(binary.Read(rand.Reader, binary.LittleEndian, &seed))
|
common.Must(binary.Read(rand.Reader, binary.LittleEndian, &seed))
|
||||||
|
//goland:noinspection GoDeprecation
|
||||||
mRand.Seed(seed)
|
mRand.Seed(seed)
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRevertRanges(t *testing.T) {
|
func TestRevertRanges(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, testRange := range []struct {
|
for _, testRange := range []struct {
|
||||||
start, end int
|
start, end int
|
||||||
ranges []Range[int]
|
ranges []Range[int]
|
||||||
|
@ -78,7 +77,6 @@ func TestRevertRanges(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMergeRanges(t *testing.T) {
|
func TestMergeRanges(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, testRange := range []struct {
|
for _, testRange := range []struct {
|
||||||
ranges []Range[int]
|
ranges []Range[int]
|
||||||
expected []Range[int]
|
expected []Range[int]
|
||||||
|
@ -146,7 +144,6 @@ func TestMergeRanges(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExcludeRanges(t *testing.T) {
|
func TestExcludeRanges(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, testRange := range []struct {
|
for _, testRange := range []struct {
|
||||||
ranges []Range[int]
|
ranges []Range[int]
|
||||||
exclude []Range[int]
|
exclude []Range[int]
|
||||||
|
|
27
common/rw/count.go
Normal file
27
common/rw/count.go
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
package rw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ReadCounter struct {
|
||||||
|
io.Reader
|
||||||
|
count int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReadCounter) Read(p []byte) (n int, err error) {
|
||||||
|
n, err = r.Reader.Read(p)
|
||||||
|
if n > 0 {
|
||||||
|
atomic.AddInt64(&r.count, int64(n))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReadCounter) Count() int64 {
|
||||||
|
return r.count
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReadCounter) Reset() {
|
||||||
|
atomic.StoreInt64(&r.count, 0)
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package network
|
package rw
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
|
@ -9,33 +9,8 @@ import (
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func IsFile(path string) bool {
|
func FileExists(path string) bool {
|
||||||
stat, err := os.Stat(path)
|
return common.Error(os.Stat(path)) == nil
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return !stat.IsDir()
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsDir(path string) bool {
|
|
||||||
stat, err := os.Stat(path)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return stat.IsDir()
|
|
||||||
}
|
|
||||||
|
|
||||||
func MkdirParent(path string) error {
|
|
||||||
if strings.Contains(path, string(os.PathSeparator)) {
|
|
||||||
parent := path[:strings.LastIndex(path, string(os.PathSeparator))]
|
|
||||||
if !IsDir(parent) {
|
|
||||||
err := os.MkdirAll(parent, 0o755)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyFile(srcPath, dstPath string) error {
|
func CopyFile(srcPath, dstPath string) error {
|
||||||
|
@ -44,29 +19,23 @@ func CopyFile(srcPath, dstPath string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer srcFile.Close()
|
defer srcFile.Close()
|
||||||
srcStat, err := srcFile.Stat()
|
if strings.Contains(dstPath, "/") {
|
||||||
if err != nil {
|
parent := dstPath[:strings.LastIndex(dstPath, "/")]
|
||||||
return err
|
if !FileExists(parent) {
|
||||||
|
err = os.MkdirAll(parent, 0o755)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
err = MkdirParent(dstPath)
|
dstFile, err := os.Create(dstPath)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
dstFile, err := os.OpenFile(dstPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, srcStat.Mode())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer dstFile.Close()
|
defer dstFile.Close()
|
||||||
_, err = io.Copy(dstFile, srcFile)
|
return common.Error(io.Copy(dstFile, srcFile))
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: use IsFile and IsDir instead.
|
|
||||||
func FileExists(path string) bool {
|
|
||||||
return common.Error(os.Stat(path)) == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated: use MkdirParent and os.WriteFile instead.
|
|
||||||
func WriteFile(path string, content []byte) error {
|
func WriteFile(path string, content []byte) error {
|
||||||
if strings.Contains(path, "/") {
|
if strings.Contains(path, "/") {
|
||||||
parent := path[:strings.LastIndex(path, "/")]
|
parent := path[:strings.LastIndex(path, "/")]
|
||||||
|
@ -87,7 +56,6 @@ func WriteFile(path string, content []byte) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
func ReadJSON(path string, data any) error {
|
func ReadJSON(path string, data any) error {
|
||||||
content, err := os.ReadFile(path)
|
content, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -100,7 +68,6 @@ func ReadJSON(path string, data any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
func WriteJSON(path string, data any) error {
|
func WriteJSON(path string, data any) error {
|
||||||
content, err := json.Marshal(data)
|
content, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -6,16 +6,14 @@ import (
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SkipN(reader io.Reader, size int) error {
|
|
||||||
return common.Error(io.CopyN(Discard, reader, int64(size)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
func Skip(reader io.Reader) error {
|
func Skip(reader io.Reader) error {
|
||||||
return SkipN(reader, 1)
|
return SkipN(reader, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
func SkipN(reader io.Reader, size int) error {
|
||||||
|
return common.Error(io.CopyN(Discard, reader, int64(size)))
|
||||||
|
}
|
||||||
|
|
||||||
func ReadByte(reader io.Reader) (byte, error) {
|
func ReadByte(reader io.Reader) (byte, error) {
|
||||||
if br, isBr := reader.(io.ByteReader); isBr {
|
if br, isBr := reader.(io.ByteReader); isBr {
|
||||||
return br.ReadByte()
|
return br.ReadByte()
|
||||||
|
@ -27,7 +25,6 @@ func ReadByte(reader io.Reader) (byte, error) {
|
||||||
return b[0], nil
|
return b[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
func ReadBytes(reader io.Reader, size int) ([]byte, error) {
|
func ReadBytes(reader io.Reader, size int) ([]byte, error) {
|
||||||
b := make([]byte, size)
|
b := make([]byte, size)
|
||||||
if err := common.Error(io.ReadFull(reader, b)); err != nil {
|
if err := common.Error(io.ReadFull(reader, b)); err != nil {
|
||||||
|
@ -36,7 +33,6 @@ func ReadBytes(reader io.Reader, size int) ([]byte, error) {
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
func ReadString(reader io.Reader, size int) (string, error) {
|
func ReadString(reader io.Reader, size int) (string, error) {
|
||||||
b, err := ReadBytes(reader, size)
|
b, err := ReadBytes(reader, size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,14 +1,12 @@
|
||||||
package rw
|
package rw
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/binary"
|
|
||||||
"github.com/sagernet/sing/common/varbin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Deprecated: create a *bufio.Reader instead.
|
|
||||||
type stubByteReader struct {
|
type stubByteReader struct {
|
||||||
io.Reader
|
io.Reader
|
||||||
}
|
}
|
||||||
|
@ -17,7 +15,6 @@ func (r stubByteReader) ReadByte() (byte, error) {
|
||||||
return ReadByte(r.Reader)
|
return ReadByte(r.Reader)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: create a *bufio.Reader instead.
|
|
||||||
func ToByteReader(reader io.Reader) io.ByteReader {
|
func ToByteReader(reader io.Reader) io.ByteReader {
|
||||||
if byteReader, ok := reader.(io.ByteReader); ok {
|
if byteReader, ok := reader.(io.ByteReader); ok {
|
||||||
return byteReader
|
return byteReader
|
||||||
|
@ -25,23 +22,40 @@ func ToByteReader(reader io.Reader) io.ByteReader {
|
||||||
return &stubByteReader{reader}
|
return &stubByteReader{reader}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use binary.ReadUvarint instead.
|
|
||||||
func ReadUVariant(reader io.Reader) (uint64, error) {
|
func ReadUVariant(reader io.Reader) (uint64, error) {
|
||||||
return binary.ReadUvarint(ToByteReader(reader))
|
return binary.ReadUvarint(ToByteReader(reader))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use varbin.UvarintLen instead.
|
|
||||||
func UVariantLen(x uint64) int {
|
func UVariantLen(x uint64) int {
|
||||||
return varbin.UvarintLen(x)
|
switch {
|
||||||
|
case x < 1<<(7*1):
|
||||||
|
return 1
|
||||||
|
case x < 1<<(7*2):
|
||||||
|
return 2
|
||||||
|
case x < 1<<(7*3):
|
||||||
|
return 3
|
||||||
|
case x < 1<<(7*4):
|
||||||
|
return 4
|
||||||
|
case x < 1<<(7*5):
|
||||||
|
return 5
|
||||||
|
case x < 1<<(7*6):
|
||||||
|
return 6
|
||||||
|
case x < 1<<(7*7):
|
||||||
|
return 7
|
||||||
|
case x < 1<<(7*8):
|
||||||
|
return 8
|
||||||
|
case x < 1<<(7*9):
|
||||||
|
return 9
|
||||||
|
default:
|
||||||
|
return 10
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use varbin.WriteUvarint instead.
|
|
||||||
func WriteUVariant(writer io.Writer, value uint64) error {
|
func WriteUVariant(writer io.Writer, value uint64) error {
|
||||||
var b [8]byte
|
var b [8]byte
|
||||||
return common.Error(writer.Write(b[:binary.PutUvarint(b[:], value)]))
|
return common.Error(writer.Write(b[:binary.PutUvarint(b[:], value)]))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use varbin.Write instead.
|
|
||||||
func WriteVString(writer io.Writer, value string) error {
|
func WriteVString(writer io.Writer, value string) error {
|
||||||
err := WriteUVariant(writer, uint64(len(value)))
|
err := WriteUVariant(writer, uint64(len(value)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -50,7 +64,6 @@ func WriteVString(writer io.Writer, value string) error {
|
||||||
return WriteString(writer, value)
|
return WriteString(writer, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use varbin.ReadValue instead.
|
|
||||||
func ReadVString(reader io.Reader) (string, error) {
|
func ReadVString(reader io.Reader) (string, error) {
|
||||||
length, err := binary.ReadUvarint(ToByteReader(reader))
|
length, err := binary.ReadUvarint(ToByteReader(reader))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -6,10 +6,20 @@ import (
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
var ZeroBytes = make([]byte, 1024)
|
var ZeroBytes = make([]byte, 1024)
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
func WriteByte(writer io.Writer, b byte) error {
|
||||||
|
return common.Error(writer.Write([]byte{b}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteBytes(writer io.Writer, b []byte) error {
|
||||||
|
return common.Error(writer.Write(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteZero(writer io.Writer) error {
|
||||||
|
return WriteByte(writer, 0)
|
||||||
|
}
|
||||||
|
|
||||||
func WriteZeroN(writer io.Writer, size int) error {
|
func WriteZeroN(writer io.Writer, size int) error {
|
||||||
var index int
|
var index int
|
||||||
for index < size {
|
for index < size {
|
||||||
|
@ -28,22 +38,6 @@ func WriteZeroN(writer io.Writer, size int) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
func WriteByte(writer io.Writer, b byte) error {
|
|
||||||
return common.Error(writer.Write([]byte{b}))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
func WriteBytes(writer io.Writer, b []byte) error {
|
|
||||||
return common.Error(writer.Write(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
func WriteZero(writer io.Writer) error {
|
|
||||||
return WriteByte(writer, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
func WriteString(writer io.Writer, str string) error {
|
func WriteString(writer io.Writer, str string) error {
|
||||||
return WriteBytes(writer, []byte(str))
|
return WriteBytes(writer, []byte(str))
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,9 +54,17 @@ func (g *Group) Concurrency(n int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) Run(ctx context.Context) error {
|
func (g *Group) Run(contextList ...context.Context) error {
|
||||||
|
return g.RunContextList(contextList)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Group) RunContextList(contextList []context.Context) error {
|
||||||
|
if len(contextList) == 0 {
|
||||||
|
contextList = append(contextList, context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
taskContext, taskFinish := common.ContextWithCancelCause(context.Background())
|
taskContext, taskFinish := common.ContextWithCancelCause(context.Background())
|
||||||
taskCancelContext, taskCancel := common.ContextWithCancelCause(ctx)
|
taskCancelContext, taskCancel := common.ContextWithCancelCause(context.Background())
|
||||||
|
|
||||||
var errorAccess sync.Mutex
|
var errorAccess sync.Mutex
|
||||||
var returnError error
|
var returnError error
|
||||||
|
@ -104,12 +112,10 @@ func (g *Group) Run(ctx context.Context) error {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
var upstreamErr bool
|
selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskCancelContext}, contextList...))
|
||||||
select {
|
|
||||||
case <-taskCancelContext.Done():
|
if selectedContext != 0 {
|
||||||
case <-ctx.Done():
|
taskCancel(upstreamErr)
|
||||||
upstreamErr = true
|
|
||||||
taskCancel(ctx.Err())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if g.cleanup != nil {
|
if g.cleanup != nil {
|
||||||
|
@ -118,8 +124,10 @@ func (g *Group) Run(ctx context.Context) error {
|
||||||
|
|
||||||
<-taskContext.Done()
|
<-taskContext.Done()
|
||||||
|
|
||||||
if upstreamErr {
|
if selectedContext != 0 {
|
||||||
return ctx.Err()
|
returnError = E.Append(returnError, upstreamErr, func(err error) error {
|
||||||
|
return E.Cause(err, "upstream")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return returnError
|
return returnError
|
||||||
|
|
|
@ -2,7 +2,6 @@ package task
|
||||||
|
|
||||||
import "context"
|
import "context"
|
||||||
|
|
||||||
// Deprecated: Use Group instead
|
|
||||||
func Run(ctx context.Context, tasks ...func() error) error {
|
func Run(ctx context.Context, tasks ...func() error) error {
|
||||||
var group Group
|
var group Group
|
||||||
for _, task := range tasks {
|
for _, task := range tasks {
|
||||||
|
@ -14,7 +13,6 @@ func Run(ctx context.Context, tasks ...func() error) error {
|
||||||
return group.Run(ctx)
|
return group.Run(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use Group instead
|
|
||||||
func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error {
|
func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error {
|
||||||
var group Group
|
var group Group
|
||||||
for _, task := range tasks {
|
for _, task := range tasks {
|
||||||
|
|
|
@ -2,7 +2,6 @@ package udpnat
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
@ -35,7 +34,5 @@ func (c *conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er
|
||||||
return
|
return
|
||||||
case <-c.ctx.Done():
|
case <-c.ctx.Done():
|
||||||
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
||||||
case <-c.readDeadline.Wait():
|
|
||||||
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,26 +13,20 @@ import (
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
"github.com/sagernet/sing/common/pipe"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Deprecated: Use N.UDPConnectionHandler instead.
|
|
||||||
//
|
|
||||||
//nolint:staticcheck
|
|
||||||
type Handler interface {
|
type Handler interface {
|
||||||
N.UDPConnectionHandler
|
N.UDPConnectionHandler
|
||||||
E.Handler
|
E.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
type Service[K comparable] struct {
|
type Service[K comparable] struct {
|
||||||
nat *cache.LruCache[K, *conn]
|
nat *cache.LruCache[K, *conn]
|
||||||
handler Handler
|
handler Handler
|
||||||
handlerEx N.UDPConnectionHandlerEx
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use NewEx instead.
|
|
||||||
func New[K comparable](maxAge int64, handler Handler) *Service[K] {
|
func New[K comparable](maxAge int64, handler Handler) *Service[K] {
|
||||||
service := &Service[K]{
|
return &Service[K]{
|
||||||
nat: cache.New(
|
nat: cache.New(
|
||||||
cache.WithAge[K, *conn](maxAge),
|
cache.WithAge[K, *conn](maxAge),
|
||||||
cache.WithUpdateAgeOnGet[K, *conn](),
|
cache.WithUpdateAgeOnGet[K, *conn](),
|
||||||
|
@ -42,27 +36,11 @@ func New[K comparable](maxAge int64, handler Handler) *Service[K] {
|
||||||
),
|
),
|
||||||
handler: handler,
|
handler: handler,
|
||||||
}
|
}
|
||||||
return service
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewEx[K comparable](maxAge int64, handler N.UDPConnectionHandlerEx) *Service[K] {
|
|
||||||
service := &Service[K]{
|
|
||||||
nat: cache.New(
|
|
||||||
cache.WithAge[K, *conn](maxAge),
|
|
||||||
cache.WithUpdateAgeOnGet[K, *conn](),
|
|
||||||
cache.WithEvict[K, *conn](func(key K, conn *conn) {
|
|
||||||
conn.Close()
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
handlerEx: handler,
|
|
||||||
}
|
|
||||||
return service
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service[T]) WriteIsThreadUnsafe() {
|
func (s *Service[T]) WriteIsThreadUnsafe() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: don't use
|
|
||||||
func (s *Service[T]) NewPacketDirect(ctx context.Context, key T, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) {
|
func (s *Service[T]) NewPacketDirect(ctx context.Context, key T, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) {
|
||||||
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
|
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
|
||||||
return ctx, &DirectBackWriter{conn, natConn}
|
return ctx, &DirectBackWriter{conn, natConn}
|
||||||
|
@ -82,31 +60,18 @@ func (w *DirectBackWriter) Upstream() any {
|
||||||
return w.Source
|
return w.Source
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: use NewPacketEx instead.
|
|
||||||
func (s *Service[T]) NewPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) N.PacketWriter) {
|
func (s *Service[T]) NewPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) N.PacketWriter) {
|
||||||
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
|
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
|
||||||
return ctx, init(natConn)
|
return ctx, init(natConn)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service[T]) NewPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) N.PacketWriter) {
|
|
||||||
s.NewContextPacketEx(ctx, key, buffer, source, destination, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
|
|
||||||
return ctx, init(natConn)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated: Use NewPacketConnectionEx instead.
|
|
||||||
func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) {
|
func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) {
|
||||||
s.NewContextPacketEx(ctx, key, buffer, metadata.Source, metadata.Destination, init)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) {
|
|
||||||
c, loaded := s.nat.LoadOrStore(key, func() *conn {
|
c, loaded := s.nat.LoadOrStore(key, func() *conn {
|
||||||
c := &conn{
|
c := &conn{
|
||||||
data: make(chan packet, 64),
|
data: make(chan packet, 64),
|
||||||
localAddr: source,
|
localAddr: metadata.Source,
|
||||||
remoteAddr: destination,
|
remoteAddr: metadata.Destination,
|
||||||
readDeadline: pipe.MakeDeadline(),
|
|
||||||
}
|
}
|
||||||
c.ctx, c.cancel = common.ContextWithCancelCause(ctx)
|
c.ctx, c.cancel = common.ContextWithCancelCause(ctx)
|
||||||
return c
|
return c
|
||||||
|
@ -114,34 +79,26 @@ func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf.
|
||||||
if !loaded {
|
if !loaded {
|
||||||
ctx, c.source = init(c)
|
ctx, c.source = init(c)
|
||||||
go func() {
|
go func() {
|
||||||
if s.handlerEx != nil {
|
err := s.handler.NewPacketConnection(ctx, c, metadata)
|
||||||
s.handlerEx.NewPacketConnectionEx(ctx, c, source, destination, func(err error) {
|
if err != nil {
|
||||||
s.nat.Delete(key)
|
s.handler.NewError(ctx, err)
|
||||||
})
|
|
||||||
} else {
|
|
||||||
//nolint:staticcheck
|
|
||||||
err := s.handler.NewPacketConnection(ctx, c, M.Metadata{
|
|
||||||
Source: source,
|
|
||||||
Destination: destination,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
s.handler.NewError(ctx, err)
|
|
||||||
}
|
|
||||||
c.Close()
|
|
||||||
s.nat.Delete(key)
|
|
||||||
}
|
}
|
||||||
|
c.Close()
|
||||||
|
s.nat.Delete(key)
|
||||||
}()
|
}()
|
||||||
|
} else {
|
||||||
|
c.localAddr = metadata.Source
|
||||||
}
|
}
|
||||||
if common.Done(c.ctx) {
|
if common.Done(c.ctx) {
|
||||||
s.nat.Delete(key)
|
s.nat.Delete(key)
|
||||||
if !common.Done(ctx) {
|
if !common.Done(ctx) {
|
||||||
s.NewContextPacketEx(ctx, key, buffer, source, destination, init)
|
s.NewContextPacket(ctx, key, buffer, metadata, init)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.data <- packet{
|
c.data <- packet{
|
||||||
data: buffer,
|
data: buffer,
|
||||||
destination: destination,
|
destination: metadata.Destination,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,7 +116,6 @@ type conn struct {
|
||||||
localAddr M.Socksaddr
|
localAddr M.Socksaddr
|
||||||
remoteAddr M.Socksaddr
|
remoteAddr M.Socksaddr
|
||||||
source N.PacketWriter
|
source N.PacketWriter
|
||||||
readDeadline pipe.Deadline
|
|
||||||
readWaitOptions N.ReadWaitOptions
|
readWaitOptions N.ReadWaitOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,8 +127,6 @@ func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
|
||||||
return p.destination, err
|
return p.destination, err
|
||||||
case <-c.ctx.Done():
|
case <-c.ctx.Done():
|
||||||
return M.Socksaddr{}, io.ErrClosedPipe
|
return M.Socksaddr{}, io.ErrClosedPipe
|
||||||
case <-c.readDeadline.Wait():
|
|
||||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,14 +159,17 @@ func (c *conn) SetDeadline(t time.Time) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SetReadDeadline(t time.Time) error {
|
func (c *conn) SetReadDeadline(t time.Time) error {
|
||||||
c.readDeadline.Set(t)
|
return os.ErrInvalid
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SetWriteDeadline(t time.Time) error {
|
func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||||
return os.ErrInvalid
|
return os.ErrInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) NeedAdditionalReadDeadline() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (c *conn) Upstream() any {
|
func (c *conn) Upstream() any {
|
||||||
return c.source
|
return c.source
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,138 +0,0 @@
|
||||||
package udpnat
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
"github.com/sagernet/sing/common/canceler"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
"github.com/sagernet/sing/common/pipe"
|
|
||||||
"github.com/sagernet/sing/contrab/freelru"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Conn interface {
|
|
||||||
N.PacketConn
|
|
||||||
SetHandler(handler N.UDPHandlerEx)
|
|
||||||
canceler.PacketConn
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Conn = (*natConn)(nil)
|
|
||||||
|
|
||||||
type natConn struct {
|
|
||||||
cache freelru.Cache[netip.AddrPort, *natConn]
|
|
||||||
writer N.PacketWriter
|
|
||||||
localAddr M.Socksaddr
|
|
||||||
handlerAccess sync.RWMutex
|
|
||||||
handler N.UDPHandlerEx
|
|
||||||
packetChan chan *N.PacketBuffer
|
|
||||||
closeOnce sync.Once
|
|
||||||
doneChan chan struct{}
|
|
||||||
readDeadline pipe.Deadline
|
|
||||||
readWaitOptions N.ReadWaitOptions
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
|
|
||||||
select {
|
|
||||||
case p := <-c.packetChan:
|
|
||||||
_, err = buffer.ReadOnceFrom(p.Buffer)
|
|
||||||
destination := p.Destination
|
|
||||||
p.Buffer.Release()
|
|
||||||
N.PutPacketBuffer(p)
|
|
||||||
return destination, err
|
|
||||||
case <-c.doneChan:
|
|
||||||
return M.Socksaddr{}, io.ErrClosedPipe
|
|
||||||
case <-c.readDeadline.Wait():
|
|
||||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
|
||||||
return c.writer.WritePacket(buffer, destination)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
|
||||||
c.readWaitOptions = options
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
|
||||||
select {
|
|
||||||
case packet := <-c.packetChan:
|
|
||||||
buffer = c.readWaitOptions.Copy(packet.Buffer)
|
|
||||||
destination = packet.Destination
|
|
||||||
N.PutPacketBuffer(packet)
|
|
||||||
return
|
|
||||||
case <-c.doneChan:
|
|
||||||
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
|
||||||
case <-c.readDeadline.Wait():
|
|
||||||
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) SetHandler(handler N.UDPHandlerEx) {
|
|
||||||
c.handlerAccess.Lock()
|
|
||||||
c.handler = handler
|
|
||||||
c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler)
|
|
||||||
c.handlerAccess.Unlock()
|
|
||||||
fetch:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case packet := <-c.packetChan:
|
|
||||||
c.handler.NewPacketEx(packet.Buffer, packet.Destination)
|
|
||||||
N.PutPacketBuffer(packet)
|
|
||||||
continue fetch
|
|
||||||
default:
|
|
||||||
break fetch
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) Timeout() time.Duration {
|
|
||||||
rawConn, lifetime, loaded := c.cache.PeekWithLifetime(c.localAddr.AddrPort())
|
|
||||||
if !loaded || rawConn != c {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return time.Until(lifetime)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) SetTimeout(timeout time.Duration) bool {
|
|
||||||
return c.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) Close() error {
|
|
||||||
c.closeOnce.Do(func() {
|
|
||||||
close(c.doneChan)
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) LocalAddr() net.Addr {
|
|
||||||
return c.localAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) RemoteAddr() net.Addr {
|
|
||||||
return M.Socksaddr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) SetDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) SetReadDeadline(t time.Time) error {
|
|
||||||
c.readDeadline.Set(t)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) SetWriteDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) Upstream() any {
|
|
||||||
return c.writer
|
|
||||||
}
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue