mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 20:37:40 +03:00
Merge branch 'SagerNet:dev' into dev
This commit is contained in:
commit
5d7890f308
155 changed files with 10203 additions and 1390 deletions
1
.github/renovate.json
vendored
1
.github/renovate.json
vendored
|
@ -1,6 +1,7 @@
|
|||
{
|
||||
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
||||
"commitMessagePrefix": "[dependencies]",
|
||||
"branchName": "main",
|
||||
"extends": [
|
||||
"config:base",
|
||||
":disableRateLimiting"
|
||||
|
|
81
.github/workflows/debug.yml
vendored
81
.github/workflows/debug.yml
vendored
|
@ -3,6 +3,7 @@ name: Debug build
|
|||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
|
@ -10,33 +11,85 @@ on:
|
|||
- '!.github/workflows/debug.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Debug build
|
||||
name: Linux Debug build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Get latest go version
|
||||
id: version
|
||||
run: |
|
||||
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ steps.version.outputs.go_version }}
|
||||
- name: Add cache to Go proxy
|
||||
go-version: ^1.22
|
||||
- name: Build
|
||||
run: |
|
||||
version=`git rev-parse HEAD`
|
||||
mkdir build
|
||||
pushd build
|
||||
go mod init build
|
||||
go get -v github.com/sagernet/sing@$version
|
||||
popd
|
||||
make test
|
||||
build_go120:
|
||||
name: Linux Debug build (Go 1.20)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ~1.20
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go121:
|
||||
name: Linux Debug build (Go 1.21)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ~1.21
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build__windows:
|
||||
name: Windows Debug build
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ^1.22
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_darwin:
|
||||
name: macOS Debug build
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ^1.22
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
|
|
10
.github/workflows/lint.yml
vendored
10
.github/workflows/lint.yml
vendored
|
@ -3,6 +3,7 @@ name: Lint
|
|||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
|
@ -10,6 +11,7 @@ on:
|
|||
- '!.github/workflows/lint.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
|
@ -18,17 +20,13 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Get latest go version
|
||||
id: version
|
||||
run: |
|
||||
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ steps.version.outputs.go_version }}
|
||||
go-version: ^1.22
|
||||
- name: Cache go module
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,2 +1,3 @@
|
|||
/.idea/
|
||||
/vendor/
|
||||
.DS_Store
|
||||
|
|
2
Makefile
2
Makefile
|
@ -18,4 +18,4 @@ lint_install:
|
|||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
test:
|
||||
go test -v ./...
|
||||
go test $(shell go list ./... | grep -v /internal/)
|
||||
|
|
|
@ -18,7 +18,6 @@ var _ xml.TokenReader = (*Reader)(nil)
|
|||
type Reader struct {
|
||||
reader *bytes.Reader
|
||||
stringRefs []string
|
||||
attrs []xml.Attr
|
||||
}
|
||||
|
||||
func NewReader(content []byte) (xml.TokenReader, bool) {
|
||||
|
@ -47,7 +46,7 @@ func (r *Reader) Token() (token xml.Token, err error) {
|
|||
return
|
||||
}
|
||||
var attrs []xml.Attr
|
||||
attrs, err = r.pullAttributes()
|
||||
attrs, err = r.readAttributes()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -93,35 +92,41 @@ func (r *Reader) Token() (token xml.Token, err error) {
|
|||
_, err = r.readUTF()
|
||||
return
|
||||
case ATTRIBUTE:
|
||||
return nil, E.New("unexpected attribute")
|
||||
_, err = r.readAttribute()
|
||||
return
|
||||
}
|
||||
return nil, E.New("unknown token type ", tokenType, " with type ", eventType)
|
||||
}
|
||||
|
||||
func (r *Reader) pullAttributes() ([]xml.Attr, error) {
|
||||
err := r.pullAttribute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (r *Reader) readAttributes() ([]xml.Attr, error) {
|
||||
var attrs []xml.Attr
|
||||
for {
|
||||
attr, err := r.readAttribute()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
attrs = append(attrs, attr)
|
||||
}
|
||||
attrs := r.attrs
|
||||
r.attrs = nil
|
||||
return attrs, nil
|
||||
}
|
||||
|
||||
func (r *Reader) pullAttribute() error {
|
||||
func (r *Reader) readAttribute() (xml.Attr, error) {
|
||||
event, err := r.reader.ReadByte()
|
||||
if err != nil {
|
||||
return nil
|
||||
return xml.Attr{}, nil
|
||||
}
|
||||
tokenType := event & 0x0f
|
||||
eventType := event & 0xf0
|
||||
if tokenType != ATTRIBUTE {
|
||||
return r.reader.UnreadByte()
|
||||
}
|
||||
var name string
|
||||
name, err = r.readInternedUTF()
|
||||
err = r.reader.UnreadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, nil
|
||||
}
|
||||
return xml.Attr{}, io.EOF
|
||||
}
|
||||
name, err := r.readInternedUTF()
|
||||
if err != nil {
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
var value string
|
||||
switch eventType {
|
||||
|
@ -134,74 +139,73 @@ func (r *Reader) pullAttribute() error {
|
|||
case TypeString:
|
||||
value, err = r.readUTF()
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
case TypeStringInterned:
|
||||
value, err = r.readInternedUTF()
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
case TypeBytesHex:
|
||||
var data []byte
|
||||
data, err = r.readBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = hex.EncodeToString(data)
|
||||
case TypeBytesBase64:
|
||||
var data []byte
|
||||
data, err = r.readBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = base64.StdEncoding.EncodeToString(data)
|
||||
case TypeInt:
|
||||
var data int32
|
||||
err = binary.Read(r.reader, binary.BigEndian, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = strconv.FormatInt(int64(data), 10)
|
||||
case TypeIntHex:
|
||||
var data int32
|
||||
err = binary.Read(r.reader, binary.BigEndian, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = "0x" + strconv.FormatInt(int64(data), 16)
|
||||
case TypeLong:
|
||||
var data int64
|
||||
err = binary.Read(r.reader, binary.BigEndian, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = strconv.FormatInt(data, 10)
|
||||
case TypeLongHex:
|
||||
var data int64
|
||||
err = binary.Read(r.reader, binary.BigEndian, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = "0x" + strconv.FormatInt(data, 16)
|
||||
case TypeFloat:
|
||||
var data float32
|
||||
err = binary.Read(r.reader, binary.BigEndian, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = strconv.FormatFloat(float64(data), 'g', -1, 32)
|
||||
case TypeDouble:
|
||||
var data float64
|
||||
err = binary.Read(r.reader, binary.BigEndian, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
return xml.Attr{}, err
|
||||
}
|
||||
value = strconv.FormatFloat(data, 'g', -1, 64)
|
||||
default:
|
||||
return E.New("unexpected attribute type, ", eventType)
|
||||
return xml.Attr{}, E.New("unexpected attribute type, ", eventType)
|
||||
}
|
||||
r.attrs = append(r.attrs, xml.Attr{Name: xml.Name{Local: name}, Value: value})
|
||||
return r.pullAttribute()
|
||||
return xml.Attr{Name: xml.Name{Local: name}, Value: value}, nil
|
||||
}
|
||||
|
||||
func (r *Reader) readUnsignedShort() (uint16, error) {
|
||||
|
|
|
@ -10,26 +10,37 @@ type TypedValue[T any] struct {
|
|||
value atomic.Value
|
||||
}
|
||||
|
||||
// typedValue is a struct with determined type to resolve atomic.Value usages with interface types
|
||||
// https://github.com/golang/go/issues/22550
|
||||
//
|
||||
// The intention to have an atomic value store for errors. However, running this code panics:
|
||||
// panic: sync/atomic: store of inconsistently typed value into Value
|
||||
// This is because atomic.Value requires that the underlying concrete type be the same (which is a reasonable expectation for its implementation).
|
||||
// When going through the atomic.Value.Store method call, the fact that both these are of the error interface is lost.
|
||||
type typedValue[T any] struct {
|
||||
value T
|
||||
}
|
||||
|
||||
func (t *TypedValue[T]) Load() T {
|
||||
value := t.value.Load()
|
||||
if value == nil {
|
||||
return common.DefaultValue[T]()
|
||||
}
|
||||
return value.(T)
|
||||
return value.(typedValue[T]).value
|
||||
}
|
||||
|
||||
func (t *TypedValue[T]) Store(value T) {
|
||||
t.value.Store(value)
|
||||
t.value.Store(typedValue[T]{value})
|
||||
}
|
||||
|
||||
func (t *TypedValue[T]) Swap(new T) T {
|
||||
old := t.value.Swap(new)
|
||||
old := t.value.Swap(typedValue[T]{new})
|
||||
if old == nil {
|
||||
return common.DefaultValue[T]()
|
||||
}
|
||||
return old.(T)
|
||||
return old.(typedValue[T]).value
|
||||
}
|
||||
|
||||
func (t *TypedValue[T]) CompareAndSwap(old, new T) bool {
|
||||
return t.value.CompareAndSwap(old, new)
|
||||
return t.value.CompareAndSwap(typedValue[T]{old}, typedValue[T]{new})
|
||||
}
|
||||
|
|
|
@ -1,38 +1,30 @@
|
|||
package auth
|
||||
|
||||
type Authenticator interface {
|
||||
Verify(user string, pass string) bool
|
||||
Users() []string
|
||||
}
|
||||
import "github.com/sagernet/sing/common"
|
||||
|
||||
type User struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
type inMemoryAuthenticator struct {
|
||||
storage map[string]string
|
||||
usernames []string
|
||||
type Authenticator struct {
|
||||
userMap map[string][]string
|
||||
}
|
||||
|
||||
func (au *inMemoryAuthenticator) Verify(username string, password string) bool {
|
||||
realPass, ok := au.storage[username]
|
||||
return ok && realPass == password
|
||||
}
|
||||
|
||||
func (au *inMemoryAuthenticator) Users() []string { return au.usernames }
|
||||
|
||||
func NewAuthenticator(users []User) Authenticator {
|
||||
func NewAuthenticator(users []User) *Authenticator {
|
||||
if len(users) == 0 {
|
||||
return nil
|
||||
}
|
||||
au := &inMemoryAuthenticator{
|
||||
storage: make(map[string]string),
|
||||
usernames: make([]string, 0, len(users)),
|
||||
au := &Authenticator{
|
||||
userMap: make(map[string][]string),
|
||||
}
|
||||
for _, user := range users {
|
||||
au.storage[user.Username] = user.Password
|
||||
au.usernames = append(au.usernames, user.Username)
|
||||
au.userMap[user.Username] = append(au.userMap[user.Username], user.Password)
|
||||
}
|
||||
return au
|
||||
}
|
||||
|
||||
func (au *Authenticator) Verify(username string, password string) bool {
|
||||
passwordList, ok := au.userMap[username]
|
||||
return ok && common.Contains(passwordList, password)
|
||||
}
|
||||
|
|
|
@ -55,7 +55,10 @@ func WrapQUIC(err error) error {
|
|||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if Contains(err, "canceled with error code 0") {
|
||||
if Contains(err,
|
||||
"canceled by remote with error code 0",
|
||||
"canceled by local with error code 0",
|
||||
) {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return err
|
||||
|
|
3
common/binary/README.md
Normal file
3
common/binary/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# binary
|
||||
|
||||
mod from go 1.22.3
|
817
common/binary/binary.go
Normal file
817
common/binary/binary.go
Normal file
|
@ -0,0 +1,817 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package binary implements simple translation between numbers and byte
|
||||
// sequences and encoding and decoding of varints.
|
||||
//
|
||||
// Numbers are translated by reading and writing fixed-size values.
|
||||
// A fixed-size value is either a fixed-size arithmetic
|
||||
// type (bool, int8, uint8, int16, float32, complex64, ...)
|
||||
// or an array or struct containing only fixed-size values.
|
||||
//
|
||||
// The varint functions encode and decode single integer values using
|
||||
// a variable-length encoding; smaller values require fewer bytes.
|
||||
// For a specification, see
|
||||
// https://developers.google.com/protocol-buffers/docs/encoding.
|
||||
//
|
||||
// This package favors simplicity over efficiency. Clients that require
|
||||
// high-performance serialization, especially for large data structures,
|
||||
// should look at more advanced solutions such as the [encoding/gob]
|
||||
// package or [google.golang.org/protobuf] for protocol buffers.
|
||||
package binary
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A ByteOrder specifies how to convert byte slices into
|
||||
// 16-, 32-, or 64-bit unsigned integers.
|
||||
//
|
||||
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
|
||||
type ByteOrder interface {
|
||||
Uint16([]byte) uint16
|
||||
Uint32([]byte) uint32
|
||||
Uint64([]byte) uint64
|
||||
PutUint16([]byte, uint16)
|
||||
PutUint32([]byte, uint32)
|
||||
PutUint64([]byte, uint64)
|
||||
String() string
|
||||
}
|
||||
|
||||
// AppendByteOrder specifies how to append 16-, 32-, or 64-bit unsigned integers
|
||||
// into a byte slice.
|
||||
//
|
||||
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
|
||||
type AppendByteOrder interface {
|
||||
AppendUint16([]byte, uint16) []byte
|
||||
AppendUint32([]byte, uint32) []byte
|
||||
AppendUint64([]byte, uint64) []byte
|
||||
String() string
|
||||
}
|
||||
|
||||
// LittleEndian is the little-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var LittleEndian littleEndian
|
||||
|
||||
// BigEndian is the big-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var BigEndian bigEndian
|
||||
|
||||
type littleEndian struct{}
|
||||
|
||||
func (littleEndian) Uint16(b []byte) uint16 {
|
||||
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint16(b[0]) | uint16(b[1])<<8
|
||||
}
|
||||
|
||||
func (littleEndian) PutUint16(b []byte, v uint16) {
|
||||
_ = b[1] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v)
|
||||
b[1] = byte(v >> 8)
|
||||
}
|
||||
|
||||
func (littleEndian) AppendUint16(b []byte, v uint16) []byte {
|
||||
return append(b,
|
||||
byte(v),
|
||||
byte(v>>8),
|
||||
)
|
||||
}
|
||||
|
||||
func (littleEndian) Uint32(b []byte) uint32 {
|
||||
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
|
||||
}
|
||||
|
||||
func (littleEndian) PutUint32(b []byte, v uint32) {
|
||||
_ = b[3] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v)
|
||||
b[1] = byte(v >> 8)
|
||||
b[2] = byte(v >> 16)
|
||||
b[3] = byte(v >> 24)
|
||||
}
|
||||
|
||||
func (littleEndian) AppendUint32(b []byte, v uint32) []byte {
|
||||
return append(b,
|
||||
byte(v),
|
||||
byte(v>>8),
|
||||
byte(v>>16),
|
||||
byte(v>>24),
|
||||
)
|
||||
}
|
||||
|
||||
func (littleEndian) Uint64(b []byte) uint64 {
|
||||
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
|
||||
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
|
||||
}
|
||||
|
||||
func (littleEndian) PutUint64(b []byte, v uint64) {
|
||||
_ = b[7] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v)
|
||||
b[1] = byte(v >> 8)
|
||||
b[2] = byte(v >> 16)
|
||||
b[3] = byte(v >> 24)
|
||||
b[4] = byte(v >> 32)
|
||||
b[5] = byte(v >> 40)
|
||||
b[6] = byte(v >> 48)
|
||||
b[7] = byte(v >> 56)
|
||||
}
|
||||
|
||||
func (littleEndian) AppendUint64(b []byte, v uint64) []byte {
|
||||
return append(b,
|
||||
byte(v),
|
||||
byte(v>>8),
|
||||
byte(v>>16),
|
||||
byte(v>>24),
|
||||
byte(v>>32),
|
||||
byte(v>>40),
|
||||
byte(v>>48),
|
||||
byte(v>>56),
|
||||
)
|
||||
}
|
||||
|
||||
func (littleEndian) String() string { return "LittleEndian" }
|
||||
|
||||
func (littleEndian) GoString() string { return "binary.LittleEndian" }
|
||||
|
||||
type bigEndian struct{}
|
||||
|
||||
func (bigEndian) Uint16(b []byte) uint16 {
|
||||
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint16(b[1]) | uint16(b[0])<<8
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint16(b []byte, v uint16) {
|
||||
_ = b[1] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v >> 8)
|
||||
b[1] = byte(v)
|
||||
}
|
||||
|
||||
func (bigEndian) AppendUint16(b []byte, v uint16) []byte {
|
||||
return append(b,
|
||||
byte(v>>8),
|
||||
byte(v),
|
||||
)
|
||||
}
|
||||
|
||||
func (bigEndian) Uint32(b []byte) uint32 {
|
||||
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint32(b []byte, v uint32) {
|
||||
_ = b[3] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v >> 24)
|
||||
b[1] = byte(v >> 16)
|
||||
b[2] = byte(v >> 8)
|
||||
b[3] = byte(v)
|
||||
}
|
||||
|
||||
func (bigEndian) AppendUint32(b []byte, v uint32) []byte {
|
||||
return append(b,
|
||||
byte(v>>24),
|
||||
byte(v>>16),
|
||||
byte(v>>8),
|
||||
byte(v),
|
||||
)
|
||||
}
|
||||
|
||||
func (bigEndian) Uint64(b []byte) uint64 {
|
||||
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
|
||||
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint64(b []byte, v uint64) {
|
||||
_ = b[7] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v >> 56)
|
||||
b[1] = byte(v >> 48)
|
||||
b[2] = byte(v >> 40)
|
||||
b[3] = byte(v >> 32)
|
||||
b[4] = byte(v >> 24)
|
||||
b[5] = byte(v >> 16)
|
||||
b[6] = byte(v >> 8)
|
||||
b[7] = byte(v)
|
||||
}
|
||||
|
||||
func (bigEndian) AppendUint64(b []byte, v uint64) []byte {
|
||||
return append(b,
|
||||
byte(v>>56),
|
||||
byte(v>>48),
|
||||
byte(v>>40),
|
||||
byte(v>>32),
|
||||
byte(v>>24),
|
||||
byte(v>>16),
|
||||
byte(v>>8),
|
||||
byte(v),
|
||||
)
|
||||
}
|
||||
|
||||
func (bigEndian) String() string { return "BigEndian" }
|
||||
|
||||
func (bigEndian) GoString() string { return "binary.BigEndian" }
|
||||
|
||||
func (nativeEndian) String() string { return "NativeEndian" }
|
||||
|
||||
func (nativeEndian) GoString() string { return "binary.NativeEndian" }
|
||||
|
||||
// Read reads structured binary data from r into data.
|
||||
// Data must be a pointer to a fixed-size value or a slice
|
||||
// of fixed-size values.
|
||||
// Bytes read from r are decoded using the specified byte order
|
||||
// and written to successive fields of the data.
|
||||
// When decoding boolean values, a zero byte is decoded as false, and
|
||||
// any other non-zero byte is decoded as true.
|
||||
// When reading into structs, the field data for fields with
|
||||
// blank (_) field names is skipped; i.e., blank field names
|
||||
// may be used for padding.
|
||||
// When reading into a struct, all non-blank fields must be exported
|
||||
// or Read may panic.
|
||||
//
|
||||
// The error is [io.EOF] only if no bytes were read.
|
||||
// If an [io.EOF] happens after reading some but not all the bytes,
|
||||
// Read returns [io.ErrUnexpectedEOF].
|
||||
func Read(r io.Reader, order ByteOrder, data any) error {
|
||||
// Fast path for basic types and slices.
|
||||
if n := intDataSize(data); n != 0 {
|
||||
bs := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, bs); err != nil {
|
||||
return err
|
||||
}
|
||||
switch data := data.(type) {
|
||||
case *bool:
|
||||
*data = bs[0] != 0
|
||||
case *int8:
|
||||
*data = int8(bs[0])
|
||||
case *uint8:
|
||||
*data = bs[0]
|
||||
case *int16:
|
||||
*data = int16(order.Uint16(bs))
|
||||
case *uint16:
|
||||
*data = order.Uint16(bs)
|
||||
case *int32:
|
||||
*data = int32(order.Uint32(bs))
|
||||
case *uint32:
|
||||
*data = order.Uint32(bs)
|
||||
case *int64:
|
||||
*data = int64(order.Uint64(bs))
|
||||
case *uint64:
|
||||
*data = order.Uint64(bs)
|
||||
case *float32:
|
||||
*data = math.Float32frombits(order.Uint32(bs))
|
||||
case *float64:
|
||||
*data = math.Float64frombits(order.Uint64(bs))
|
||||
case []bool:
|
||||
for i, x := range bs { // Easier to loop over the input for 8-bit values.
|
||||
data[i] = x != 0
|
||||
}
|
||||
case []int8:
|
||||
for i, x := range bs {
|
||||
data[i] = int8(x)
|
||||
}
|
||||
case []uint8:
|
||||
copy(data, bs)
|
||||
case []int16:
|
||||
for i := range data {
|
||||
data[i] = int16(order.Uint16(bs[2*i:]))
|
||||
}
|
||||
case []uint16:
|
||||
for i := range data {
|
||||
data[i] = order.Uint16(bs[2*i:])
|
||||
}
|
||||
case []int32:
|
||||
for i := range data {
|
||||
data[i] = int32(order.Uint32(bs[4*i:]))
|
||||
}
|
||||
case []uint32:
|
||||
for i := range data {
|
||||
data[i] = order.Uint32(bs[4*i:])
|
||||
}
|
||||
case []int64:
|
||||
for i := range data {
|
||||
data[i] = int64(order.Uint64(bs[8*i:]))
|
||||
}
|
||||
case []uint64:
|
||||
for i := range data {
|
||||
data[i] = order.Uint64(bs[8*i:])
|
||||
}
|
||||
case []float32:
|
||||
for i := range data {
|
||||
data[i] = math.Float32frombits(order.Uint32(bs[4*i:]))
|
||||
}
|
||||
case []float64:
|
||||
for i := range data {
|
||||
data[i] = math.Float64frombits(order.Uint64(bs[8*i:]))
|
||||
}
|
||||
default:
|
||||
n = 0 // fast path doesn't apply
|
||||
}
|
||||
if n != 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to reflect-based decoding.
|
||||
v := reflect.ValueOf(data)
|
||||
size := -1
|
||||
switch v.Kind() {
|
||||
case reflect.Pointer:
|
||||
v = v.Elem()
|
||||
size = dataSize(v)
|
||||
case reflect.Slice:
|
||||
size = dataSize(v)
|
||||
}
|
||||
if size < 0 {
|
||||
return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String())
|
||||
}
|
||||
d := &decoder{order: order, buf: make([]byte, size)}
|
||||
if _, err := io.ReadFull(r, d.buf); err != nil {
|
||||
return err
|
||||
}
|
||||
d.value(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes the binary representation of data into w.
|
||||
// Data must be a fixed-size value or a slice of fixed-size
|
||||
// values, or a pointer to such data.
|
||||
// Boolean values encode as one byte: 1 for true, and 0 for false.
|
||||
// Bytes written to w are encoded using the specified byte order
|
||||
// and read from successive fields of the data.
|
||||
// When writing structs, zero values are written for fields
|
||||
// with blank (_) field names.
|
||||
func Write(w io.Writer, order ByteOrder, data any) error {
|
||||
// Fast path for basic types and slices.
|
||||
if n := intDataSize(data); n != 0 {
|
||||
bs := make([]byte, n)
|
||||
switch v := data.(type) {
|
||||
case *bool:
|
||||
if *v {
|
||||
bs[0] = 1
|
||||
} else {
|
||||
bs[0] = 0
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
bs[0] = 1
|
||||
} else {
|
||||
bs[0] = 0
|
||||
}
|
||||
case []bool:
|
||||
for i, x := range v {
|
||||
if x {
|
||||
bs[i] = 1
|
||||
} else {
|
||||
bs[i] = 0
|
||||
}
|
||||
}
|
||||
case *int8:
|
||||
bs[0] = byte(*v)
|
||||
case int8:
|
||||
bs[0] = byte(v)
|
||||
case []int8:
|
||||
for i, x := range v {
|
||||
bs[i] = byte(x)
|
||||
}
|
||||
case *uint8:
|
||||
bs[0] = *v
|
||||
case uint8:
|
||||
bs[0] = v
|
||||
case []uint8:
|
||||
bs = v
|
||||
case *int16:
|
||||
order.PutUint16(bs, uint16(*v))
|
||||
case int16:
|
||||
order.PutUint16(bs, uint16(v))
|
||||
case []int16:
|
||||
for i, x := range v {
|
||||
order.PutUint16(bs[2*i:], uint16(x))
|
||||
}
|
||||
case *uint16:
|
||||
order.PutUint16(bs, *v)
|
||||
case uint16:
|
||||
order.PutUint16(bs, v)
|
||||
case []uint16:
|
||||
for i, x := range v {
|
||||
order.PutUint16(bs[2*i:], x)
|
||||
}
|
||||
case *int32:
|
||||
order.PutUint32(bs, uint32(*v))
|
||||
case int32:
|
||||
order.PutUint32(bs, uint32(v))
|
||||
case []int32:
|
||||
for i, x := range v {
|
||||
order.PutUint32(bs[4*i:], uint32(x))
|
||||
}
|
||||
case *uint32:
|
||||
order.PutUint32(bs, *v)
|
||||
case uint32:
|
||||
order.PutUint32(bs, v)
|
||||
case []uint32:
|
||||
for i, x := range v {
|
||||
order.PutUint32(bs[4*i:], x)
|
||||
}
|
||||
case *int64:
|
||||
order.PutUint64(bs, uint64(*v))
|
||||
case int64:
|
||||
order.PutUint64(bs, uint64(v))
|
||||
case []int64:
|
||||
for i, x := range v {
|
||||
order.PutUint64(bs[8*i:], uint64(x))
|
||||
}
|
||||
case *uint64:
|
||||
order.PutUint64(bs, *v)
|
||||
case uint64:
|
||||
order.PutUint64(bs, v)
|
||||
case []uint64:
|
||||
for i, x := range v {
|
||||
order.PutUint64(bs[8*i:], x)
|
||||
}
|
||||
case *float32:
|
||||
order.PutUint32(bs, math.Float32bits(*v))
|
||||
case float32:
|
||||
order.PutUint32(bs, math.Float32bits(v))
|
||||
case []float32:
|
||||
for i, x := range v {
|
||||
order.PutUint32(bs[4*i:], math.Float32bits(x))
|
||||
}
|
||||
case *float64:
|
||||
order.PutUint64(bs, math.Float64bits(*v))
|
||||
case float64:
|
||||
order.PutUint64(bs, math.Float64bits(v))
|
||||
case []float64:
|
||||
for i, x := range v {
|
||||
order.PutUint64(bs[8*i:], math.Float64bits(x))
|
||||
}
|
||||
}
|
||||
_, err := w.Write(bs)
|
||||
return err
|
||||
}
|
||||
|
||||
// Fallback to reflect-based encoding.
|
||||
v := reflect.Indirect(reflect.ValueOf(data))
|
||||
size := dataSize(v)
|
||||
if size < 0 {
|
||||
return errors.New("binary.Write: some values are not fixed-sized in type " + reflect.TypeOf(data).String())
|
||||
}
|
||||
buf := make([]byte, size)
|
||||
e := &encoder{order: order, buf: buf}
|
||||
e.value(v)
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// Size returns how many bytes [Write] would generate to encode the value v, which
|
||||
// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data.
|
||||
// If v is neither of these, Size returns -1.
|
||||
func Size(v any) int {
|
||||
return dataSize(reflect.Indirect(reflect.ValueOf(v)))
|
||||
}
|
||||
|
||||
var structSize sync.Map // map[reflect.Type]int
|
||||
|
||||
// dataSize returns the number of bytes the actual data represented by v occupies in memory.
|
||||
// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice
|
||||
// it returns the length of the slice times the element size and does not count the memory
|
||||
// occupied by the header. If the type of v is not acceptable, dataSize returns -1.
|
||||
func dataSize(v reflect.Value) int {
|
||||
switch v.Kind() {
|
||||
case reflect.Slice:
|
||||
if s := sizeof(v.Type().Elem()); s >= 0 {
|
||||
return s * v.Len()
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
if size, ok := structSize.Load(t); ok {
|
||||
return size.(int)
|
||||
}
|
||||
size := sizeof(t)
|
||||
structSize.Store(t, size)
|
||||
return size
|
||||
|
||||
default:
|
||||
if v.IsValid() {
|
||||
return sizeof(v.Type())
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// sizeof returns the size >= 0 of variables for the given type or -1 if the type is not acceptable.
|
||||
func sizeof(t reflect.Type) int {
|
||||
switch t.Kind() {
|
||||
case reflect.Array:
|
||||
if s := sizeof(t.Elem()); s >= 0 {
|
||||
return s * t.Len()
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
sum := 0
|
||||
for i, n := 0, t.NumField(); i < n; i++ {
|
||||
s := sizeof(t.Field(i).Type)
|
||||
if s < 0 {
|
||||
return -1
|
||||
}
|
||||
sum += s
|
||||
}
|
||||
return sum
|
||||
|
||||
case reflect.Bool,
|
||||
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
return int(t.Size())
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
type coder struct {
|
||||
order ByteOrder
|
||||
buf []byte
|
||||
offset int
|
||||
}
|
||||
|
||||
type (
|
||||
decoder coder
|
||||
encoder coder
|
||||
)
|
||||
|
||||
func (d *decoder) bool() bool {
|
||||
x := d.buf[d.offset]
|
||||
d.offset++
|
||||
return x != 0
|
||||
}
|
||||
|
||||
func (e *encoder) bool(x bool) {
|
||||
if x {
|
||||
e.buf[e.offset] = 1
|
||||
} else {
|
||||
e.buf[e.offset] = 0
|
||||
}
|
||||
e.offset++
|
||||
}
|
||||
|
||||
func (d *decoder) uint8() uint8 {
|
||||
x := d.buf[d.offset]
|
||||
d.offset++
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint8(x uint8) {
|
||||
e.buf[e.offset] = x
|
||||
e.offset++
|
||||
}
|
||||
|
||||
func (d *decoder) uint16() uint16 {
|
||||
x := d.order.Uint16(d.buf[d.offset : d.offset+2])
|
||||
d.offset += 2
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint16(x uint16) {
|
||||
e.order.PutUint16(e.buf[e.offset:e.offset+2], x)
|
||||
e.offset += 2
|
||||
}
|
||||
|
||||
func (d *decoder) uint32() uint32 {
|
||||
x := d.order.Uint32(d.buf[d.offset : d.offset+4])
|
||||
d.offset += 4
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint32(x uint32) {
|
||||
e.order.PutUint32(e.buf[e.offset:e.offset+4], x)
|
||||
e.offset += 4
|
||||
}
|
||||
|
||||
func (d *decoder) uint64() uint64 {
|
||||
x := d.order.Uint64(d.buf[d.offset : d.offset+8])
|
||||
d.offset += 8
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint64(x uint64) {
|
||||
e.order.PutUint64(e.buf[e.offset:e.offset+8], x)
|
||||
e.offset += 8
|
||||
}
|
||||
|
||||
func (d *decoder) int8() int8 { return int8(d.uint8()) }
|
||||
|
||||
func (e *encoder) int8(x int8) { e.uint8(uint8(x)) }
|
||||
|
||||
func (d *decoder) int16() int16 { return int16(d.uint16()) }
|
||||
|
||||
func (e *encoder) int16(x int16) { e.uint16(uint16(x)) }
|
||||
|
||||
func (d *decoder) int32() int32 { return int32(d.uint32()) }
|
||||
|
||||
func (e *encoder) int32(x int32) { e.uint32(uint32(x)) }
|
||||
|
||||
func (d *decoder) int64() int64 { return int64(d.uint64()) }
|
||||
|
||||
func (e *encoder) int64(x int64) { e.uint64(uint64(x)) }
|
||||
|
||||
func (d *decoder) value(v reflect.Value) {
|
||||
switch v.Kind() {
|
||||
case reflect.Array:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
d.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
l := v.NumField()
|
||||
for i := 0; i < l; i++ {
|
||||
// Note: Calling v.CanSet() below is an optimization.
|
||||
// It would be sufficient to check the field name,
|
||||
// but creating the StructField info for each field is
|
||||
// costly (run "go test -bench=ReadStruct" and compare
|
||||
// results when making changes to this code).
|
||||
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
|
||||
d.value(v)
|
||||
} else {
|
||||
d.skip(v)
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
d.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Bool:
|
||||
v.SetBool(d.bool())
|
||||
|
||||
case reflect.Int8:
|
||||
v.SetInt(int64(d.int8()))
|
||||
case reflect.Int16:
|
||||
v.SetInt(int64(d.int16()))
|
||||
case reflect.Int32:
|
||||
v.SetInt(int64(d.int32()))
|
||||
case reflect.Int64:
|
||||
v.SetInt(d.int64())
|
||||
|
||||
case reflect.Uint8:
|
||||
v.SetUint(uint64(d.uint8()))
|
||||
case reflect.Uint16:
|
||||
v.SetUint(uint64(d.uint16()))
|
||||
case reflect.Uint32:
|
||||
v.SetUint(uint64(d.uint32()))
|
||||
case reflect.Uint64:
|
||||
v.SetUint(d.uint64())
|
||||
|
||||
case reflect.Float32:
|
||||
v.SetFloat(float64(math.Float32frombits(d.uint32())))
|
||||
case reflect.Float64:
|
||||
v.SetFloat(math.Float64frombits(d.uint64()))
|
||||
|
||||
case reflect.Complex64:
|
||||
v.SetComplex(complex(
|
||||
float64(math.Float32frombits(d.uint32())),
|
||||
float64(math.Float32frombits(d.uint32())),
|
||||
))
|
||||
case reflect.Complex128:
|
||||
v.SetComplex(complex(
|
||||
math.Float64frombits(d.uint64()),
|
||||
math.Float64frombits(d.uint64()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *encoder) value(v reflect.Value) {
|
||||
switch v.Kind() {
|
||||
case reflect.Array:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
e.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
l := v.NumField()
|
||||
for i := 0; i < l; i++ {
|
||||
// see comment for corresponding code in decoder.value()
|
||||
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
|
||||
e.value(v)
|
||||
} else {
|
||||
e.skip(v)
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
e.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Bool:
|
||||
e.bool(v.Bool())
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Int8:
|
||||
e.int8(int8(v.Int()))
|
||||
case reflect.Int16:
|
||||
e.int16(int16(v.Int()))
|
||||
case reflect.Int32:
|
||||
e.int32(int32(v.Int()))
|
||||
case reflect.Int64:
|
||||
e.int64(v.Int())
|
||||
}
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Uint8:
|
||||
e.uint8(uint8(v.Uint()))
|
||||
case reflect.Uint16:
|
||||
e.uint16(uint16(v.Uint()))
|
||||
case reflect.Uint32:
|
||||
e.uint32(uint32(v.Uint()))
|
||||
case reflect.Uint64:
|
||||
e.uint64(v.Uint())
|
||||
}
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Float32:
|
||||
e.uint32(math.Float32bits(float32(v.Float())))
|
||||
case reflect.Float64:
|
||||
e.uint64(math.Float64bits(v.Float()))
|
||||
}
|
||||
|
||||
case reflect.Complex64, reflect.Complex128:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Complex64:
|
||||
x := v.Complex()
|
||||
e.uint32(math.Float32bits(float32(real(x))))
|
||||
e.uint32(math.Float32bits(float32(imag(x))))
|
||||
case reflect.Complex128:
|
||||
x := v.Complex()
|
||||
e.uint64(math.Float64bits(real(x)))
|
||||
e.uint64(math.Float64bits(imag(x)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *decoder) skip(v reflect.Value) {
|
||||
d.offset += dataSize(v)
|
||||
}
|
||||
|
||||
func (e *encoder) skip(v reflect.Value) {
|
||||
n := dataSize(v)
|
||||
zero := e.buf[e.offset : e.offset+n]
|
||||
for i := range zero {
|
||||
zero[i] = 0
|
||||
}
|
||||
e.offset += n
|
||||
}
|
||||
|
||||
// intDataSize returns the size of the data required to represent the data when encoded.
|
||||
// It returns zero if the type cannot be implemented by the fast path in Read or Write.
|
||||
func intDataSize(data any) int {
|
||||
switch data := data.(type) {
|
||||
case bool, int8, uint8, *bool, *int8, *uint8:
|
||||
return 1
|
||||
case []bool:
|
||||
return len(data)
|
||||
case []int8:
|
||||
return len(data)
|
||||
case []uint8:
|
||||
return len(data)
|
||||
case int16, uint16, *int16, *uint16:
|
||||
return 2
|
||||
case []int16:
|
||||
return 2 * len(data)
|
||||
case []uint16:
|
||||
return 2 * len(data)
|
||||
case int32, uint32, *int32, *uint32:
|
||||
return 4
|
||||
case []int32:
|
||||
return 4 * len(data)
|
||||
case []uint32:
|
||||
return 4 * len(data)
|
||||
case int64, uint64, *int64, *uint64:
|
||||
return 8
|
||||
case []int64:
|
||||
return 8 * len(data)
|
||||
case []uint64:
|
||||
return 8 * len(data)
|
||||
case float32, *float32:
|
||||
return 4
|
||||
case float64, *float64:
|
||||
return 8
|
||||
case []float32:
|
||||
return 4 * len(data)
|
||||
case []float64:
|
||||
return 8 * len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
14
common/binary/native_endian_big.go
Normal file
14
common/binary/native_endian_big.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build armbe || arm64be || m68k || mips || mips64 || mips64p32 || ppc || ppc64 || s390 || s390x || shbe || sparc || sparc64
|
||||
|
||||
package binary
|
||||
|
||||
type nativeEndian struct {
|
||||
bigEndian
|
||||
}
|
||||
|
||||
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var NativeEndian nativeEndian
|
14
common/binary/native_endian_little.go
Normal file
14
common/binary/native_endian_little.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build 386 || amd64 || amd64p32 || alpha || arm || arm64 || loong64 || mipsle || mips64le || mips64p32le || nios2 || ppc64le || riscv || riscv64 || sh || wasm
|
||||
|
||||
package binary
|
||||
|
||||
type nativeEndian struct {
|
||||
littleEndian
|
||||
}
|
||||
|
||||
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var NativeEndian nativeEndian
|
305
common/binary/variant_data.go
Normal file
305
common/binary/variant_data.go
Normal file
|
@ -0,0 +1,305 @@
|
|||
package binary
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func ReadDataSlice(r *bufio.Reader, order ByteOrder, data ...any) error {
|
||||
for index, item := range data {
|
||||
err := ReadData(r, order, item)
|
||||
if err != nil {
|
||||
return E.Cause(err, "[", index, "]")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReadData(r *bufio.Reader, order ByteOrder, data any) error {
|
||||
switch dataPtr := data.(type) {
|
||||
case *[]uint8:
|
||||
bytesLen, err := ReadUvarint(r)
|
||||
if err != nil {
|
||||
return E.Cause(err, "bytes length")
|
||||
}
|
||||
newBytes := make([]uint8, bytesLen)
|
||||
_, err = io.ReadFull(r, newBytes)
|
||||
if err != nil {
|
||||
return E.Cause(err, "bytes value")
|
||||
}
|
||||
*dataPtr = newBytes
|
||||
default:
|
||||
if intBaseDataSize(data) != 0 {
|
||||
return Read(r, order, data)
|
||||
}
|
||||
}
|
||||
dataValue := reflect.ValueOf(data)
|
||||
if dataValue.Kind() == reflect.Pointer {
|
||||
dataValue = dataValue.Elem()
|
||||
}
|
||||
return readData(r, order, dataValue)
|
||||
}
|
||||
|
||||
func readData(r *bufio.Reader, order ByteOrder, data reflect.Value) error {
|
||||
switch data.Kind() {
|
||||
case reflect.Pointer:
|
||||
pointerValue, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pointerValue == 0 {
|
||||
data.SetZero()
|
||||
return nil
|
||||
}
|
||||
if data.IsNil() {
|
||||
data.Set(reflect.New(data.Type().Elem()))
|
||||
}
|
||||
return readData(r, order, data.Elem())
|
||||
case reflect.String:
|
||||
stringLength, err := ReadUvarint(r)
|
||||
if err != nil {
|
||||
return E.Cause(err, "string length")
|
||||
}
|
||||
if stringLength == 0 {
|
||||
data.SetZero()
|
||||
} else {
|
||||
stringData := make([]byte, stringLength)
|
||||
_, err = io.ReadFull(r, stringData)
|
||||
if err != nil {
|
||||
return E.Cause(err, "string value")
|
||||
}
|
||||
data.SetString(string(stringData))
|
||||
}
|
||||
case reflect.Array:
|
||||
arrayLen := data.Len()
|
||||
for i := 0; i < arrayLen; i++ {
|
||||
err := readData(r, order, data.Index(i))
|
||||
if err != nil {
|
||||
return E.Cause(err, "[", i, "]")
|
||||
}
|
||||
}
|
||||
case reflect.Slice:
|
||||
sliceLength, err := ReadUvarint(r)
|
||||
if err != nil {
|
||||
return E.Cause(err, "slice length")
|
||||
}
|
||||
if !data.IsNil() && data.Cap() >= int(sliceLength) {
|
||||
data.SetLen(int(sliceLength))
|
||||
} else if sliceLength > 0 {
|
||||
data.Set(reflect.MakeSlice(data.Type(), int(sliceLength), int(sliceLength)))
|
||||
}
|
||||
if sliceLength > 0 {
|
||||
if data.Type().Elem().Kind() == reflect.Uint8 {
|
||||
_, err = io.ReadFull(r, data.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "bytes value")
|
||||
}
|
||||
} else {
|
||||
for index := 0; index < int(sliceLength); index++ {
|
||||
err = readData(r, order, data.Index(index))
|
||||
if err != nil {
|
||||
return E.Cause(err, "[", index, "]")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
mapLength, err := ReadUvarint(r)
|
||||
if err != nil {
|
||||
return E.Cause(err, "map length")
|
||||
}
|
||||
data.Set(reflect.MakeMap(data.Type()))
|
||||
for index := 0; index < int(mapLength); index++ {
|
||||
key := reflect.New(data.Type().Key()).Elem()
|
||||
err = readData(r, order, key)
|
||||
if err != nil {
|
||||
return E.Cause(err, "[", index, "].key")
|
||||
}
|
||||
value := reflect.New(data.Type().Elem()).Elem()
|
||||
err = readData(r, order, value)
|
||||
if err != nil {
|
||||
return E.Cause(err, "[", index, "].value")
|
||||
}
|
||||
data.SetMapIndex(key, value)
|
||||
}
|
||||
case reflect.Struct:
|
||||
fieldType := data.Type()
|
||||
fieldLen := data.NumField()
|
||||
for i := 0; i < fieldLen; i++ {
|
||||
field := data.Field(i)
|
||||
fieldName := fieldType.Field(i).Name
|
||||
if field.CanSet() || fieldName != "_" {
|
||||
err := readData(r, order, field)
|
||||
if err != nil {
|
||||
return E.Cause(err, fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
size := dataSize(data)
|
||||
if size < 0 {
|
||||
return errors.New("invalid type " + reflect.TypeOf(data).String())
|
||||
}
|
||||
d := &decoder{order: order, buf: make([]byte, size)}
|
||||
_, err := io.ReadFull(r, d.buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.value(data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func WriteDataSlice(writer *bufio.Writer, order ByteOrder, data ...any) error {
|
||||
for index, item := range data {
|
||||
err := WriteData(writer, order, item)
|
||||
if err != nil {
|
||||
return E.Cause(err, "[", index, "]")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func WriteData(writer *bufio.Writer, order ByteOrder, data any) error {
|
||||
switch dataPtr := data.(type) {
|
||||
case []uint8:
|
||||
_, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(len(dataPtr))))
|
||||
if err != nil {
|
||||
return E.Cause(err, "bytes length")
|
||||
}
|
||||
_, err = writer.Write(dataPtr)
|
||||
if err != nil {
|
||||
return E.Cause(err, "bytes value")
|
||||
}
|
||||
default:
|
||||
if intBaseDataSize(data) != 0 {
|
||||
return Write(writer, order, data)
|
||||
}
|
||||
}
|
||||
return writeData(writer, order, reflect.Indirect(reflect.ValueOf(data)))
|
||||
}
|
||||
|
||||
func writeData(writer *bufio.Writer, order ByteOrder, data reflect.Value) error {
|
||||
switch data.Kind() {
|
||||
case reflect.Pointer:
|
||||
if data.IsNil() {
|
||||
err := writer.WriteByte(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err := writer.WriteByte(1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeData(writer, order, data.Elem())
|
||||
}
|
||||
case reflect.String:
|
||||
stringValue := data.String()
|
||||
_, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(len(stringValue))))
|
||||
if err != nil {
|
||||
return E.Cause(err, "string length")
|
||||
}
|
||||
if stringValue != "" {
|
||||
_, err = writer.WriteString(stringValue)
|
||||
if err != nil {
|
||||
return E.Cause(err, "string value")
|
||||
}
|
||||
}
|
||||
case reflect.Array:
|
||||
dataLen := data.Len()
|
||||
for i := 0; i < dataLen; i++ {
|
||||
err := writeData(writer, order, data.Index(i))
|
||||
if err != nil {
|
||||
return E.Cause(err, "[", i, "]")
|
||||
}
|
||||
}
|
||||
case reflect.Slice:
|
||||
dataLen := data.Len()
|
||||
_, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(dataLen)))
|
||||
if err != nil {
|
||||
return E.Cause(err, "slice length")
|
||||
}
|
||||
if dataLen > 0 {
|
||||
if data.Type().Elem().Kind() == reflect.Uint8 {
|
||||
_, err = writer.Write(data.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "bytes value")
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < dataLen; i++ {
|
||||
err = writeData(writer, order, data.Index(i))
|
||||
if err != nil {
|
||||
return E.Cause(err, "[", i, "]")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
dataLen := data.Len()
|
||||
_, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(dataLen)))
|
||||
if err != nil {
|
||||
return E.Cause(err, "map length")
|
||||
}
|
||||
if dataLen > 0 {
|
||||
for index, key := range data.MapKeys() {
|
||||
err = writeData(writer, order, key)
|
||||
if err != nil {
|
||||
return E.Cause(err, "[", index, "].key")
|
||||
}
|
||||
err = writeData(writer, order, data.MapIndex(key))
|
||||
if err != nil {
|
||||
return E.Cause(err, "[", index, "].value")
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
fieldType := data.Type()
|
||||
fieldLen := data.NumField()
|
||||
for i := 0; i < fieldLen; i++ {
|
||||
field := data.Field(i)
|
||||
fieldName := fieldType.Field(i).Name
|
||||
if field.CanSet() || fieldName != "_" {
|
||||
err := writeData(writer, order, field)
|
||||
if err != nil {
|
||||
return E.Cause(err, fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
size := dataSize(data)
|
||||
if size < 0 {
|
||||
return errors.New("binary.Write: some values are not fixed-sized in type " + data.Type().String())
|
||||
}
|
||||
buf := make([]byte, size)
|
||||
e := &encoder{order: order, buf: buf}
|
||||
e.value(data)
|
||||
_, err := writer.Write(buf)
|
||||
if err != nil {
|
||||
return E.Cause(err, reflect.TypeOf(data).String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func intBaseDataSize(data any) int {
|
||||
switch data.(type) {
|
||||
case bool, int8, uint8:
|
||||
return 1
|
||||
case int16, uint16:
|
||||
return 2
|
||||
case int32, uint32:
|
||||
return 4
|
||||
case int64, uint64:
|
||||
return 8
|
||||
case float32:
|
||||
return 4
|
||||
case float64:
|
||||
return 8
|
||||
}
|
||||
return 0
|
||||
}
|
166
common/binary/varint.go
Normal file
166
common/binary/varint.go
Normal file
|
@ -0,0 +1,166 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package binary
|
||||
|
||||
// This file implements "varint" encoding of 64-bit integers.
|
||||
// The encoding is:
|
||||
// - unsigned integers are serialized 7 bits at a time, starting with the
|
||||
// least significant bits
|
||||
// - the most significant bit (msb) in each output byte indicates if there
|
||||
// is a continuation byte (msb = 1)
|
||||
// - signed integers are mapped to unsigned integers using "zig-zag"
|
||||
// encoding: Positive values x are written as 2*x + 0, negative values
|
||||
// are written as 2*(^x) + 1; that is, negative numbers are complemented
|
||||
// and whether to complement is encoded in bit 0.
|
||||
//
|
||||
// Design note:
|
||||
// At most 10 bytes are needed for 64-bit values. The encoding could
|
||||
// be more dense: a full 64-bit value needs an extra byte just to hold bit 63.
|
||||
// Instead, the msb of the previous byte could be used to hold bit 63 since we
|
||||
// know there can't be more than 64 bits. This is a trivial improvement and
|
||||
// would reduce the maximum encoding length to 9 bytes. However, it breaks the
|
||||
// invariant that the msb is always the "continuation bit" and thus makes the
|
||||
// format incompatible with a varint encoding for larger numbers (say 128-bit).
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer.
|
||||
const (
|
||||
MaxVarintLen16 = 3
|
||||
MaxVarintLen32 = 5
|
||||
MaxVarintLen64 = 10
|
||||
)
|
||||
|
||||
// AppendUvarint appends the varint-encoded form of x,
|
||||
// as generated by [PutUvarint], to buf and returns the extended buffer.
|
||||
func AppendUvarint(buf []byte, x uint64) []byte {
|
||||
for x >= 0x80 {
|
||||
buf = append(buf, byte(x)|0x80)
|
||||
x >>= 7
|
||||
}
|
||||
return append(buf, byte(x))
|
||||
}
|
||||
|
||||
// PutUvarint encodes a uint64 into buf and returns the number of bytes written.
|
||||
// If the buffer is too small, PutUvarint will panic.
|
||||
func PutUvarint(buf []byte, x uint64) int {
|
||||
i := 0
|
||||
for x >= 0x80 {
|
||||
buf[i] = byte(x) | 0x80
|
||||
x >>= 7
|
||||
i++
|
||||
}
|
||||
buf[i] = byte(x)
|
||||
return i + 1
|
||||
}
|
||||
|
||||
// Uvarint decodes a uint64 from buf and returns that value and the
|
||||
// number of bytes read (> 0). If an error occurred, the value is 0
|
||||
// and the number of bytes n is <= 0 meaning:
|
||||
//
|
||||
// n == 0: buf too small
|
||||
// n < 0: value larger than 64 bits (overflow)
|
||||
// and -n is the number of bytes read
|
||||
func Uvarint(buf []byte) (uint64, int) {
|
||||
var x uint64
|
||||
var s uint
|
||||
for i, b := range buf {
|
||||
if i == MaxVarintLen64 {
|
||||
// Catch byte reads past MaxVarintLen64.
|
||||
// See issue https://golang.org/issues/41185
|
||||
return 0, -(i + 1) // overflow
|
||||
}
|
||||
if b < 0x80 {
|
||||
if i == MaxVarintLen64-1 && b > 1 {
|
||||
return 0, -(i + 1) // overflow
|
||||
}
|
||||
return x | uint64(b)<<s, i + 1
|
||||
}
|
||||
x |= uint64(b&0x7f) << s
|
||||
s += 7
|
||||
}
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
// AppendVarint appends the varint-encoded form of x,
|
||||
// as generated by [PutVarint], to buf and returns the extended buffer.
|
||||
func AppendVarint(buf []byte, x int64) []byte {
|
||||
ux := uint64(x) << 1
|
||||
if x < 0 {
|
||||
ux = ^ux
|
||||
}
|
||||
return AppendUvarint(buf, ux)
|
||||
}
|
||||
|
||||
// PutVarint encodes an int64 into buf and returns the number of bytes written.
|
||||
// If the buffer is too small, PutVarint will panic.
|
||||
func PutVarint(buf []byte, x int64) int {
|
||||
ux := uint64(x) << 1
|
||||
if x < 0 {
|
||||
ux = ^ux
|
||||
}
|
||||
return PutUvarint(buf, ux)
|
||||
}
|
||||
|
||||
// Varint decodes an int64 from buf and returns that value and the
|
||||
// number of bytes read (> 0). If an error occurred, the value is 0
|
||||
// and the number of bytes n is <= 0 with the following meaning:
|
||||
//
|
||||
// n == 0: buf too small
|
||||
// n < 0: value larger than 64 bits (overflow)
|
||||
// and -n is the number of bytes read
|
||||
func Varint(buf []byte) (int64, int) {
|
||||
ux, n := Uvarint(buf) // ok to continue in presence of error
|
||||
x := int64(ux >> 1)
|
||||
if ux&1 != 0 {
|
||||
x = ^x
|
||||
}
|
||||
return x, n
|
||||
}
|
||||
|
||||
var errOverflow = errors.New("binary: varint overflows a 64-bit integer")
|
||||
|
||||
// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64.
|
||||
// The error is [io.EOF] only if no bytes were read.
|
||||
// If an [io.EOF] happens after reading some but not all the bytes,
|
||||
// ReadUvarint returns [io.ErrUnexpectedEOF].
|
||||
func ReadUvarint(r io.ByteReader) (uint64, error) {
|
||||
var x uint64
|
||||
var s uint
|
||||
for i := 0; i < MaxVarintLen64; i++ {
|
||||
b, err := r.ReadByte()
|
||||
if err != nil {
|
||||
if i > 0 && err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return x, err
|
||||
}
|
||||
if b < 0x80 {
|
||||
if i == MaxVarintLen64-1 && b > 1 {
|
||||
return x, errOverflow
|
||||
}
|
||||
return x | uint64(b)<<s, nil
|
||||
}
|
||||
x |= uint64(b&0x7f) << s
|
||||
s += 7
|
||||
}
|
||||
return x, errOverflow
|
||||
}
|
||||
|
||||
// ReadVarint reads an encoded signed integer from r and returns it as an int64.
|
||||
// The error is [io.EOF] only if no bytes were read.
|
||||
// If an [io.EOF] happens after reading some but not all the bytes,
|
||||
// ReadVarint returns [io.ErrUnexpectedEOF].
|
||||
func ReadVarint(r io.ByteReader) (int64, error) {
|
||||
ux, err := ReadUvarint(r) // ok to continue in presence of error
|
||||
x := int64(ux >> 1)
|
||||
if ux&1 != 0 {
|
||||
x = ^x
|
||||
}
|
||||
return x, err
|
||||
}
|
|
@ -8,7 +8,7 @@ import (
|
|||
"sync"
|
||||
)
|
||||
|
||||
var DefaultAllocator = newDefaultAllocer()
|
||||
var DefaultAllocator = newDefaultAllocator()
|
||||
|
||||
type Allocator interface {
|
||||
Get(size int) []byte
|
||||
|
@ -17,22 +17,28 @@ type Allocator interface {
|
|||
|
||||
// defaultAllocator for incoming frames, optimized to prevent overwriting after zeroing
|
||||
type defaultAllocator struct {
|
||||
buffers []sync.Pool
|
||||
buffers [11]sync.Pool
|
||||
}
|
||||
|
||||
// NewAllocator initiates a []byte allocator for frames less than 65536 bytes,
|
||||
// the waste(memory fragmentation) of space allocation is guaranteed to be
|
||||
// no more than 50%.
|
||||
func newDefaultAllocer() Allocator {
|
||||
alloc := new(defaultAllocator)
|
||||
alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K
|
||||
for k := range alloc.buffers {
|
||||
i := k
|
||||
alloc.buffers[k].New = func() any {
|
||||
return make([]byte, 1<<uint32(i))
|
||||
func newDefaultAllocator() Allocator {
|
||||
return &defaultAllocator{
|
||||
buffers: [...]sync.Pool{ // 64B -> 64K
|
||||
{New: func() any { return new([1 << 6]byte) }},
|
||||
{New: func() any { return new([1 << 7]byte) }},
|
||||
{New: func() any { return new([1 << 8]byte) }},
|
||||
{New: func() any { return new([1 << 9]byte) }},
|
||||
{New: func() any { return new([1 << 10]byte) }},
|
||||
{New: func() any { return new([1 << 11]byte) }},
|
||||
{New: func() any { return new([1 << 12]byte) }},
|
||||
{New: func() any { return new([1 << 13]byte) }},
|
||||
{New: func() any { return new([1 << 14]byte) }},
|
||||
{New: func() any { return new([1 << 15]byte) }},
|
||||
{New: func() any { return new([1 << 16]byte) }},
|
||||
},
|
||||
}
|
||||
}
|
||||
return alloc
|
||||
}
|
||||
|
||||
// Get a []byte from pool with most appropriate cap
|
||||
|
@ -41,12 +47,42 @@ func (alloc *defaultAllocator) Get(size int) []byte {
|
|||
return nil
|
||||
}
|
||||
|
||||
bits := msb(size)
|
||||
if size == 1<<bits {
|
||||
return alloc.buffers[bits].Get().([]byte)[:size]
|
||||
var index uint16
|
||||
if size > 64 {
|
||||
index = msb(size)
|
||||
if size != 1<<index {
|
||||
index += 1
|
||||
}
|
||||
index -= 6
|
||||
}
|
||||
|
||||
return alloc.buffers[bits+1].Get().([]byte)[:size]
|
||||
buffer := alloc.buffers[index].Get()
|
||||
switch index {
|
||||
case 0:
|
||||
return buffer.(*[1 << 6]byte)[:size]
|
||||
case 1:
|
||||
return buffer.(*[1 << 7]byte)[:size]
|
||||
case 2:
|
||||
return buffer.(*[1 << 8]byte)[:size]
|
||||
case 3:
|
||||
return buffer.(*[1 << 9]byte)[:size]
|
||||
case 4:
|
||||
return buffer.(*[1 << 10]byte)[:size]
|
||||
case 5:
|
||||
return buffer.(*[1 << 11]byte)[:size]
|
||||
case 6:
|
||||
return buffer.(*[1 << 12]byte)[:size]
|
||||
case 7:
|
||||
return buffer.(*[1 << 13]byte)[:size]
|
||||
case 8:
|
||||
return buffer.(*[1 << 14]byte)[:size]
|
||||
case 9:
|
||||
return buffer.(*[1 << 15]byte)[:size]
|
||||
case 10:
|
||||
return buffer.(*[1 << 16]byte)[:size]
|
||||
default:
|
||||
panic("invalid pool index")
|
||||
}
|
||||
}
|
||||
|
||||
// Put returns a []byte to pool for future use,
|
||||
|
@ -56,10 +92,37 @@ func (alloc *defaultAllocator) Put(buf []byte) error {
|
|||
if cap(buf) == 0 || cap(buf) > 65536 || cap(buf) != 1<<bits {
|
||||
return errors.New("allocator Put() incorrect buffer size")
|
||||
}
|
||||
bits -= 6
|
||||
buf = buf[:cap(buf)]
|
||||
|
||||
//nolint
|
||||
//lint:ignore SA6002 ignore temporarily
|
||||
alloc.buffers[bits].Put(buf)
|
||||
switch bits {
|
||||
case 0:
|
||||
alloc.buffers[bits].Put((*[1 << 6]byte)(buf))
|
||||
case 1:
|
||||
alloc.buffers[bits].Put((*[1 << 7]byte)(buf))
|
||||
case 2:
|
||||
alloc.buffers[bits].Put((*[1 << 8]byte)(buf))
|
||||
case 3:
|
||||
alloc.buffers[bits].Put((*[1 << 9]byte)(buf))
|
||||
case 4:
|
||||
alloc.buffers[bits].Put((*[1 << 10]byte)(buf))
|
||||
case 5:
|
||||
alloc.buffers[bits].Put((*[1 << 11]byte)(buf))
|
||||
case 6:
|
||||
alloc.buffers[bits].Put((*[1 << 12]byte)(buf))
|
||||
case 7:
|
||||
alloc.buffers[bits].Put((*[1 << 13]byte)(buf))
|
||||
case 8:
|
||||
alloc.buffers[bits].Put((*[1 << 14]byte)(buf))
|
||||
case 9:
|
||||
alloc.buffers[bits].Put((*[1 << 15]byte)(buf))
|
||||
case 10:
|
||||
alloc.buffers[bits].Put((*[1 << 16]byte)(buf))
|
||||
default:
|
||||
panic("invalid pool index")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -4,29 +4,27 @@ import (
|
|||
"crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/debug"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
)
|
||||
|
||||
const ReversedHeader = 1024
|
||||
|
||||
type Buffer struct {
|
||||
data []byte
|
||||
start int
|
||||
end int
|
||||
refs int32
|
||||
capacity int
|
||||
refs atomic.Int32
|
||||
managed bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func New() *Buffer {
|
||||
return &Buffer{
|
||||
data: Get(BufferSize),
|
||||
start: ReversedHeader,
|
||||
end: ReversedHeader,
|
||||
capacity: BufferSize,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
@ -34,8 +32,7 @@ func New() *Buffer {
|
|||
func NewPacket() *Buffer {
|
||||
return &Buffer{
|
||||
data: Get(UDPBufferSize),
|
||||
start: ReversedHeader,
|
||||
end: ReversedHeader,
|
||||
capacity: UDPBufferSize,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
@ -46,61 +43,28 @@ func NewSize(size int) *Buffer {
|
|||
} else if size > 65535 {
|
||||
return &Buffer{
|
||||
data: make([]byte, size),
|
||||
capacity: size,
|
||||
}
|
||||
}
|
||||
return &Buffer{
|
||||
data: Get(size),
|
||||
capacity: size,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
||||
func StackNew() *Buffer {
|
||||
if common.UnsafeBuffer {
|
||||
return &Buffer{
|
||||
data: make([]byte, BufferSize),
|
||||
start: ReversedHeader,
|
||||
end: ReversedHeader,
|
||||
}
|
||||
} else {
|
||||
return New()
|
||||
}
|
||||
}
|
||||
|
||||
func StackNewPacket() *Buffer {
|
||||
if common.UnsafeBuffer {
|
||||
return &Buffer{
|
||||
data: make([]byte, UDPBufferSize),
|
||||
start: ReversedHeader,
|
||||
end: ReversedHeader,
|
||||
}
|
||||
} else {
|
||||
return NewPacket()
|
||||
}
|
||||
}
|
||||
|
||||
func StackNewSize(size int) *Buffer {
|
||||
if size == 0 {
|
||||
return &Buffer{}
|
||||
}
|
||||
if common.UnsafeBuffer {
|
||||
return &Buffer{
|
||||
data: Make(size),
|
||||
}
|
||||
} else {
|
||||
return NewSize(size)
|
||||
}
|
||||
}
|
||||
|
||||
func As(data []byte) *Buffer {
|
||||
return &Buffer{
|
||||
data: data,
|
||||
end: len(data),
|
||||
capacity: len(data),
|
||||
}
|
||||
}
|
||||
|
||||
func With(data []byte) *Buffer {
|
||||
return &Buffer{
|
||||
data: data,
|
||||
capacity: len(data),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -114,8 +78,8 @@ func (b *Buffer) SetByte(index int, value byte) {
|
|||
|
||||
func (b *Buffer) Extend(n int) []byte {
|
||||
end := b.end + n
|
||||
if end > cap(b.data) {
|
||||
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",end " + strconv.Itoa(b.end) + ", need " + strconv.Itoa(n))
|
||||
if end > b.capacity {
|
||||
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",end ", b.end, ", need ", n))
|
||||
}
|
||||
ext := b.data[b.end:end]
|
||||
b.end = end
|
||||
|
@ -137,14 +101,14 @@ func (b *Buffer) Write(data []byte) (n int, err error) {
|
|||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n = copy(b.data[b.end:], data)
|
||||
n = copy(b.data[b.end:b.capacity], data)
|
||||
b.end += n
|
||||
return
|
||||
}
|
||||
|
||||
func (b *Buffer) ExtendHeader(n int) []byte {
|
||||
if b.start < n {
|
||||
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",start " + strconv.Itoa(b.start) + ", need " + strconv.Itoa(n))
|
||||
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",start ", b.start, ", need ", n))
|
||||
}
|
||||
b.start -= n
|
||||
return b.data[b.start : b.start+n]
|
||||
|
@ -197,7 +161,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
|
|||
}
|
||||
|
||||
func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) {
|
||||
if b.end+size > b.Cap() {
|
||||
if b.end+size > b.capacity {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n, err = io.ReadFull(r, b.data[b.end:b.end+size])
|
||||
|
@ -234,7 +198,7 @@ func (b *Buffer) WriteString(s string) (n int, err error) {
|
|||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n = copy(b.data[b.end:], s)
|
||||
n = copy(b.data[b.end:b.capacity], s)
|
||||
b.end += n
|
||||
return
|
||||
}
|
||||
|
@ -249,13 +213,10 @@ func (b *Buffer) WriteZero() error {
|
|||
}
|
||||
|
||||
func (b *Buffer) WriteZeroN(n int) error {
|
||||
if b.end+n > b.Cap() {
|
||||
if b.end+n > b.capacity {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
for i := b.end; i <= b.end+n; i++ {
|
||||
b.data[i] = 0
|
||||
}
|
||||
b.end += n
|
||||
common.ClearArray(b.Extend(n))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -298,40 +259,63 @@ func (b *Buffer) Resize(start, end int) {
|
|||
b.end = b.start + end
|
||||
}
|
||||
|
||||
func (b *Buffer) Reset() {
|
||||
b.start = ReversedHeader
|
||||
b.end = ReversedHeader
|
||||
func (b *Buffer) Reserve(n int) {
|
||||
if n > b.capacity {
|
||||
panic(F.ToString("buffer overflow: capacity ", b.capacity, ", need ", n))
|
||||
}
|
||||
b.capacity -= n
|
||||
}
|
||||
|
||||
func (b *Buffer) FullReset() {
|
||||
func (b *Buffer) OverCap(n int) {
|
||||
if b.capacity+n > len(b.data) {
|
||||
panic(F.ToString("buffer overflow: capacity ", len(b.data), ", need ", b.capacity+n))
|
||||
}
|
||||
b.capacity += n
|
||||
}
|
||||
|
||||
func (b *Buffer) Reset() {
|
||||
b.start = 0
|
||||
b.end = 0
|
||||
b.capacity = len(b.data)
|
||||
}
|
||||
|
||||
// Deprecated: use Reset instead.
|
||||
func (b *Buffer) FullReset() {
|
||||
b.Reset()
|
||||
}
|
||||
|
||||
func (b *Buffer) IncRef() {
|
||||
atomic.AddInt32(&b.refs, 1)
|
||||
b.refs.Add(1)
|
||||
}
|
||||
|
||||
func (b *Buffer) DecRef() {
|
||||
atomic.AddInt32(&b.refs, -1)
|
||||
b.refs.Add(-1)
|
||||
}
|
||||
|
||||
func (b *Buffer) Release() {
|
||||
if b == nil || b.closed || !b.managed {
|
||||
if b == nil || !b.managed {
|
||||
return
|
||||
}
|
||||
if atomic.LoadInt32(&b.refs) > 0 {
|
||||
if b.refs.Load() > 0 {
|
||||
return
|
||||
}
|
||||
common.Must(Put(b.data))
|
||||
*b = Buffer{closed: true}
|
||||
*b = Buffer{}
|
||||
}
|
||||
|
||||
func (b *Buffer) Cut(start int, end int) *Buffer {
|
||||
b.start += start
|
||||
b.end = len(b.data) - end
|
||||
return &Buffer{
|
||||
data: b.data[b.start:b.end],
|
||||
func (b *Buffer) Leak() {
|
||||
if debug.Enabled {
|
||||
if b == nil || !b.managed {
|
||||
return
|
||||
}
|
||||
refs := b.refs.Load()
|
||||
if refs == 0 {
|
||||
panic("leaking buffer")
|
||||
} else {
|
||||
panic(F.ToString("leaking buffer with ", refs, " references"))
|
||||
}
|
||||
} else {
|
||||
b.Release()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -344,6 +328,10 @@ func (b *Buffer) Len() int {
|
|||
}
|
||||
|
||||
func (b *Buffer) Cap() int {
|
||||
return b.capacity
|
||||
}
|
||||
|
||||
func (b *Buffer) RawCap() int {
|
||||
return len(b.data)
|
||||
}
|
||||
|
||||
|
@ -351,10 +339,6 @@ func (b *Buffer) Bytes() []byte {
|
|||
return b.data[b.start:b.end]
|
||||
}
|
||||
|
||||
func (b *Buffer) Slice() []byte {
|
||||
return b.data
|
||||
}
|
||||
|
||||
func (b *Buffer) From(n int) []byte {
|
||||
return b.data[b.start+n : b.end]
|
||||
}
|
||||
|
@ -372,11 +356,11 @@ func (b *Buffer) Index(start int) []byte {
|
|||
}
|
||||
|
||||
func (b *Buffer) FreeLen() int {
|
||||
return b.Cap() - b.end
|
||||
return b.capacity - b.end
|
||||
}
|
||||
|
||||
func (b *Buffer) FreeBytes() []byte {
|
||||
return b.data[b.end:b.Cap()]
|
||||
return b.data[b.end:b.capacity]
|
||||
}
|
||||
|
||||
func (b *Buffer) IsEmpty() bool {
|
||||
|
@ -384,7 +368,7 @@ func (b *Buffer) IsEmpty() bool {
|
|||
}
|
||||
|
||||
func (b *Buffer) IsFull() bool {
|
||||
return b.end == b.Cap()
|
||||
return b.end == b.capacity
|
||||
}
|
||||
|
||||
func (b *Buffer) ToOwned() *Buffer {
|
||||
|
@ -392,5 +376,6 @@ func (b *Buffer) ToOwned() *Buffer {
|
|||
copy(n.data[b.start:b.end], b.data[b.start:b.end])
|
||||
n.start = b.start
|
||||
n.end = b.end
|
||||
n.capacity = b.capacity
|
||||
return n
|
||||
}
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
package buf
|
||||
|
||||
import "encoding/hex"
|
||||
|
||||
func EncodeHexString(src []byte) string {
|
||||
dst := Make(hex.EncodedLen(len(src)))
|
||||
hex.Encode(dst, src)
|
||||
return string(dst)
|
||||
}
|
|
@ -11,46 +11,7 @@ func Put(buf []byte) error {
|
|||
return DefaultAllocator.Put(buf)
|
||||
}
|
||||
|
||||
// Deprecated: use array instead.
|
||||
func Make(size int) []byte {
|
||||
if size == 0 {
|
||||
return nil
|
||||
}
|
||||
var buffer []byte
|
||||
switch {
|
||||
case size <= 2:
|
||||
buffer = make([]byte, 2)
|
||||
case size <= 4:
|
||||
buffer = make([]byte, 4)
|
||||
case size <= 8:
|
||||
buffer = make([]byte, 8)
|
||||
case size <= 16:
|
||||
buffer = make([]byte, 16)
|
||||
case size <= 32:
|
||||
buffer = make([]byte, 32)
|
||||
case size <= 64:
|
||||
buffer = make([]byte, 64)
|
||||
case size <= 128:
|
||||
buffer = make([]byte, 128)
|
||||
case size <= 256:
|
||||
buffer = make([]byte, 256)
|
||||
case size <= 512:
|
||||
buffer = make([]byte, 512)
|
||||
case size <= 1024:
|
||||
buffer = make([]byte, 1024)
|
||||
case size <= 2048:
|
||||
buffer = make([]byte, 2048)
|
||||
case size <= 4096:
|
||||
buffer = make([]byte, 4096)
|
||||
case size <= 8192:
|
||||
buffer = make([]byte, 8192)
|
||||
case size <= 16384:
|
||||
buffer = make([]byte, 16384)
|
||||
case size <= 32768:
|
||||
buffer = make([]byte, 32768)
|
||||
case size <= 65535:
|
||||
buffer = make([]byte, 65535)
|
||||
default:
|
||||
return make([]byte, size)
|
||||
}
|
||||
return buffer[:size]
|
||||
}
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
//go:build !disable_unsafe
|
||||
|
||||
package buf
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
type dbgVar struct {
|
||||
name string
|
||||
value *int32
|
||||
}
|
||||
|
||||
//go:linkname dbgvars runtime.dbgvars
|
||||
var dbgvars any
|
||||
|
||||
// go.info.runtime.dbgvars: relocation target go.info.[]github.com/sagernet/sing/common/buf.dbgVar not defined
|
||||
// var dbgvars []dbgVar
|
||||
|
||||
func init() {
|
||||
if !common.UnsafeBuffer {
|
||||
return
|
||||
}
|
||||
debugVars := *(*[]dbgVar)(unsafe.Pointer(&dbgvars))
|
||||
for _, v := range debugVars {
|
||||
if v.name == "invalidptr" {
|
||||
*v.value = 0
|
||||
return
|
||||
}
|
||||
}
|
||||
panic("can't disable invalidptr")
|
||||
}
|
34
common/bufio/addr_bsd.go
Normal file
34
common/bufio/addr_bsd.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package bufio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
|
||||
if destination.Addr().Is4() {
|
||||
sa := unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: destination.Addr().As4(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = unix.SizeofSockaddrInet4
|
||||
} else {
|
||||
sa := unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: destination.Addr().As16(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = unix.SizeofSockaddrInet6
|
||||
}
|
||||
return
|
||||
}
|
30
common/bufio/addr_linux.go
Normal file
30
common/bufio/addr_linux.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
|
||||
if destination.Addr().Is4() {
|
||||
sa := unix.RawSockaddrInet4{
|
||||
Family: unix.AF_INET,
|
||||
Addr: destination.Addr().As4(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = unix.SizeofSockaddrInet4
|
||||
} else {
|
||||
sa := unix.RawSockaddrInet6{
|
||||
Family: unix.AF_INET6,
|
||||
Addr: destination.Addr().As16(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = unix.SizeofSockaddrInet6
|
||||
}
|
||||
return
|
||||
}
|
30
common/bufio/addr_windows.go
Normal file
30
common/bufio/addr_windows.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen int32) {
|
||||
if destination.Addr().Is4() {
|
||||
sa := windows.RawSockaddrInet4{
|
||||
Family: windows.AF_INET,
|
||||
Addr: destination.Addr().As4(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = int32(unsafe.Sizeof(sa))
|
||||
} else {
|
||||
sa := windows.RawSockaddrInet6{
|
||||
Family: windows.AF_INET6,
|
||||
Addr: destination.Addr().As16(),
|
||||
}
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
|
||||
name = unsafe.Pointer(&sa)
|
||||
nameLen = int32(unsafe.Sizeof(sa))
|
||||
}
|
||||
return
|
||||
}
|
|
@ -8,51 +8,76 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type BindPacketConn struct {
|
||||
type BindPacketConn interface {
|
||||
N.NetPacketConn
|
||||
Addr net.Addr
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) *BindPacketConn {
|
||||
return &BindPacketConn{
|
||||
type bindPacketConn struct {
|
||||
N.NetPacketConn
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) BindPacketConn {
|
||||
return &bindPacketConn{
|
||||
NewPacketConn(conn),
|
||||
addr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) Read(b []byte) (n int, err error) {
|
||||
func (c *bindPacketConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.Addr)
|
||||
func (c *bindPacketConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.addr)
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) RemoteAddr() net.Addr {
|
||||
return c.Addr
|
||||
func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
|
||||
if !isReadWaiter {
|
||||
return nil, false
|
||||
}
|
||||
return &bindPacketReadWaiter{readWaiter}, true
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) Upstream() any {
|
||||
func (c *bindPacketConn) RemoteAddr() net.Addr {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (c *bindPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
var (
|
||||
_ N.NetPacketConn = (*UnbindPacketConn)(nil)
|
||||
_ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil)
|
||||
)
|
||||
|
||||
type UnbindPacketConn struct {
|
||||
N.ExtendedConn
|
||||
Addr M.Socksaddr
|
||||
addr M.Socksaddr
|
||||
}
|
||||
|
||||
func NewUnbindPacketConn(conn net.Conn) *UnbindPacketConn {
|
||||
func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn {
|
||||
return &UnbindPacketConn{
|
||||
NewExtendedConn(conn),
|
||||
M.SocksaddrFromNet(conn.RemoteAddr()),
|
||||
}
|
||||
}
|
||||
|
||||
func NewUnbindPacketConnWithAddr(conn net.Conn, addr M.Socksaddr) N.NetPacketConn {
|
||||
return &UnbindPacketConn{
|
||||
NewExtendedConn(conn),
|
||||
addr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, err = c.ExtendedConn.Read(p)
|
||||
if err == nil {
|
||||
addr = c.Addr.UDPAddr()
|
||||
addr = c.addr.UDPAddr()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -66,7 +91,7 @@ func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination = c.Addr
|
||||
destination = c.addr
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -74,6 +99,67 @@ func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error
|
|||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
|
||||
readWaiter, isReadWaiter := CreateReadWaiter(c.ExtendedConn)
|
||||
if !isReadWaiter {
|
||||
return nil, false
|
||||
}
|
||||
return &unbindPacketReadWaiter{readWaiter, c.addr}, true
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
func NewServerPacketConn(conn net.PacketConn) N.ExtendedConn {
|
||||
return &serverPacketConn{
|
||||
NetPacketConn: NewPacketConn(conn),
|
||||
}
|
||||
}
|
||||
|
||||
type serverPacketConn struct {
|
||||
N.NetPacketConn
|
||||
remoteAddr M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Read(p []byte) (n int, err error) {
|
||||
n, addr, err := c.NetPacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.remoteAddr = M.SocksaddrFromNet(addr)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
destination, err := c.NetPacketConn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.remoteAddr = destination
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Write(p []byte) (n int, err error) {
|
||||
return c.NetPacketConn.WriteTo(p, c.remoteAddr.UDPAddr())
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
return c.NetPacketConn.WritePacket(buffer, c.remoteAddr)
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
|
||||
if !isReadWaiter {
|
||||
return nil, false
|
||||
}
|
||||
return &serverPacketReadWaiter{c, readWaiter}, true
|
||||
}
|
||||
|
|
62
common/bufio/bind_wait.go
Normal file
62
common/bufio/bind_wait.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ N.ReadWaiter = (*bindPacketReadWaiter)(nil)
|
||||
|
||||
type bindPacketReadWaiter struct {
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (w *bindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return w.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (w *bindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
buffer, _, err = w.readWaiter.WaitReadPacket()
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*unbindPacketReadWaiter)(nil)
|
||||
|
||||
type unbindPacketReadWaiter struct {
|
||||
readWaiter N.ReadWaiter
|
||||
addr M.Socksaddr
|
||||
}
|
||||
|
||||
func (w *unbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return w.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (w *unbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
buffer, err = w.readWaiter.WaitReadBuffer()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination = w.addr
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.ReadWaiter = (*serverPacketReadWaiter)(nil)
|
||||
|
||||
type serverPacketReadWaiter struct {
|
||||
*serverPacketConn
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (w *serverPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return w.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (w *serverPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
buffer, destination, err := w.readWaiter.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.remoteAddr = destination
|
||||
return
|
||||
}
|
|
@ -37,7 +37,7 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.buffer.FullReset()
|
||||
w.buffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ func (c *ChunkReader) ReadBuffer(buffer *buf.Buffer) error {
|
|||
} else if !c.cache.IsEmpty() {
|
||||
return common.Error(buffer.ReadFrom(c.cache))
|
||||
}
|
||||
c.cache.FullReset()
|
||||
c.cache.Reset()
|
||||
err := c.upstream.ReadBuffer(c.cache)
|
||||
if err != nil {
|
||||
c.cache.Release()
|
||||
|
@ -46,7 +46,7 @@ func (c *ChunkReader) Read(p []byte) (n int, err error) {
|
|||
} else if !c.cache.IsEmpty() {
|
||||
return c.cache.Read(p)
|
||||
}
|
||||
c.cache.FullReset()
|
||||
c.cache.Reset()
|
||||
err = c.upstream.ReadBuffer(c.cache)
|
||||
if err != nil {
|
||||
c.cache.Release()
|
||||
|
@ -70,7 +70,7 @@ func (c *ChunkReader) ReadChunk() (*buf.Buffer, error) {
|
|||
} else if !c.cache.IsEmpty() {
|
||||
return c.cache, nil
|
||||
}
|
||||
c.cache.FullReset()
|
||||
c.cache.Reset()
|
||||
err := c.upstream.ReadBuffer(c.cache)
|
||||
if err != nil {
|
||||
c.cache.Release()
|
||||
|
|
|
@ -22,7 +22,7 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
|||
} else if destination == nil {
|
||||
return 0, E.New("nil writer")
|
||||
}
|
||||
originDestination := destination
|
||||
originSource := source
|
||||
var readCounters, writeCounters []N.CountFunc
|
||||
for {
|
||||
source, readCounters = N.UnwrapCountReader(source, readCounters)
|
||||
|
@ -45,105 +45,61 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
|||
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
||||
if srcIsSyscall && dstIsSyscall {
|
||||
var handled bool
|
||||
handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
return CopyExtended(originDestination, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
||||
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
||||
}
|
||||
|
||||
func CopyExtended(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
safeSrc := N.IsSafeReader(source)
|
||||
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
|
||||
if safeSrc != nil {
|
||||
if headroom == 0 {
|
||||
return CopyExtendedWithSrcBuffer(originDestination, destination, safeSrc, readCounters, writeCounters)
|
||||
}
|
||||
}
|
||||
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
readWaiter, isReadWaiter := CreateReadWaiter(source)
|
||||
if isReadWaiter {
|
||||
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
FrontHeadroom: frontHeadroom,
|
||||
RearHeadroom: rearHeadroom,
|
||||
MTU: N.CalculateMTU(source, destination),
|
||||
})
|
||||
if !needCopy || common.LowMemory {
|
||||
var handled bool
|
||||
handled, n, err = copyWaitWithPool(originDestination, destination, readWaiter, readCounters, writeCounters)
|
||||
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
if !common.UnsafeBuffer || N.IsUnsafeWriter(destination) {
|
||||
return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters)
|
||||
}
|
||||
bufferSize := N.CalculateMTU(source, destination)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += headroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
_buffer := buf.StackNewSize(bufferSize)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
return CopyExtendedBuffer(originDestination, destination, source, buffer, readCounters, writeCounters)
|
||||
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
|
||||
}
|
||||
|
||||
func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
buffer.Resize(frontHeadroom, 0)
|
||||
buffer.Reserve(rearHeadroom)
|
||||
var notFirstTime bool
|
||||
for {
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = source.ReadBuffer(readBuffer)
|
||||
err = source.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originDestination, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
var notFirstTime bool
|
||||
for {
|
||||
var buffer *buf.Buffer
|
||||
buffer, err = source.ReadBufferThreadSafe()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originDestination, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
buffer.OverCap(rearHeadroom)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
@ -157,7 +113,7 @@ func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.Extend
|
|||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
bufferSize := N.CalculateMTU(source, destination)
|
||||
|
@ -169,26 +125,25 @@ func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWri
|
|||
var notFirstTime bool
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = source.ReadBuffer(readBuffer)
|
||||
buffer.Resize(frontHeadroom, 0)
|
||||
buffer.Reserve(rearHeadroom)
|
||||
err = source.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originDestination, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
dataLen := buffer.Len()
|
||||
buffer.OverCap(rearHeadroom)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
@ -249,6 +204,7 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
|
|||
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
|
||||
var readCounters, writeCounters []N.CountFunc
|
||||
var cachedPackets []*N.PacketBuffer
|
||||
originSource := source
|
||||
for {
|
||||
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
|
||||
destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters)
|
||||
|
@ -262,113 +218,38 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
|
|||
break
|
||||
}
|
||||
if cachedPackets != nil {
|
||||
n, err = WritePacketWithPool(destinationConn, cachedPackets)
|
||||
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
safeSrc := N.IsSafePacketReader(source)
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
headroom := frontHeadroom + rearHeadroom
|
||||
if safeSrc != nil {
|
||||
if headroom == 0 {
|
||||
var copyN int64
|
||||
copyN, err = CopyPacketWithSrcBuffer(destinationConn, safeSrc, readCounters, writeCounters)
|
||||
n += copyN
|
||||
return
|
||||
}
|
||||
}
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
||||
if isReadWaiter {
|
||||
var (
|
||||
handled bool
|
||||
copeN int64
|
||||
)
|
||||
handled, copeN, err = copyPacketWaitWithPool(destinationConn, readWaiter, readCounters, writeCounters)
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
||||
if isReadWaiter {
|
||||
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
FrontHeadroom: frontHeadroom,
|
||||
RearHeadroom: rearHeadroom,
|
||||
MTU: N.CalculateMTU(source, destinationConn),
|
||||
})
|
||||
if !needCopy || common.LowMemory {
|
||||
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
||||
if handled {
|
||||
n += copeN
|
||||
return
|
||||
}
|
||||
}
|
||||
if N.IsUnsafeWriter(destinationConn) {
|
||||
return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters)
|
||||
}
|
||||
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += headroom
|
||||
} else {
|
||||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
_buffer := buf.StackNewSize(bufferSize)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
for {
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
destination, err = source.ReadPacket(readBuffer)
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(destinationConn, err)
|
||||
}
|
||||
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
|
||||
n += copeN
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
var buffer *buf.Buffer
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer, destination, err = source.ReadPacketThreadSafe()
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(destinationConn, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
if dataLen == 0 {
|
||||
continue
|
||||
}
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||
|
@ -378,25 +259,23 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r
|
|||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
destination, err = source.ReadPacket(readBuffer)
|
||||
buffer.Resize(frontHeadroom, 0)
|
||||
buffer.Reserve(rearHeadroom)
|
||||
destination, err = source.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(destinationConn, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
dataLen := buffer.Len()
|
||||
buffer.OverCap(rearHeadroom)
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
@ -410,24 +289,28 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r
|
|||
}
|
||||
}
|
||||
|
||||
func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
|
||||
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
var notFirstTime bool
|
||||
for _, packetBuffer := range packetBuffers {
|
||||
buffer := buf.NewPacket()
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
_, err = readBuffer.Write(packetBuffer.Buffer.Bytes())
|
||||
buffer.Resize(frontHeadroom, 0)
|
||||
buffer.Reserve(rearHeadroom)
|
||||
_, err = buffer.Write(packetBuffer.Buffer.Bytes())
|
||||
packetBuffer.Buffer.Release()
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
continue
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
dataLen := buffer.Len()
|
||||
buffer.OverCap(rearHeadroom)
|
||||
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||
func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||
rawSource, err := source.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -18,3 +22,69 @@ func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.
|
|||
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
|
||||
return
|
||||
}
|
||||
|
||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
notFirstTime bool
|
||||
)
|
||||
for {
|
||||
buffer, err = source.WaitReadBuffer()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
)
|
||||
for {
|
||||
buffer, destination, err = source.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
@ -15,115 +14,14 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
bufferSize := N.CalculateMTU(source, destination)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
notFirstTime bool
|
||||
)
|
||||
source.InitializeReadWaiter(func() *buf.Buffer {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
return readBuffer
|
||||
})
|
||||
defer source.InitializeReadWaiter(nil)
|
||||
for {
|
||||
err = source.WaitReadBuffer()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originDestination, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
notFirstTime bool
|
||||
)
|
||||
source.InitializeReadWaiter(func() *buf.Buffer {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
return readBuffer
|
||||
})
|
||||
defer source.InitializeReadWaiter(nil)
|
||||
for {
|
||||
destination, err = source.WaitReadPacket()
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(destinationConn, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
||||
|
||||
type syscallReadWaiter struct {
|
||||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
||||
|
@ -136,47 +34,48 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readErr = nil
|
||||
if newBuffer == nil {
|
||||
w.readFunc = nil
|
||||
} else {
|
||||
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer := newBuffer()
|
||||
buffer := w.options.NewBuffer()
|
||||
var readN int
|
||||
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
} else {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if w.readErr == syscall.EAGAIN {
|
||||
return false
|
||||
}
|
||||
if readN == 0 {
|
||||
if readN == 0 && w.readErr == nil {
|
||||
w.readErr = io.EOF
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) WaitReadBuffer() error {
|
||||
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
if w.readFunc == nil {
|
||||
return os.ErrInvalid
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
err := w.rawConn.Read(w.readFunc)
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return err
|
||||
return
|
||||
}
|
||||
if w.readErr != nil {
|
||||
if w.readErr == io.EOF {
|
||||
return io.EOF
|
||||
return nil, io.EOF
|
||||
}
|
||||
return E.Cause(w.readErr, "raw read")
|
||||
return nil, E.Cause(w.readErr, "raw read")
|
||||
}
|
||||
return nil
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
|
||||
|
@ -186,6 +85,8 @@ type syscallPacketReadWaiter struct {
|
|||
readErr error
|
||||
readFrom M.Socksaddr
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
|
||||
|
@ -198,42 +99,37 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readErr = nil
|
||||
w.readFrom = M.Socksaddr{}
|
||||
if newBuffer == nil {
|
||||
w.readFunc = nil
|
||||
} else {
|
||||
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer := newBuffer()
|
||||
buffer := w.options.NewPacketBuffer()
|
||||
var readN int
|
||||
var from syscall.Sockaddr
|
||||
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if w.readErr != nil {
|
||||
buffer.Release()
|
||||
return w.readErr != syscall.EAGAIN
|
||||
}
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
} else {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
if w.readErr == syscall.EAGAIN {
|
||||
return false
|
||||
}
|
||||
if from != nil {
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
switch fromAddr := from.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
|
||||
case *syscall.SockaddrInet6:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port))
|
||||
}
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
if w.readFunc == nil {
|
||||
return M.Socksaddr{}, os.ErrInvalid
|
||||
return nil, M.Socksaddr{}, os.ErrInvalid
|
||||
}
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
|
@ -243,6 +139,8 @@ func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err
|
|||
err = E.Cause(w.readErr, "raw read")
|
||||
return
|
||||
}
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
destination = w.readFrom
|
||||
return
|
||||
}
|
||||
|
|
77
common/bufio/copy_direct_test.go
Normal file
77
common/bufio/copy_direct_test.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCopyWaitTCP(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn := TCPPipe(t)
|
||||
readWaiter, created := CreateReadWaiter(outputConn)
|
||||
require.True(t, created)
|
||||
require.NotNil(t, readWaiter)
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
|
||||
require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{
|
||||
Conn: outputConn,
|
||||
readWaiter: readWaiter,
|
||||
}))
|
||||
}
|
||||
|
||||
type readWaitWrapper struct {
|
||||
net.Conn
|
||||
readWaiter N.ReadWaiter
|
||||
buffer *buf.Buffer
|
||||
}
|
||||
|
||||
func (r *readWaitWrapper) Read(p []byte) (n int, err error) {
|
||||
if r.buffer != nil {
|
||||
if r.buffer.Len() > 0 {
|
||||
return r.buffer.Read(p)
|
||||
}
|
||||
if r.buffer.IsEmpty() {
|
||||
r.buffer.Release()
|
||||
r.buffer = nil
|
||||
}
|
||||
}
|
||||
buffer, err := r.readWaiter.WaitReadBuffer()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r.buffer = buffer
|
||||
return r.buffer.Read(p)
|
||||
}
|
||||
|
||||
func TestCopyWaitUDP(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn, outputAddr := UDPPipe(t)
|
||||
readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn))
|
||||
require.True(t, created)
|
||||
require.NotNil(t, readWaiter)
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
|
||||
require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{
|
||||
PacketConn: outputConn,
|
||||
readWaiter: readWaiter,
|
||||
}, outputAddr))
|
||||
}
|
||||
|
||||
type packetReadWaitWrapper struct {
|
||||
net.PacketConn
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
buffer, destination, err := r.readWaiter.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n = copy(p, buffer.Bytes())
|
||||
buffer.Release()
|
||||
addr = destination.UDPAddr()
|
||||
return
|
||||
}
|
|
@ -2,22 +2,206 @@ package bufio
|
|||
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
var modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||
|
||||
var procrecv = modws2_32.NewProc("recv")
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
|
||||
var _p0 *byte
|
||||
if len(buf) > 0 {
|
||||
_p0 = &buf[0]
|
||||
}
|
||||
r0, _, e1 := syscall.SyscallN(procrecv.Addr(), uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags))
|
||||
n = int32(r0)
|
||||
if n == -1 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
||||
|
||||
type syscallReadWaiter struct {
|
||||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
hasData bool
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
||||
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
|
||||
rawConn, err := syscallConn.SyscallConn()
|
||||
if err == nil {
|
||||
return &syscallReadWaiter{rawConn: rawConn}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
if !w.hasData {
|
||||
w.hasData = true
|
||||
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
|
||||
// socket is readable if we return false. So the `recv` syscall will not block the system thread.
|
||||
return false
|
||||
}
|
||||
buffer := w.options.NewBuffer()
|
||||
var readN int32
|
||||
readN, w.readErr = recv(windows.Handle(fd), buffer.FreeBytes(), 0)
|
||||
if readN > 0 {
|
||||
buffer.Truncate(int(readN))
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
} else {
|
||||
buffer.Release()
|
||||
}
|
||||
if w.readErr == windows.WSAEWOULDBLOCK {
|
||||
return false
|
||||
}
|
||||
if readN == 0 && w.readErr == nil {
|
||||
w.readErr = io.EOF
|
||||
}
|
||||
w.hasData = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
if w.readFunc == nil {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if w.readErr != nil {
|
||||
if w.readErr == io.EOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, E.Cause(w.readErr, "raw read")
|
||||
}
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
return
|
||||
}
|
||||
|
||||
func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) {
|
||||
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
|
||||
|
||||
type syscallPacketReadWaiter struct {
|
||||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
readFrom M.Socksaddr
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
hasData bool
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
|
||||
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
|
||||
rawConn, err := syscallConn.SyscallConn()
|
||||
if err == nil {
|
||||
return &syscallPacketReadWaiter{rawConn: rawConn}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func createSyscallPacketReadWaiter(reader any) (N.PacketReadWaiter, bool) {
|
||||
return nil, false
|
||||
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
if !w.hasData {
|
||||
w.hasData = true
|
||||
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
|
||||
// socket is readable if we return false. So the `recvfrom` syscall will not block the system thread.
|
||||
return false
|
||||
}
|
||||
buffer := w.options.NewPacketBuffer()
|
||||
var readN int
|
||||
var from windows.Sockaddr
|
||||
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
} else {
|
||||
buffer.Release()
|
||||
}
|
||||
if w.readErr == windows.WSAEWOULDBLOCK {
|
||||
return false
|
||||
}
|
||||
if from != nil {
|
||||
switch fromAddr := from.(type) {
|
||||
case *windows.SockaddrInet4:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
|
||||
case *windows.SockaddrInet6:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
||||
}
|
||||
}
|
||||
w.hasData = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
if w.readFunc == nil {
|
||||
return nil, M.Socksaddr{}, os.ErrInvalid
|
||||
}
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if w.readErr != nil {
|
||||
err = E.Cause(w.readErr, "raw read")
|
||||
return
|
||||
}
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
destination = w.readFrom
|
||||
return
|
||||
}
|
||||
|
|
|
@ -14,18 +14,18 @@ type Conn struct {
|
|||
reader Reader
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn) *Conn {
|
||||
func NewConn(conn net.Conn) N.ExtendedConn {
|
||||
if deadlineConn, isDeadline := conn.(*Conn); isDeadline {
|
||||
return deadlineConn
|
||||
}
|
||||
return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)}
|
||||
return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)})
|
||||
}
|
||||
|
||||
func NewFallbackConn(conn net.Conn) *Conn {
|
||||
func NewFallbackConn(conn net.Conn) N.ExtendedConn {
|
||||
if deadlineConn, isDeadline := conn.(*Conn); isDeadline {
|
||||
return deadlineConn
|
||||
}
|
||||
return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)}
|
||||
return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)})
|
||||
}
|
||||
|
||||
func (c *Conn) Read(p []byte) (n int, err error) {
|
||||
|
|
|
@ -14,18 +14,18 @@ type PacketConn struct {
|
|||
reader PacketReader
|
||||
}
|
||||
|
||||
func NewPacketConn(conn N.NetPacketConn) *PacketConn {
|
||||
func NewPacketConn(conn N.NetPacketConn) N.NetPacketConn {
|
||||
if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline {
|
||||
return deadlineConn
|
||||
}
|
||||
return &PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)}
|
||||
return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)})
|
||||
}
|
||||
|
||||
func NewFallbackPacketConn(conn N.NetPacketConn) *PacketConn {
|
||||
func NewFallbackPacketConn(conn N.NetPacketConn) N.NetPacketConn {
|
||||
if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline {
|
||||
return deadlineConn
|
||||
}
|
||||
return &PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)}
|
||||
return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)})
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
|
|
|
@ -52,14 +52,13 @@ func (r *packetReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, nil, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
go r.pipeReadFrom(len(p))
|
||||
default:
|
||||
}
|
||||
return r.readFrom(p)
|
||||
}
|
||||
|
||||
func (r *packetReader) readFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
|
@ -106,14 +105,13 @@ func (r *packetReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr,
|
|||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
go r.pipeReadFromBuffer(buffer.FreeLen())
|
||||
default:
|
||||
go r.pipeReadFrom(buffer.FreeLen())
|
||||
}
|
||||
return r.readPacket(buffer)
|
||||
}
|
||||
|
||||
func (r *packetReader) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
|
@ -134,17 +132,6 @@ func (r *packetReader) pipeReturnFromBuffer(result *packetReadResult, buffer *bu
|
|||
}
|
||||
}
|
||||
|
||||
func (r *packetReader) pipeReadFromBuffer(pLen int) {
|
||||
buffer := buf.NewSize(pLen)
|
||||
destination, err := r.TimeoutPacketReader.ReadPacket(buffer)
|
||||
r.result <- &packetReadResult{
|
||||
buffer: buffer,
|
||||
destination: destination,
|
||||
err: err,
|
||||
}
|
||||
r.done <- struct{}{}
|
||||
}
|
||||
|
||||
func (r *packetReader) SetReadDeadline(t time.Time) error {
|
||||
r.deadline.Store(t)
|
||||
r.pipeDeadline.set(t)
|
||||
|
|
|
@ -2,6 +2,7 @@ package deadline
|
|||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
|
@ -25,12 +26,15 @@ func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err err
|
|||
return r.pipeReturnFrom(result, p)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, nil, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
if r.disablePipe.Load() {
|
||||
return r.TimeoutPacketReader.ReadFrom(p)
|
||||
}
|
||||
select {
|
||||
case <-r.done:
|
||||
if r.deadline.Load().IsZero() {
|
||||
} else if r.deadline.Load().IsZero() {
|
||||
r.done <- struct{}{}
|
||||
r.inRead.Store(true)
|
||||
defer r.inRead.Store(false)
|
||||
|
@ -38,9 +42,13 @@ func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err err
|
|||
return
|
||||
}
|
||||
go r.pipeReadFrom(len(p))
|
||||
default:
|
||||
}
|
||||
return r.readFrom(p)
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, nil, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
|
@ -49,22 +57,29 @@ func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Soc
|
|||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
if r.disablePipe.Load() {
|
||||
return r.TimeoutPacketReader.ReadPacket(buffer)
|
||||
}
|
||||
select {
|
||||
case <-r.done:
|
||||
if r.deadline.Load().IsZero() {
|
||||
} else if r.deadline.Load().IsZero() {
|
||||
r.done <- struct{}{}
|
||||
r.inRead.Store(true)
|
||||
defer r.inRead.Store(false)
|
||||
destination, err = r.TimeoutPacketReader.ReadPacket(buffer)
|
||||
return
|
||||
}
|
||||
go r.pipeReadFromBuffer(buffer.FreeLen())
|
||||
default:
|
||||
go r.pipeReadFrom(buffer.FreeLen())
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
}
|
||||
return r.readPacket(buffer)
|
||||
}
|
||||
|
||||
func (r *fallbackPacketReader) SetReadDeadline(t time.Time) error {
|
||||
|
|
|
@ -54,14 +54,13 @@ func (r *reader) Read(p []byte) (n int, err error) {
|
|||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
go r.pipeRead(len(p))
|
||||
default:
|
||||
}
|
||||
return r.read(p)
|
||||
}
|
||||
|
||||
func (r *reader) read(p []byte) (n int, err error) {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
|
@ -99,14 +98,13 @@ func (r *reader) ReadBuffer(buffer *buf.Buffer) error {
|
|||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
go r.pipeReadBuffer(buffer.FreeLen())
|
||||
default:
|
||||
go r.pipeRead(buffer.FreeLen())
|
||||
}
|
||||
return r.readBuffer(buffer)
|
||||
}
|
||||
|
||||
func (r *reader) readBuffer(buffer *buf.Buffer) error {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
|
@ -127,16 +125,6 @@ func (r *reader) pipeReturnBuffer(result *readResult, buffer *buf.Buffer) error
|
|||
}
|
||||
}
|
||||
|
||||
func (r *reader) pipeReadBuffer(pLen int) {
|
||||
cacheBuffer := buf.NewSize(pLen)
|
||||
err := r.ExtendedReader.ReadBuffer(cacheBuffer)
|
||||
r.result <- &readResult{
|
||||
buffer: cacheBuffer,
|
||||
err: err,
|
||||
}
|
||||
r.done <- struct{}{}
|
||||
}
|
||||
|
||||
func (r *reader) SetReadDeadline(t time.Time) error {
|
||||
r.deadline.Store(t)
|
||||
r.pipeDeadline.set(t)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
|
@ -23,12 +24,15 @@ func (r *fallbackReader) Read(p []byte) (n int, err error) {
|
|||
return r.pipeReturn(result, p)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
if r.disablePipe.Load() {
|
||||
return r.ExtendedReader.Read(p)
|
||||
}
|
||||
select {
|
||||
case <-r.done:
|
||||
if r.deadline.Load().IsZero() {
|
||||
} else if r.deadline.Load().IsZero() {
|
||||
r.done <- struct{}{}
|
||||
r.inRead.Store(true)
|
||||
defer r.inRead.Store(false)
|
||||
|
@ -36,9 +40,13 @@ func (r *fallbackReader) Read(p []byte) (n int, err error) {
|
|||
return
|
||||
}
|
||||
go r.pipeRead(len(p))
|
||||
default:
|
||||
}
|
||||
return r.reader.read(p)
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error {
|
||||
|
@ -47,21 +55,28 @@ func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error {
|
|||
return r.pipeReturnBuffer(result, buffer)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
if r.disablePipe.Load() {
|
||||
return r.ExtendedReader.ReadBuffer(buffer)
|
||||
}
|
||||
select {
|
||||
case <-r.done:
|
||||
if r.deadline.Load().IsZero() {
|
||||
} else if r.deadline.Load().IsZero() {
|
||||
r.done <- struct{}{}
|
||||
r.inRead.Store(true)
|
||||
defer r.inRead.Store(false)
|
||||
return r.ExtendedReader.ReadBuffer(buffer)
|
||||
}
|
||||
go r.pipeReadBuffer(buffer.FreeLen())
|
||||
default:
|
||||
go r.pipeRead(buffer.FreeLen())
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return os.ErrDeadlineExceeded
|
||||
}
|
||||
return r.readBuffer(buffer)
|
||||
}
|
||||
|
||||
func (r *fallbackReader) SetReadDeadline(t time.Time) error {
|
||||
|
|
75
common/bufio/deadline/serial.go
Normal file
75
common/bufio/deadline/serial.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/debug"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type SerialConn struct {
|
||||
N.ExtendedConn
|
||||
access sync.Mutex
|
||||
}
|
||||
|
||||
func NewSerialConn(conn N.ExtendedConn) N.ExtendedConn {
|
||||
if !debug.Enabled {
|
||||
return conn
|
||||
}
|
||||
return &SerialConn{ExtendedConn: conn}
|
||||
}
|
||||
|
||||
func (c *SerialConn) Read(p []byte) (n int, err error) {
|
||||
if !c.access.TryLock() {
|
||||
panic("concurrent read on deadline conn")
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.ExtendedConn.Read(p)
|
||||
}
|
||||
|
||||
func (c *SerialConn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
if !c.access.TryLock() {
|
||||
panic("concurrent read on deadline conn")
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.ExtendedConn.ReadBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *SerialConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
type SerialPacketConn struct {
|
||||
N.NetPacketConn
|
||||
access sync.Mutex
|
||||
}
|
||||
|
||||
func NewSerialPacketConn(conn N.NetPacketConn) N.NetPacketConn {
|
||||
if !debug.Enabled {
|
||||
return conn
|
||||
}
|
||||
return &SerialPacketConn{NetPacketConn: conn}
|
||||
}
|
||||
|
||||
func (c *SerialPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
if !c.access.TryLock() {
|
||||
panic("concurrent read on deadline conn")
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.NetPacketConn.ReadFrom(p)
|
||||
}
|
||||
|
||||
func (c *SerialPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
if !c.access.TryLock() {
|
||||
panic("concurrent read on deadline conn")
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.NetPacketConn.ReadPacket(buffer)
|
||||
}
|
||||
|
||||
func (c *SerialPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
|
@ -3,6 +3,7 @@ package bufio
|
|||
import (
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
@ -12,13 +13,17 @@ var _ N.NetPacketConn = (*FallbackPacketConn)(nil)
|
|||
|
||||
type FallbackPacketConn struct {
|
||||
N.PacketConn
|
||||
writer N.NetPacketWriter
|
||||
}
|
||||
|
||||
func NewNetPacketConn(conn N.PacketConn) N.NetPacketConn {
|
||||
if packetConn, loaded := conn.(N.NetPacketConn); loaded {
|
||||
return packetConn
|
||||
}
|
||||
return &FallbackPacketConn{PacketConn: conn}
|
||||
return &FallbackPacketConn{
|
||||
PacketConn: conn,
|
||||
writer: NewNetPacketWriter(conn),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
|
@ -36,11 +41,7 @@ func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error
|
|||
}
|
||||
|
||||
func (c *FallbackPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
err = c.WritePacket(buf.As(p), M.SocksaddrFromNet(addr))
|
||||
if err == nil {
|
||||
n = len(p)
|
||||
}
|
||||
return
|
||||
return c.writer.WriteTo(p, addr)
|
||||
}
|
||||
|
||||
func (c *FallbackPacketConn) ReaderReplaceable() bool {
|
||||
|
@ -54,3 +55,50 @@ func (c *FallbackPacketConn) WriterReplaceable() bool {
|
|||
func (c *FallbackPacketConn) Upstream() any {
|
||||
return c.PacketConn
|
||||
}
|
||||
|
||||
func (c *FallbackPacketConn) UpstreamWriter() any {
|
||||
return c.writer
|
||||
}
|
||||
|
||||
var _ N.NetPacketWriter = (*FallbackPacketWriter)(nil)
|
||||
|
||||
type FallbackPacketWriter struct {
|
||||
N.PacketWriter
|
||||
frontHeadroom int
|
||||
rearHeadroom int
|
||||
}
|
||||
|
||||
func NewNetPacketWriter(writer N.PacketWriter) N.NetPacketWriter {
|
||||
if packetWriter, loaded := writer.(N.NetPacketWriter); loaded {
|
||||
return packetWriter
|
||||
}
|
||||
return &FallbackPacketWriter{
|
||||
PacketWriter: writer,
|
||||
frontHeadroom: N.CalculateFrontHeadroom(writer),
|
||||
rearHeadroom: N.CalculateRearHeadroom(writer),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *FallbackPacketWriter) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if c.frontHeadroom > 0 || c.rearHeadroom > 0 {
|
||||
buffer := buf.NewSize(len(p) + c.frontHeadroom + c.rearHeadroom)
|
||||
buffer.Resize(c.frontHeadroom, 0)
|
||||
common.Must1(buffer.Write(p))
|
||||
err = c.PacketWriter.WritePacket(buffer, M.SocksaddrFromNet(addr))
|
||||
} else {
|
||||
err = c.PacketWriter.WritePacket(buf.As(p), M.SocksaddrFromNet(addr))
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n = len(p)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *FallbackPacketWriter) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *FallbackPacketWriter) Upstream() any {
|
||||
return c.PacketWriter
|
||||
}
|
||||
|
|
|
@ -37,13 +37,7 @@ func WriteBuffer(writer N.ExtendedWriter, buffer *buf.Buffer) (n int, err error)
|
|||
frontHeadroom := N.CalculateFrontHeadroom(writer)
|
||||
rearHeadroom := N.CalculateRearHeadroom(writer)
|
||||
if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() {
|
||||
bufferSize := N.CalculateMTU(nil, writer)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
newBuffer := buf.NewSize(bufferSize)
|
||||
newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom)
|
||||
newBuffer.Resize(frontHeadroom, 0)
|
||||
common.Must1(newBuffer.Write(buffer.Bytes()))
|
||||
buffer.Release()
|
||||
|
@ -69,13 +63,7 @@ func WritePacketBuffer(writer N.PacketWriter, buffer *buf.Buffer, destination M.
|
|||
frontHeadroom := N.CalculateFrontHeadroom(writer)
|
||||
rearHeadroom := N.CalculateRearHeadroom(writer)
|
||||
if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() {
|
||||
bufferSize := N.CalculateMTU(nil, writer)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
newBuffer := buf.NewSize(bufferSize)
|
||||
newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom)
|
||||
newBuffer.Resize(frontHeadroom, 0)
|
||||
common.Must1(newBuffer.Write(buffer.Bytes()))
|
||||
buffer.Release()
|
||||
|
|
|
@ -9,54 +9,142 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type NATPacketConn struct {
|
||||
type NATPacketConn interface {
|
||||
N.NetPacketConn
|
||||
UpdateDestination(destinationAddress netip.Addr)
|
||||
}
|
||||
|
||||
func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
|
||||
return &unidirectionalNATPacketConn{
|
||||
NetPacketConn: conn,
|
||||
origin: socksaddrWithoutPort(origin),
|
||||
destination: socksaddrWithoutPort(destination),
|
||||
}
|
||||
}
|
||||
|
||||
func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
|
||||
return &bidirectionalNATPacketConn{
|
||||
NetPacketConn: conn,
|
||||
origin: socksaddrWithoutPort(origin),
|
||||
destination: socksaddrWithoutPort(destination),
|
||||
}
|
||||
}
|
||||
|
||||
type unidirectionalNATPacketConn struct {
|
||||
N.NetPacketConn
|
||||
origin M.Socksaddr
|
||||
destination M.Socksaddr
|
||||
}
|
||||
|
||||
func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) *NATPacketConn {
|
||||
return &NATPacketConn{
|
||||
NetPacketConn: conn,
|
||||
origin: origin,
|
||||
destination: destination,
|
||||
func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
destination := M.SocksaddrFromNet(addr)
|
||||
if socksaddrWithoutPort(destination) == c.destination {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.origin.Addr,
|
||||
Fqdn: c.origin.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
|
||||
}
|
||||
|
||||
func (c *NATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = c.NetPacketConn.ReadFrom(p)
|
||||
if err == nil && M.SocksaddrFromNet(addr) == c.origin {
|
||||
addr = c.destination.UDPAddr()
|
||||
func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if socksaddrWithoutPort(destination) == c.destination {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.origin.Addr,
|
||||
Fqdn: c.origin.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *NATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if M.SocksaddrFromNet(addr) == c.destination {
|
||||
addr = c.origin.UDPAddr()
|
||||
}
|
||||
return c.NetPacketConn.WriteTo(p, addr)
|
||||
}
|
||||
|
||||
func (c *NATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
destination, err = c.NetPacketConn.ReadPacket(buffer)
|
||||
if destination == c.origin {
|
||||
destination = c.destination
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *NATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if destination == c.destination {
|
||||
destination = c.origin
|
||||
}
|
||||
return c.NetPacketConn.WritePacket(buffer, destination)
|
||||
}
|
||||
|
||||
func (c *NATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
|
||||
func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
|
||||
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
|
||||
}
|
||||
|
||||
func (c *NATPacketConn) Upstream() any {
|
||||
func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
func (c *unidirectionalNATPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
type bidirectionalNATPacketConn struct {
|
||||
N.NetPacketConn
|
||||
origin M.Socksaddr
|
||||
destination M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = c.NetPacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination := M.SocksaddrFromNet(addr)
|
||||
if socksaddrWithoutPort(destination) == c.origin {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.destination.Addr,
|
||||
Fqdn: c.destination.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
addr = destination.UDPAddr()
|
||||
return
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
destination := M.SocksaddrFromNet(addr)
|
||||
if socksaddrWithoutPort(destination) == c.destination {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.origin.Addr,
|
||||
Fqdn: c.origin.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
destination, err = c.NetPacketConn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if socksaddrWithoutPort(destination) == c.origin {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.destination.Addr,
|
||||
Fqdn: c.destination.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if socksaddrWithoutPort(destination) == c.destination {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.origin.Addr,
|
||||
Fqdn: c.origin.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return c.NetPacketConn.WritePacket(buffer, destination)
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
|
||||
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
|
||||
destination.Port = 0
|
||||
return destination
|
||||
}
|
||||
|
|
39
common/bufio/nat_wait.go
Normal file
39
common/bufio/nat_wait.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func (c *bidirectionalNATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaiter, bool) {
|
||||
waiter, created := CreatePacketReadWaiter(c.NetPacketConn)
|
||||
if !created {
|
||||
return nil, false
|
||||
}
|
||||
return &waitBidirectionalNATPacketConn{c, waiter}, true
|
||||
}
|
||||
|
||||
type waitBidirectionalNATPacketConn struct {
|
||||
*bidirectionalNATPacketConn
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (c *waitBidirectionalNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return c.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (c *waitBidirectionalNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
buffer, destination, err = c.readWaiter.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if socksaddrWithoutPort(destination) == c.origin {
|
||||
destination = M.Socksaddr{
|
||||
Addr: c.destination.Addr,
|
||||
Fqdn: c.destination.Fqdn,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
277
common/bufio/net_test.go
Normal file
277
common/bufio/net_test.go
Normal file
|
@ -0,0 +1,277 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
var (
|
||||
group task.Group
|
||||
serverConn net.Conn
|
||||
clientConn net.Conn
|
||||
)
|
||||
group.Append0(func(ctx context.Context) error {
|
||||
var serverErr error
|
||||
serverConn, serverErr = listener.Accept()
|
||||
return serverErr
|
||||
})
|
||||
group.Append0(func(ctx context.Context) error {
|
||||
var clientErr error
|
||||
clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
|
||||
return clientErr
|
||||
})
|
||||
err = group.Run()
|
||||
require.NoError(t, err)
|
||||
listener.Close()
|
||||
t.Cleanup(func() {
|
||||
serverConn.Close()
|
||||
clientConn.Close()
|
||||
})
|
||||
return serverConn, clientConn
|
||||
}
|
||||
|
||||
func UDPPipe(t *testing.T) (net.PacketConn, net.PacketConn, M.Socksaddr) {
|
||||
serverConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
clientConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
return serverConn, clientConn, M.SocksaddrFromNet(clientConn.LocalAddr())
|
||||
}
|
||||
|
||||
func Timeout(t *testing.T) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("timeout")
|
||||
}
|
||||
}()
|
||||
return cancel
|
||||
}
|
||||
|
||||
type hashPair struct {
|
||||
sendHash map[int][]byte
|
||||
recvHash map[int][]byte
|
||||
}
|
||||
|
||||
func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) {
|
||||
pingCh := make(chan hashPair)
|
||||
pongCh := make(chan hashPair)
|
||||
test := func(t *testing.T) error {
|
||||
defer close(pingCh)
|
||||
defer close(pongCh)
|
||||
pingOpen := false
|
||||
pongOpen := false
|
||||
var serverPair hashPair
|
||||
var clientPair hashPair
|
||||
|
||||
for {
|
||||
if pingOpen && pongOpen {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case serverPair, pingOpen = <-pingCh:
|
||||
assert.True(t, pingOpen)
|
||||
case clientPair, pongOpen = <-pongCh:
|
||||
assert.True(t, pongOpen)
|
||||
case <-time.After(10 * time.Second):
|
||||
return errors.New("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, serverPair.recvHash, clientPair.sendHash)
|
||||
assert.Equal(t, serverPair.sendHash, clientPair.recvHash)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return pingCh, pongCh, test
|
||||
}
|
||||
|
||||
func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error {
|
||||
times := 100
|
||||
chunkSize := int64(64 * 1024)
|
||||
|
||||
pingCh, pongCh, test := newLargeDataPair()
|
||||
writeRandData := func(conn net.Conn) (map[int][]byte, error) {
|
||||
buf := make([]byte, chunkSize)
|
||||
hashMap := map[int][]byte{}
|
||||
for i := 0; i < times; i++ {
|
||||
if _, err := rand.Read(buf[1:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf[0] = byte(i)
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
hashMap[i] = hash[:]
|
||||
|
||||
if _, err := conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return hashMap, nil
|
||||
}
|
||||
go func() {
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, chunkSize)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, err := io.ReadFull(outputConn, buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
|
||||
sendHash, err := writeRandData(outputConn)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pingCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
sendHash, err := writeRandData(inputConn)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, chunkSize)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, err = io.ReadFull(inputConn, buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
|
||||
pongCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
return test(t)
|
||||
}
|
||||
|
||||
func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error {
|
||||
rAddr := outputAddr.UDPAddr()
|
||||
times := 50
|
||||
chunkSize := 9000
|
||||
pingCh, pongCh, test := newLargeDataPair()
|
||||
writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) {
|
||||
hashMap := map[int][]byte{}
|
||||
mux := sync.Mutex{}
|
||||
for i := 0; i < times; i++ {
|
||||
buf := make([]byte, chunkSize)
|
||||
if _, err := rand.Read(buf[1:]); err != nil {
|
||||
t.Log(err.Error())
|
||||
continue
|
||||
}
|
||||
buf[0] = byte(i)
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
mux.Lock()
|
||||
hashMap[i] = hash[:]
|
||||
mux.Unlock()
|
||||
|
||||
if _, err := pc.WriteTo(buf, addr); err != nil {
|
||||
t.Log(err.Error())
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
return hashMap, nil
|
||||
}
|
||||
go func() {
|
||||
var (
|
||||
lAddr net.Addr
|
||||
err error
|
||||
)
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, 64*1024)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, lAddr, err = outputConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
hash := md5.Sum(buf[:chunkSize])
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
sendHash, err := writeRandData(outputConn, lAddr)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pingCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
sendHash, err := writeRandData(inputConn, rAddr)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, 64*1024)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, _, err := inputConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hash := md5.Sum(buf[:chunkSize])
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
|
||||
pongCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
|
||||
return test(t)
|
||||
}
|
|
@ -1,127 +0,0 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func CopyTimes(dst io.Writer, src io.Reader, times int) (n int64, err error) {
|
||||
return CopyExtendedTimes(NewExtendedWriter(N.UnwrapWriter(dst)), NewExtendedReader(N.UnwrapReader(src)), times)
|
||||
}
|
||||
|
||||
func CopyExtendedTimes(dst N.ExtendedWriter, src N.ExtendedReader, times int) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
dstUnsafe := N.IsUnsafeWriter(dst)
|
||||
var buffer *buf.Buffer
|
||||
if !dstUnsafe {
|
||||
_buffer := buf.StackNewSize(bufferSize)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer = common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
}
|
||||
notFirstTime := true
|
||||
for i := 0; i < times; i++ {
|
||||
if dstUnsafe {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
}
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = src.ReadBuffer(readBuffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type ReadFromWriter interface {
|
||||
io.ReaderFrom
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func ReadFrom0(readerFrom ReadFromWriter, reader io.Reader) (n int64, err error) {
|
||||
n, err = CopyTimes(readerFrom, reader, 1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var rn int64
|
||||
rn, err = readerFrom.ReadFrom(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += rn
|
||||
return
|
||||
}
|
||||
|
||||
func ReadFromN(readerFrom ReadFromWriter, reader io.Reader, times int) (n int64, err error) {
|
||||
n, err = CopyTimes(readerFrom, reader, times)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var rn int64
|
||||
rn, err = readerFrom.ReadFrom(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += rn
|
||||
return
|
||||
}
|
||||
|
||||
type WriteToReader interface {
|
||||
io.WriterTo
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func WriteTo0(writerTo WriteToReader, writer io.Writer) (n int64, err error) {
|
||||
n, err = CopyTimes(writer, writerTo, 1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var wn int64
|
||||
wn, err = writerTo.WriteTo(writer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += wn
|
||||
return
|
||||
}
|
||||
|
||||
func WriteToN(writerTo WriteToReader, writer io.Writer, times int) (n int64, err error) {
|
||||
n, err = CopyTimes(writer, writerTo, times)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var wn int64
|
||||
wn, err = writerTo.WriteTo(writer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += wn
|
||||
return
|
||||
}
|
|
@ -33,10 +33,10 @@ func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) {
|
|||
case syscall.Conn:
|
||||
rawConn, err := w.SyscallConn()
|
||||
if err == nil {
|
||||
return &SyscallVectorisedWriter{writer, rawConn}, true
|
||||
return &SyscallVectorisedWriter{upstream: writer, rawConn: rawConn}, true
|
||||
}
|
||||
case syscall.RawConn:
|
||||
return &SyscallVectorisedWriter{writer, w}, true
|
||||
return &SyscallVectorisedWriter{upstream: writer, rawConn: w}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
@ -48,10 +48,10 @@ func CreateVectorisedPacketWriter(writer any) (N.VectorisedPacketWriter, bool) {
|
|||
case syscall.Conn:
|
||||
rawConn, err := w.SyscallConn()
|
||||
if err == nil {
|
||||
return &SyscallVectorisedPacketWriter{writer, rawConn}, true
|
||||
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: rawConn}, true
|
||||
}
|
||||
case syscall.RawConn:
|
||||
return &SyscallVectorisedPacketWriter{writer, w}, true
|
||||
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: w}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
@ -74,9 +74,7 @@ func (w *BufferedVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error
|
|||
if bufferLen > 65535 {
|
||||
bufferBytes = make([]byte, bufferLen)
|
||||
} else {
|
||||
_buffer := buf.StackNewSize(bufferLen)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
buffer := buf.NewSize(bufferLen)
|
||||
defer buffer.Release()
|
||||
bufferBytes = buffer.FreeBytes()
|
||||
}
|
||||
|
@ -113,6 +111,7 @@ var _ N.VectorisedWriter = (*SyscallVectorisedWriter)(nil)
|
|||
type SyscallVectorisedWriter struct {
|
||||
upstream any
|
||||
rawConn syscall.RawConn
|
||||
syscallVectorisedWriterFields
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedWriter) Upstream() any {
|
||||
|
@ -128,6 +127,7 @@ var _ N.VectorisedPacketWriter = (*SyscallVectorisedPacketWriter)(nil)
|
|||
type SyscallVectorisedPacketWriter struct {
|
||||
upstream any
|
||||
rawConn syscall.RawConn
|
||||
syscallVectorisedWriterFields
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedPacketWriter) Upstream() any {
|
||||
|
|
60
common/bufio/vectorised_test.go
Normal file
60
common/bufio/vectorised_test.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteVectorised(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn := TCPPipe(t)
|
||||
vectorisedWriter, created := CreateVectorisedWriter(inputConn)
|
||||
require.True(t, created)
|
||||
require.NotNil(t, vectorisedWriter)
|
||||
var bufA [1024]byte
|
||||
var bufB [1024]byte
|
||||
var bufC [2048]byte
|
||||
_, err := io.ReadFull(rand.Reader, bufA[:])
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadFull(rand.Reader, bufB[:])
|
||||
require.NoError(t, err)
|
||||
copy(bufC[:], bufA[:])
|
||||
copy(bufC[1024:], bufB[:])
|
||||
finish := Timeout(t)
|
||||
_, err = WriteVectorised(vectorisedWriter, [][]byte{bufA[:], bufB[:]})
|
||||
require.NoError(t, err)
|
||||
output := make([]byte, 2048)
|
||||
_, err = io.ReadFull(outputConn, output)
|
||||
finish()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bufC[:], output)
|
||||
}
|
||||
|
||||
func TestWriteVectorisedPacket(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn, outputAddr := UDPPipe(t)
|
||||
vectorisedWriter, created := CreateVectorisedPacketWriter(inputConn)
|
||||
require.True(t, created)
|
||||
require.NotNil(t, vectorisedWriter)
|
||||
var bufA [1024]byte
|
||||
var bufB [1024]byte
|
||||
var bufC [2048]byte
|
||||
_, err := io.ReadFull(rand.Reader, bufA[:])
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadFull(rand.Reader, bufB[:])
|
||||
require.NoError(t, err)
|
||||
copy(bufC[:], bufA[:])
|
||||
copy(bufC[1024:], bufB[:])
|
||||
finish := Timeout(t)
|
||||
_, err = WriteVectorisedPacket(vectorisedWriter, [][]byte{bufA[:], bufB[:]}, outputAddr)
|
||||
require.NoError(t, err)
|
||||
output := make([]byte, 2048)
|
||||
n, _, err := outputConn.ReadFrom(output)
|
||||
finish()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2048, n)
|
||||
require.Equal(t, bufC[:], output)
|
||||
}
|
|
@ -3,6 +3,8 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
|
@ -11,15 +13,28 @@ import (
|
|||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type syscallVectorisedWriterFields struct {
|
||||
access sync.Mutex
|
||||
iovecList *[]unix.Iovec
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
iovecList := make([]unix.Iovec, 0, len(buffers))
|
||||
for _, buffer := range buffers {
|
||||
var iovec unix.Iovec
|
||||
iovec.Base = &buffer.Bytes()[0]
|
||||
iovec.SetLen(buffer.Len())
|
||||
iovecList = append(iovecList, iovec)
|
||||
var iovecList []unix.Iovec
|
||||
if w.iovecList != nil {
|
||||
iovecList = *w.iovecList
|
||||
}
|
||||
iovecList = iovecList[:0]
|
||||
for index, buffer := range buffers {
|
||||
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
|
||||
iovecList[index].SetLen(buffer.Len())
|
||||
}
|
||||
if w.iovecList == nil {
|
||||
w.iovecList = new([]unix.Iovec)
|
||||
}
|
||||
*w.iovecList = iovecList // cache
|
||||
var innerErr unix.Errno
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
//nolint:staticcheck
|
||||
|
@ -28,32 +43,52 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
|
|||
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
|
||||
})
|
||||
if innerErr != 0 {
|
||||
err = innerErr
|
||||
err = os.NewSyscallError("SYS_WRITEV", innerErr)
|
||||
}
|
||||
for index := range iovecList {
|
||||
iovecList[index] = unix.Iovec{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
var sockaddr unix.Sockaddr
|
||||
if destination.IsIPv4() {
|
||||
sockaddr = &unix.SockaddrInet4{
|
||||
Port: int(destination.Port),
|
||||
Addr: destination.Addr.As4(),
|
||||
var iovecList []unix.Iovec
|
||||
if w.iovecList != nil {
|
||||
iovecList = *w.iovecList
|
||||
}
|
||||
} else {
|
||||
sockaddr = &unix.SockaddrInet6{
|
||||
Port: int(destination.Port),
|
||||
Addr: destination.Addr.As16(),
|
||||
iovecList = iovecList[:0]
|
||||
for index, buffer := range buffers {
|
||||
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
|
||||
iovecList[index].SetLen(buffer.Len())
|
||||
}
|
||||
if w.iovecList == nil {
|
||||
w.iovecList = new([]unix.Iovec)
|
||||
}
|
||||
*w.iovecList = iovecList // cache
|
||||
var innerErr error
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
_, innerErr = unix.SendmsgBuffers(int(fd), buf.ToSliceMulti(buffers), nil, sockaddr, 0)
|
||||
var msg unix.Msghdr
|
||||
name, nameLen := ToSockaddr(destination.AddrPort())
|
||||
msg.Name = (*byte)(name)
|
||||
msg.Namelen = nameLen
|
||||
if len(iovecList) > 0 {
|
||||
msg.Iov = &iovecList[0]
|
||||
msg.SetIovlen(len(iovecList))
|
||||
}
|
||||
_, innerErr = sendmsg(int(fd), &msg, 0)
|
||||
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
|
||||
})
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
for index := range iovecList {
|
||||
iovecList[index] = unix.Iovec{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
//go:linkname sendmsg golang.org/x/sys/unix.sendmsg
|
||||
func sendmsg(s int, msg *unix.Msghdr, flags int) (n int, err error)
|
||||
|
|
|
@ -1,62 +1,93 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type syscallVectorisedWriterFields struct {
|
||||
access sync.Mutex
|
||||
iovecList *[]windows.WSABuf
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
iovecList := make([]*windows.WSABuf, len(buffers))
|
||||
for i, buffer := range buffers {
|
||||
iovecList[i] = &windows.WSABuf{
|
||||
Len: uint32(buffer.Len()),
|
||||
var iovecList []windows.WSABuf
|
||||
if w.iovecList != nil {
|
||||
iovecList = *w.iovecList
|
||||
}
|
||||
iovecList = iovecList[:0]
|
||||
for _, buffer := range buffers {
|
||||
iovecList = append(iovecList, windows.WSABuf{
|
||||
Buf: &buffer.Bytes()[0],
|
||||
Len: uint32(buffer.Len()),
|
||||
})
|
||||
}
|
||||
if w.iovecList == nil {
|
||||
w.iovecList = new([]windows.WSABuf)
|
||||
}
|
||||
*w.iovecList = iovecList // cache
|
||||
var n uint32
|
||||
var innerErr error
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
innerErr = windows.WSASend(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
|
||||
innerErr = windows.WSASend(windows.Handle(fd), &iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
|
||||
return innerErr != windows.WSAEWOULDBLOCK
|
||||
})
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
for index := range iovecList {
|
||||
iovecList[index] = windows.WSABuf{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
iovecList := make([]*windows.WSABuf, len(buffers))
|
||||
for i, buffer := range buffers {
|
||||
iovecList[i] = &windows.WSABuf{
|
||||
Len: uint32(buffer.Len()),
|
||||
var iovecList []windows.WSABuf
|
||||
if w.iovecList != nil {
|
||||
iovecList = *w.iovecList
|
||||
}
|
||||
iovecList = iovecList[:0]
|
||||
for _, buffer := range buffers {
|
||||
iovecList = append(iovecList, windows.WSABuf{
|
||||
Buf: &buffer.Bytes()[0],
|
||||
Len: uint32(buffer.Len()),
|
||||
})
|
||||
}
|
||||
if w.iovecList == nil {
|
||||
w.iovecList = new([]windows.WSABuf)
|
||||
}
|
||||
var sockaddr windows.Sockaddr
|
||||
if destination.IsIPv4() {
|
||||
sockaddr = &windows.SockaddrInet4{
|
||||
Port: int(destination.Port),
|
||||
Addr: destination.Addr.As4(),
|
||||
}
|
||||
} else {
|
||||
sockaddr = &windows.SockaddrInet6{
|
||||
Port: int(destination.Port),
|
||||
Addr: destination.Addr.As16(),
|
||||
}
|
||||
}
|
||||
*w.iovecList = iovecList // cache
|
||||
var n uint32
|
||||
var innerErr error
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
innerErr = windows.WSASendto(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, sockaddr, nil, nil)
|
||||
name, nameLen := ToSockaddr(destination.AddrPort())
|
||||
innerErr = windows.WSASendTo(
|
||||
windows.Handle(fd),
|
||||
&iovecList[0],
|
||||
uint32(len(iovecList)),
|
||||
&n,
|
||||
0,
|
||||
(*windows.RawSockaddrAny)(name),
|
||||
nameLen,
|
||||
nil,
|
||||
nil)
|
||||
return innerErr != windows.WSAEWOULDBLOCK
|
||||
})
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
for index := range iovecList {
|
||||
iovecList[index] = windows.WSABuf{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
8
common/cache/lrucache.go
vendored
8
common/cache/lrucache.go
vendored
|
@ -258,6 +258,14 @@ func (c *LruCache[K, V]) Delete(key K) {
|
|||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *LruCache[K, V]) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for element := c.lru.Front(); element != nil; element = element.Next() {
|
||||
c.deleteElement(element)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LruCache[K, V]) maybeDeleteOldest() {
|
||||
if !c.staleReturn && c.maxAge > 0 {
|
||||
now := time.Now().Unix()
|
||||
|
|
|
@ -21,13 +21,13 @@ type TimerPacketConn struct {
|
|||
instance *Instance
|
||||
}
|
||||
|
||||
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) {
|
||||
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
|
||||
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
|
||||
oldTimeout := timeoutConn.Timeout()
|
||||
if timeout < oldTimeout {
|
||||
timeoutConn.SetTimeout(timeout)
|
||||
}
|
||||
return ctx, timeoutConn
|
||||
return ctx, conn
|
||||
}
|
||||
err := conn.SetReadDeadline(time.Time{})
|
||||
if err == nil {
|
||||
|
|
|
@ -2,6 +2,7 @@ package canceler
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
|
@ -31,7 +32,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
|
|||
for {
|
||||
err = c.PacketConn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||
if err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
return
|
||||
}
|
||||
destination, err = c.PacketConn.ReadPacket(buffer)
|
||||
if err == nil {
|
||||
|
@ -43,7 +44,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
|
|||
return
|
||||
}
|
||||
} else {
|
||||
return M.Socksaddr{}, err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -66,6 +67,7 @@ func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
|
|||
}
|
||||
|
||||
func (c *TimeoutPacketConn) Close() error {
|
||||
c.cancel(net.ErrClosed)
|
||||
return c.PacketConn.Close()
|
||||
}
|
||||
|
||||
|
|
11
common/clear.go
Normal file
11
common/clear.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
//go:build go1.21
|
||||
|
||||
package common
|
||||
|
||||
func ClearArray[T ~[]E, E any](t T) {
|
||||
clear(t)
|
||||
}
|
||||
|
||||
func ClearMap[T ~map[K]V, K comparable, V any](t T) {
|
||||
clear(t)
|
||||
}
|
16
common/clear_compat.go
Normal file
16
common/clear_compat.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
//go:build !go1.21
|
||||
|
||||
package common
|
||||
|
||||
func ClearArray[T ~[]E, E any](t T) {
|
||||
var defaultValue E
|
||||
for i := range t {
|
||||
t[i] = defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
func ClearMap[T ~map[K]V, K comparable, V any](t T) {
|
||||
for k := range t {
|
||||
delete(t, k)
|
||||
}
|
||||
}
|
|
@ -159,20 +159,14 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
|
|||
|
||||
//go:norace
|
||||
func Dup[T any](obj T) T {
|
||||
if UnsafeBuffer {
|
||||
pointer := uintptr(unsafe.Pointer(&obj))
|
||||
//nolint:staticcheck
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
return *(*T)(unsafe.Pointer(pointer))
|
||||
} else {
|
||||
return obj
|
||||
}
|
||||
}
|
||||
|
||||
func KeepAlive(obj any) {
|
||||
if UnsafeBuffer {
|
||||
runtime.KeepAlive(obj)
|
||||
}
|
||||
}
|
||||
|
||||
func Uniq[T comparable](arr []T) []T {
|
||||
|
@ -342,6 +336,10 @@ func DefaultValue[T any]() T {
|
|||
return defaultValue
|
||||
}
|
||||
|
||||
func Ptr[T any](obj T) *T {
|
||||
return &obj
|
||||
}
|
||||
|
||||
func Close(closers ...any) error {
|
||||
var retErr error
|
||||
for _, closer := range closers {
|
||||
|
|
|
@ -1,59 +1,35 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func BindToInterface(finder InterfaceFinder, interfaceName string, interfaceIndex int) Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
|
||||
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int)) Func {
|
||||
func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int, err error)) Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
interfaceName, interfaceIndex := block(network, address)
|
||||
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
|
||||
}
|
||||
}
|
||||
|
||||
const useInterfaceName = runtime.GOOS == "linux" || runtime.GOOS == "android"
|
||||
|
||||
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
|
||||
if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) {
|
||||
return nil
|
||||
}
|
||||
if interfaceName == "" && interfaceIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
if interfaceName != "" && useInterfaceName || interfaceIndex != -1 && !useInterfaceName {
|
||||
return bindToInterface(conn, network, address, interfaceName, interfaceIndex)
|
||||
}
|
||||
if finder == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
var err error
|
||||
if useInterfaceName {
|
||||
interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex)
|
||||
} else {
|
||||
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||
}
|
||||
interfaceName, interfaceIndex, err := block(network, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if useInterfaceName {
|
||||
if interfaceName == "" {
|
||||
return nil
|
||||
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false)
|
||||
}
|
||||
} else {
|
||||
if interfaceIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return bindToInterface(conn, network, address, interfaceName, interfaceIndex)
|
||||
}
|
||||
|
||||
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
if interfaceName == "" && interfaceIndex == -1 {
|
||||
return E.New("interface not found: ", interfaceName)
|
||||
}
|
||||
if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) {
|
||||
return nil
|
||||
}
|
||||
return bindToInterface(conn, network, address, finder, interfaceName, interfaceIndex, preferInterfaceName)
|
||||
}
|
||||
|
|
|
@ -1,16 +1,24 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
|
||||
if interfaceIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
var err error
|
||||
if interfaceIndex == -1 {
|
||||
if finder == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
switch network {
|
||||
case "tcp6", "udp6":
|
||||
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, interfaceIndex)
|
||||
|
|
|
@ -1,30 +1,21 @@
|
|||
package control
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type InterfaceFinder interface {
|
||||
Interfaces() []Interface
|
||||
InterfaceIndexByName(name string) (int, error)
|
||||
InterfaceNameByIndex(index int) (string, error)
|
||||
InterfaceByAddr(addr netip.Addr) (*Interface, error)
|
||||
}
|
||||
|
||||
func DefaultInterfaceFinder() InterfaceFinder {
|
||||
return (*netInterfaceFinder)(nil)
|
||||
}
|
||||
|
||||
type netInterfaceFinder struct{}
|
||||
|
||||
func (w *netInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
|
||||
netInterface, err := net.InterfaceByName(name)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return netInterface.Index, nil
|
||||
}
|
||||
|
||||
func (w *netInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
|
||||
netInterface, err := net.InterfaceByIndex(index)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return netInterface.Name, nil
|
||||
type Interface struct {
|
||||
Index int
|
||||
MTU int
|
||||
Name string
|
||||
Addresses []netip.Prefix
|
||||
HardwareAddr net.HardwareAddr
|
||||
}
|
||||
|
|
104
common/control/bind_finder_default.go
Normal file
104
common/control/bind_finder_default.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
|
||||
|
||||
type DefaultInterfaceFinder struct {
|
||||
interfaces []Interface
|
||||
}
|
||||
|
||||
func NewDefaultInterfaceFinder() *DefaultInterfaceFinder {
|
||||
return &DefaultInterfaceFinder{}
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) Update() error {
|
||||
netIfs, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaces := make([]Interface, 0, len(netIfs))
|
||||
for _, netIf := range netIfs {
|
||||
ifAddrs, err := netIf.Addrs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaces = append(interfaces, Interface{
|
||||
Index: netIf.Index,
|
||||
MTU: netIf.MTU,
|
||||
Name: netIf.Name,
|
||||
Addresses: common.Map(ifAddrs, M.PrefixFromNet),
|
||||
HardwareAddr: netIf.HardwareAddr,
|
||||
})
|
||||
}
|
||||
f.interfaces = interfaces
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) UpdateInterfaces(interfaces []Interface) {
|
||||
f.interfaces = interfaces
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) Interfaces() []Interface {
|
||||
return f.interfaces
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
|
||||
for _, netInterface := range f.interfaces {
|
||||
if netInterface.Name == name {
|
||||
return netInterface.Index, nil
|
||||
}
|
||||
}
|
||||
netInterface, err := net.InterfaceByName(name)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
f.Update()
|
||||
return netInterface.Index, nil
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
|
||||
for _, netInterface := range f.interfaces {
|
||||
if netInterface.Index == index {
|
||||
return netInterface.Name, nil
|
||||
}
|
||||
}
|
||||
netInterface, err := net.InterfaceByIndex(index)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
f.Update()
|
||||
return netInterface.Name, nil
|
||||
}
|
||||
|
||||
//go:linkname errNoSuchInterface net.errNoSuchInterface
|
||||
var errNoSuchInterface error
|
||||
|
||||
func (f *DefaultInterfaceFinder) InterfaceByAddr(addr netip.Addr) (*Interface, error) {
|
||||
for _, netInterface := range f.interfaces {
|
||||
for _, prefix := range netInterface.Addresses {
|
||||
if prefix.Contains(addr) {
|
||||
return &netInterface, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
err := f.Update()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, netInterface := range f.interfaces {
|
||||
for _, prefix := range netInterface.Addresses {
|
||||
if prefix.Contains(addr) {
|
||||
return &netInterface, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: addr.AsSlice()}, Err: errNoSuchInterface}
|
||||
}
|
|
@ -1,13 +1,42 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
|
||||
var ifIndexDisabled atomic.Bool
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if !preferInterfaceName && !ifIndexDisabled.Load() {
|
||||
if interfaceIndex == -1 {
|
||||
if interfaceName == "" {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
var err error
|
||||
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if E.IsMulti(err, unix.ENOPROTOOPT, unix.EINVAL) {
|
||||
ifIndexDisabled.Store(true)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if interfaceName == "" {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
return unix.BindToDevice(int(fd), interfaceName)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -4,6 +4,6 @@ package control
|
|||
|
||||
import "syscall"
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -2,17 +2,28 @@ package control
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
var err error
|
||||
if interfaceIndex == -1 {
|
||||
if finder == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
handle := syscall.Handle(fd)
|
||||
if M.ParseSocksaddr(address).AddrString() == "" {
|
||||
err := bind4(handle, interfaceIndex)
|
||||
err = bind4(handle, interfaceIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package control
|
|||
import (
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
|
@ -30,6 +31,14 @@ func Conn(conn syscall.Conn, block func(fd uintptr) error) error {
|
|||
return Raw(rawConn, block)
|
||||
}
|
||||
|
||||
func Conn0[T any](conn syscall.Conn, block func(fd uintptr) (T, error)) (T, error) {
|
||||
rawConn, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), err
|
||||
}
|
||||
return Raw0[T](rawConn, block)
|
||||
}
|
||||
|
||||
func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
|
||||
var innerErr error
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
|
@ -37,3 +46,14 @@ func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
|
|||
})
|
||||
return E.Errors(innerErr, err)
|
||||
}
|
||||
|
||||
func Raw0[T any](rawConn syscall.RawConn, block func(fd uintptr) (T, error)) (T, error) {
|
||||
var (
|
||||
value T
|
||||
innerErr error
|
||||
)
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
value, innerErr = block(fd)
|
||||
})
|
||||
return value, E.Errors(innerErr, err)
|
||||
}
|
||||
|
|
|
@ -4,10 +4,10 @@ import (
|
|||
"syscall"
|
||||
)
|
||||
|
||||
func RoutingMark(mark int) Func {
|
||||
func RoutingMark(mark uint32) Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark)
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(mark))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
package control
|
||||
|
||||
func RoutingMark(mark int) Func {
|
||||
func RoutingMark(mark uint32) Func {
|
||||
return nil
|
||||
}
|
||||
|
|
58
common/control/redirect_darwin.go
Normal file
58
common/control/redirect_darwin.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
PF_OUT = 0x2
|
||||
DIOCNATLOOK = 0xc0544417
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
|
||||
pfFd, err := syscall.Open("/dev/pf", 0, syscall.O_RDONLY)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
defer syscall.Close(pfFd)
|
||||
nl := struct {
|
||||
saddr, daddr, rsaddr, rdaddr [16]byte
|
||||
sxport, dxport, rsxport, rdxport [4]byte
|
||||
af, proto, protoVariant, direction uint8
|
||||
}{
|
||||
af: syscall.AF_INET,
|
||||
proto: syscall.IPPROTO_TCP,
|
||||
direction: PF_OUT,
|
||||
}
|
||||
localAddr := M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
|
||||
removeAddr := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
|
||||
if localAddr.IsIPv4() {
|
||||
copy(nl.saddr[:net.IPv4len], removeAddr.Addr.AsSlice())
|
||||
copy(nl.daddr[:net.IPv4len], localAddr.Addr.AsSlice())
|
||||
nl.af = syscall.AF_INET
|
||||
} else {
|
||||
copy(nl.saddr[:], removeAddr.Addr.AsSlice())
|
||||
copy(nl.daddr[:], localAddr.Addr.AsSlice())
|
||||
nl.af = syscall.AF_INET6
|
||||
}
|
||||
binary.BigEndian.PutUint16(nl.sxport[:], removeAddr.Port)
|
||||
binary.BigEndian.PutUint16(nl.dxport[:], localAddr.Port)
|
||||
if _, _, errno := unix.Syscall(syscall.SYS_IOCTL, uintptr(pfFd), DIOCNATLOOK, uintptr(unsafe.Pointer(&nl))); errno != 0 {
|
||||
return netip.AddrPort{}, errno
|
||||
}
|
||||
var address netip.Addr
|
||||
if nl.af == unix.AF_INET {
|
||||
address = M.AddrFromIP(nl.rdaddr[:net.IPv4len])
|
||||
} else {
|
||||
address = netip.AddrFrom16(nl.rdaddr)
|
||||
}
|
||||
return netip.AddrPortFrom(address, binary.BigEndian.Uint16(nl.rdxport[:])), nil
|
||||
}
|
38
common/control/redirect_linux.go
Normal file
38
common/control/redirect_linux.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
|
||||
syscallConn, loaded := common.Cast[syscall.Conn](conn)
|
||||
if !loaded {
|
||||
return netip.AddrPort{}, os.ErrInvalid
|
||||
}
|
||||
return Conn0[netip.AddrPort](syscallConn, func(fd uintptr) (netip.AddrPort, error) {
|
||||
if M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap().IsIPv4() {
|
||||
raw, err := unix.GetsockoptIPv6Mreq(int(fd), unix.IPPROTO_IP, unix.SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
return netip.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
|
||||
} else {
|
||||
raw, err := unix.GetsockoptIPv6MTUInfo(int(fd), unix.IPPROTO_IPV6, unix.SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
var port [2]byte
|
||||
binary.BigEndian.PutUint16(port[:], raw.Addr.Port)
|
||||
return netip.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), binary.LittleEndian.Uint16(port[:])), nil
|
||||
}
|
||||
})
|
||||
}
|
13
common/control/redirect_other.go
Normal file
13
common/control/redirect_other.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
//go:build !linux && !darwin
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, os.ErrInvalid
|
||||
}
|
30
common/control/tcp_keep_alive_linux.go
Normal file
30
common/control/tcp_keep_alive_linux.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
if N.NetworkName(network) != N.NetworkTCP {
|
||||
return nil
|
||||
}
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
return E.Errors(
|
||||
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPIDLE, int(roundDurationUp(idle, time.Second))),
|
||||
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPINTVL, int(roundDurationUp(interval, time.Second))),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func roundDurationUp(d time.Duration, to time.Duration) time.Duration {
|
||||
return (d + to - 1) / to
|
||||
}
|
11
common/control/tcp_keep_alive_stub.go
Normal file
11
common/control/tcp_keep_alive_stub.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
//go:build !linux
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
|
||||
return nil
|
||||
}
|
56
common/control/tproxy_linux.go
Normal file
56
common/control/tproxy_linux.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func TProxy(fd uintptr, family int) error {
|
||||
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
|
||||
if err == nil {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
|
||||
}
|
||||
if err == nil && family == unix.AF_INET6 {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
|
||||
}
|
||||
if err == nil {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
|
||||
}
|
||||
if err == nil && family == unix.AF_INET6 {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func TProxyWriteBack() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if M.ParseSocksaddr(address).Addr.Is6() {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
|
||||
} else {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
|
||||
controlMessages, err := unix.ParseSocketControlMessage(oob)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
for _, message := range controlMessages {
|
||||
if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR {
|
||||
return netip.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil
|
||||
} else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR {
|
||||
return netip.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil
|
||||
}
|
||||
}
|
||||
return netip.AddrPort{}, E.New("not found")
|
||||
}
|
20
common/control/tproxy_other.go
Normal file
20
common/control/tproxy_other.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
//go:build !linux
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
)
|
||||
|
||||
func TProxy(fd uintptr, isIPv6 bool) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func TProxyWriteBack() Func {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, os.ErrInvalid
|
||||
}
|
|
@ -1,8 +1,12 @@
|
|||
package domain
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"sort"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
)
|
||||
|
||||
type Matcher struct {
|
||||
|
@ -10,14 +14,19 @@ type Matcher struct {
|
|||
}
|
||||
|
||||
func NewMatcher(domains []string, domainSuffix []string) *Matcher {
|
||||
domainList := make([]string, 0, len(domains)+len(domainSuffix))
|
||||
domainList := make([]string, 0, len(domains)+2*len(domainSuffix))
|
||||
seen := make(map[string]bool, len(domainList))
|
||||
for _, domain := range domainSuffix {
|
||||
if seen[domain] {
|
||||
continue
|
||||
}
|
||||
seen[domain] = true
|
||||
if domain[0] == '.' {
|
||||
domainList = append(domainList, reverseDomainSuffix(domain))
|
||||
} else {
|
||||
domainList = append(domainList, reverseDomain(domain))
|
||||
domainList = append(domainList, reverseRootDomainSuffix(domain))
|
||||
}
|
||||
}
|
||||
for _, domain := range domains {
|
||||
if seen[domain] {
|
||||
|
@ -27,15 +36,87 @@ func NewMatcher(domains []string, domainSuffix []string) *Matcher {
|
|||
domainList = append(domainList, reverseDomain(domain))
|
||||
}
|
||||
sort.Strings(domainList)
|
||||
return &Matcher{
|
||||
newSuccinctSet(domainList),
|
||||
return &Matcher{newSuccinctSet(domainList)}
|
||||
}
|
||||
|
||||
func ReadMatcher(reader io.Reader) (*Matcher, error) {
|
||||
var version uint8
|
||||
err := binary.Read(reader, binary.BigEndian, &version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leavesLength, err := rw.ReadUVariant(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leaves := make([]uint64, leavesLength)
|
||||
err = binary.Read(reader, binary.BigEndian, leaves)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labelBitmapLength, err := rw.ReadUVariant(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labelBitmap := make([]uint64, labelBitmapLength)
|
||||
err = binary.Read(reader, binary.BigEndian, labelBitmap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labelsLength, err := rw.ReadUVariant(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labels := make([]byte, labelsLength)
|
||||
_, err = io.ReadFull(reader, labels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
set := &succinctSet{
|
||||
leaves: leaves,
|
||||
labelBitmap: labelBitmap,
|
||||
labels: labels,
|
||||
}
|
||||
set.init()
|
||||
return &Matcher{set}, nil
|
||||
}
|
||||
|
||||
func (m *Matcher) Match(domain string) bool {
|
||||
return m.set.Has(reverseDomain(domain))
|
||||
}
|
||||
|
||||
func (m *Matcher) Write(writer io.Writer) error {
|
||||
err := binary.Write(writer, binary.BigEndian, byte(1))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = rw.WriteUVariant(writer, uint64(len(m.set.leaves)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Write(writer, binary.BigEndian, m.set.leaves)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = rw.WriteUVariant(writer, uint64(len(m.set.labelBitmap)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Write(writer, binary.BigEndian, m.set.labelBitmap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = rw.WriteUVariant(writer, uint64(len(m.set.labels)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = writer.Write(m.set.labels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func reverseDomain(domain string) string {
|
||||
l := len(domain)
|
||||
b := make([]byte, l)
|
||||
|
@ -58,3 +139,16 @@ func reverseDomainSuffix(domain string) string {
|
|||
b[l] = prefixLabel
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func reverseRootDomainSuffix(domain string) string {
|
||||
l := len(domain)
|
||||
b := make([]byte, l+2)
|
||||
for i := 0; i < l; {
|
||||
r, n := utf8.DecodeRuneInString(domain[i:])
|
||||
i += n
|
||||
utf8.EncodeRune(b[l-i:], r)
|
||||
}
|
||||
b[l] = '.'
|
||||
b[l+1] = prefixLabel
|
||||
return string(b)
|
||||
}
|
||||
|
|
|
@ -6,9 +6,6 @@ type causeError struct {
|
|||
}
|
||||
|
||||
func (e *causeError) Error() string {
|
||||
if e.cause == nil {
|
||||
return e.message
|
||||
}
|
||||
return e.message + ": " + e.cause.Error()
|
||||
}
|
||||
|
||||
|
|
|
@ -26,14 +26,14 @@ func New(message ...any) error {
|
|||
|
||||
func Cause(cause error, message ...any) error {
|
||||
if cause == nil {
|
||||
return nil
|
||||
panic("cause on an nil error")
|
||||
}
|
||||
return &causeError{F.ToString(message...), cause}
|
||||
}
|
||||
|
||||
func Extend(cause error, message ...any) error {
|
||||
if cause == nil {
|
||||
return nil
|
||||
panic("extend on an nil error")
|
||||
}
|
||||
return &extendedError{F.ToString(message...), cause}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ func (e *multiError) Unwrap() []error {
|
|||
func Errors(errors ...error) error {
|
||||
errors = common.FilterNotNil(errors)
|
||||
errors = ExpandAll(errors)
|
||||
errors = common.FilterNotNil(errors)
|
||||
errors = common.UniqBy(errors, error.Error)
|
||||
switch len(errors) {
|
||||
case 0:
|
||||
|
@ -36,10 +37,13 @@ func Errors(errors ...error) error {
|
|||
}
|
||||
|
||||
func Expand(err error) []error {
|
||||
if multiErr, isMultiErr := err.(MultiError); isMultiErr {
|
||||
return ExpandAll(multiErr.Unwrap())
|
||||
}
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if multiErr, isMultiErr := err.(MultiError); isMultiErr {
|
||||
return ExpandAll(common.FilterNotNil(multiErr.Unwrap()))
|
||||
} else {
|
||||
return []error{err}
|
||||
}
|
||||
}
|
||||
|
||||
func ExpandAll(errs []error) []error {
|
||||
|
|
59
common/json/badjson/array.go
Normal file
59
common/json/badjson/array.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
type JSONArray []any
|
||||
|
||||
func (a JSONArray) IsEmpty() bool {
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
return common.All(a, func(it any) bool {
|
||||
if valueInterface, valueMaybeEmpty := it.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
func (a JSONArray) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal([]any(a))
|
||||
}
|
||||
|
||||
func (a *JSONArray) UnmarshalJSON(content []byte) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||
arrayStart, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if arrayStart != json.Delim('[') {
|
||||
return E.New("excepted array start, but got ", arrayStart)
|
||||
}
|
||||
err = a.decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
arrayEnd, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if arrayEnd != json.Delim(']') {
|
||||
return E.New("excepted array end, but got ", arrayEnd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *JSONArray) decodeJSON(decoder *json.Decoder) error {
|
||||
for decoder.More() {
|
||||
item, err := decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*a = append(*a, item)
|
||||
}
|
||||
return nil
|
||||
}
|
5
common/json/badjson/empty.go
Normal file
5
common/json/badjson/empty.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package badjson
|
||||
|
||||
type isEmpty interface {
|
||||
IsEmpty() bool
|
||||
}
|
54
common/json/badjson/json.go
Normal file
54
common/json/badjson/json.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
func Decode(content []byte) (any, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||
return decodeJSON(decoder)
|
||||
}
|
||||
|
||||
func decodeJSON(decoder *json.Decoder) (any, error) {
|
||||
rawToken, err := decoder.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch token := rawToken.(type) {
|
||||
case json.Delim:
|
||||
switch token {
|
||||
case '{':
|
||||
var object JSONObject
|
||||
err = object.decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawToken, err = decoder.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if rawToken != json.Delim('}') {
|
||||
return nil, E.New("excepted object end, but got ", rawToken)
|
||||
}
|
||||
return &object, nil
|
||||
case '[':
|
||||
var array JSONArray
|
||||
err = array.decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawToken, err = decoder.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if rawToken != json.Delim(']') {
|
||||
return nil, E.New("excepted array end, but got ", rawToken)
|
||||
}
|
||||
return array, nil
|
||||
default:
|
||||
return nil, E.New("excepted object or array end: ", token)
|
||||
}
|
||||
}
|
||||
return rawToken, nil
|
||||
}
|
139
common/json/badjson/merge.go
Normal file
139
common/json/badjson/merge.go
Normal file
|
@ -0,0 +1,139 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
func Omitempty[T any](value T) (T, error) {
|
||||
objectContent, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal object")
|
||||
}
|
||||
rawNewObject, err := Decode(objectContent)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), err
|
||||
}
|
||||
newObjectContent, err := json.Marshal(rawNewObject)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
|
||||
}
|
||||
var newObject T
|
||||
err = json.Unmarshal(newObjectContent, &newObject)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
|
||||
}
|
||||
return newObject, nil
|
||||
}
|
||||
|
||||
func Merge[T any](source T, destination T) (T, error) {
|
||||
rawSource, err := json.Marshal(source)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
||||
}
|
||||
rawDestination, err := json.Marshal(destination)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
||||
}
|
||||
return MergeFrom[T](rawSource, rawDestination)
|
||||
}
|
||||
|
||||
func MergeFromSource[T any](rawSource json.RawMessage, destination T) (T, error) {
|
||||
if rawSource == nil {
|
||||
return destination, nil
|
||||
}
|
||||
rawDestination, err := json.Marshal(destination)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
||||
}
|
||||
return MergeFrom[T](rawSource, rawDestination)
|
||||
}
|
||||
|
||||
func MergeFromDestination[T any](source T, rawDestination json.RawMessage) (T, error) {
|
||||
if rawDestination == nil {
|
||||
return source, nil
|
||||
}
|
||||
rawSource, err := json.Marshal(source)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
||||
}
|
||||
return MergeFrom[T](rawSource, rawDestination)
|
||||
}
|
||||
|
||||
func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage) (T, error) {
|
||||
rawMerged, err := MergeJSON(rawSource, rawDestination)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "merge options")
|
||||
}
|
||||
var merged T
|
||||
err = json.Unmarshal(rawMerged, &merged)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage) (json.RawMessage, error) {
|
||||
if rawSource == nil && rawDestination == nil {
|
||||
return nil, os.ErrInvalid
|
||||
} else if rawSource == nil {
|
||||
return rawDestination, nil
|
||||
} else if rawDestination == nil {
|
||||
return rawSource, nil
|
||||
}
|
||||
source, err := Decode(rawSource)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode source")
|
||||
}
|
||||
destination, err := Decode(rawDestination)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode destination")
|
||||
}
|
||||
if source == nil {
|
||||
return json.Marshal(destination)
|
||||
} else if destination == nil {
|
||||
return json.Marshal(source)
|
||||
}
|
||||
merged, err := mergeJSON(source, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(merged)
|
||||
}
|
||||
|
||||
func mergeJSON(anySource any, anyDestination any) (any, error) {
|
||||
switch destination := anyDestination.(type) {
|
||||
case JSONArray:
|
||||
switch source := anySource.(type) {
|
||||
case JSONArray:
|
||||
destination = append(destination, source...)
|
||||
default:
|
||||
destination = append(destination, source)
|
||||
}
|
||||
return destination, nil
|
||||
case *JSONObject:
|
||||
switch source := anySource.(type) {
|
||||
case *JSONObject:
|
||||
for _, entry := range source.Entries() {
|
||||
oldValue, loaded := destination.Get(entry.Key)
|
||||
if loaded {
|
||||
var err error
|
||||
entry.Value, err = mergeJSON(entry.Value, oldValue)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "merge object item ", entry.Key)
|
||||
}
|
||||
}
|
||||
destination.Put(entry.Key, entry.Value)
|
||||
}
|
||||
default:
|
||||
return nil, E.New("cannot merge json object into ", reflect.TypeOf(source))
|
||||
}
|
||||
return destination, nil
|
||||
default:
|
||||
return destination, nil
|
||||
}
|
||||
}
|
98
common/json/badjson/object.go
Normal file
98
common/json/badjson/object.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/x/collections"
|
||||
"github.com/sagernet/sing/common/x/linkedhashmap"
|
||||
)
|
||||
|
||||
type JSONObject struct {
|
||||
linkedhashmap.Map[string, any]
|
||||
}
|
||||
|
||||
func (m *JSONObject) IsEmpty() bool {
|
||||
if m.Size() == 0 {
|
||||
return true
|
||||
}
|
||||
return common.All(m.Entries(), func(it collections.MapEntry[string, any]) bool {
|
||||
if valueInterface, valueMaybeEmpty := it.Value.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
func (m *JSONObject) MarshalJSON() ([]byte, error) {
|
||||
buffer := new(bytes.Buffer)
|
||||
buffer.WriteString("{")
|
||||
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
|
||||
if valueInterface, valueMaybeEmpty := it.Value.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
iLen := len(items)
|
||||
for i, entry := range items {
|
||||
keyContent, err := json.Marshal(entry.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
||||
buffer.WriteString(": ")
|
||||
valueContent, err := json.Marshal(entry.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.WriteString(strings.TrimSpace(string(valueContent)))
|
||||
if i < iLen-1 {
|
||||
buffer.WriteString(", ")
|
||||
}
|
||||
}
|
||||
buffer.WriteString("}")
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func (m *JSONObject) UnmarshalJSON(content []byte) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||
m.Clear()
|
||||
objectStart, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if objectStart != json.Delim('{') {
|
||||
return E.New("expected json object start, but starts with ", objectStart)
|
||||
}
|
||||
err = m.decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return E.Cause(err, "decode json object content")
|
||||
}
|
||||
objectEnd, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if objectEnd != json.Delim('}') {
|
||||
return E.New("expected json object end, but ends with ", objectEnd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *JSONObject) decodeJSON(decoder *json.Decoder) error {
|
||||
for decoder.More() {
|
||||
var entryKey string
|
||||
keyToken, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entryKey = keyToken.(string)
|
||||
var entryValue any
|
||||
entryValue, err = decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return E.Cause(err, "decode value for ", entryKey)
|
||||
}
|
||||
m.Put(entryKey, entryValue)
|
||||
}
|
||||
return nil
|
||||
}
|
86
common/json/badjson/typed.go
Normal file
86
common/json/badjson/typed.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/x/linkedhashmap"
|
||||
)
|
||||
|
||||
type TypedMap[K comparable, V any] struct {
|
||||
linkedhashmap.Map[K, V]
|
||||
}
|
||||
|
||||
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
|
||||
buffer := new(bytes.Buffer)
|
||||
buffer.WriteString("{")
|
||||
items := m.Entries()
|
||||
iLen := len(items)
|
||||
for i, entry := range items {
|
||||
keyContent, err := json.Marshal(entry.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
||||
buffer.WriteString(": ")
|
||||
valueContent, err := json.Marshal(entry.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.WriteString(strings.TrimSpace(string(valueContent)))
|
||||
if i < iLen-1 {
|
||||
buffer.WriteString(", ")
|
||||
}
|
||||
}
|
||||
buffer.WriteString("}")
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||
m.Clear()
|
||||
objectStart, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if objectStart != json.Delim('{') {
|
||||
return E.New("expected json object start, but starts with ", objectStart)
|
||||
}
|
||||
err = m.decodeJSON(decoder)
|
||||
if err != nil {
|
||||
return E.Cause(err, "decode json object content")
|
||||
}
|
||||
objectEnd, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if objectEnd != json.Delim('}') {
|
||||
return E.New("expected json object end, but ends with ", objectEnd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error {
|
||||
for decoder.More() {
|
||||
keyToken, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyContent, err := json.Marshal(keyToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var entryKey K
|
||||
err = json.Unmarshal(keyContent, &entryKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var entryValue V
|
||||
err = decoder.Decode(&entryValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Put(entryKey, entryValue)
|
||||
}
|
||||
return nil
|
||||
}
|
128
common/json/comment.go
Normal file
128
common/json/comment.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
)
|
||||
|
||||
// kanged from v2ray
|
||||
|
||||
type commentFilterState = byte
|
||||
|
||||
const (
|
||||
commentFilterStateContent commentFilterState = iota
|
||||
commentFilterStateEscape
|
||||
commentFilterStateDoubleQuote
|
||||
commentFilterStateDoubleQuoteEscape
|
||||
commentFilterStateSingleQuote
|
||||
commentFilterStateSingleQuoteEscape
|
||||
commentFilterStateComment
|
||||
commentFilterStateSlash
|
||||
commentFilterStateMultilineComment
|
||||
commentFilterStateMultilineCommentStar
|
||||
)
|
||||
|
||||
type CommentFilter struct {
|
||||
br *bufio.Reader
|
||||
state commentFilterState
|
||||
}
|
||||
|
||||
func NewCommentFilter(reader io.Reader) io.Reader {
|
||||
return &CommentFilter{br: bufio.NewReader(reader)}
|
||||
}
|
||||
|
||||
func (v *CommentFilter) Read(b []byte) (int, error) {
|
||||
p := b[:0]
|
||||
for len(p) < len(b)-2 {
|
||||
x, err := v.br.ReadByte()
|
||||
if err != nil {
|
||||
if len(p) == 0 {
|
||||
return 0, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
switch v.state {
|
||||
case commentFilterStateContent:
|
||||
switch x {
|
||||
case '"':
|
||||
v.state = commentFilterStateDoubleQuote
|
||||
p = append(p, x)
|
||||
case '\'':
|
||||
v.state = commentFilterStateSingleQuote
|
||||
p = append(p, x)
|
||||
case '\\':
|
||||
v.state = commentFilterStateEscape
|
||||
case '#':
|
||||
v.state = commentFilterStateComment
|
||||
case '/':
|
||||
v.state = commentFilterStateSlash
|
||||
default:
|
||||
p = append(p, x)
|
||||
}
|
||||
case commentFilterStateEscape:
|
||||
p = append(p, '\\', x)
|
||||
v.state = commentFilterStateContent
|
||||
case commentFilterStateDoubleQuote:
|
||||
switch x {
|
||||
case '"':
|
||||
v.state = commentFilterStateContent
|
||||
p = append(p, x)
|
||||
case '\\':
|
||||
v.state = commentFilterStateDoubleQuoteEscape
|
||||
default:
|
||||
p = append(p, x)
|
||||
}
|
||||
case commentFilterStateDoubleQuoteEscape:
|
||||
p = append(p, '\\', x)
|
||||
v.state = commentFilterStateDoubleQuote
|
||||
case commentFilterStateSingleQuote:
|
||||
switch x {
|
||||
case '\'':
|
||||
v.state = commentFilterStateContent
|
||||
p = append(p, x)
|
||||
case '\\':
|
||||
v.state = commentFilterStateSingleQuoteEscape
|
||||
default:
|
||||
p = append(p, x)
|
||||
}
|
||||
case commentFilterStateSingleQuoteEscape:
|
||||
p = append(p, '\\', x)
|
||||
v.state = commentFilterStateSingleQuote
|
||||
case commentFilterStateComment:
|
||||
if x == '\n' {
|
||||
v.state = commentFilterStateContent
|
||||
p = append(p, '\n')
|
||||
}
|
||||
case commentFilterStateSlash:
|
||||
switch x {
|
||||
case '/':
|
||||
v.state = commentFilterStateComment
|
||||
case '*':
|
||||
v.state = commentFilterStateMultilineComment
|
||||
default:
|
||||
p = append(p, '/', x)
|
||||
}
|
||||
case commentFilterStateMultilineComment:
|
||||
switch x {
|
||||
case '*':
|
||||
v.state = commentFilterStateMultilineCommentStar
|
||||
case '\n':
|
||||
p = append(p, '\n')
|
||||
}
|
||||
case commentFilterStateMultilineCommentStar:
|
||||
switch x {
|
||||
case '/':
|
||||
v.state = commentFilterStateContent
|
||||
case '*':
|
||||
// Stay
|
||||
case '\n':
|
||||
p = append(p, '\n')
|
||||
default:
|
||||
v.state = commentFilterStateMultilineComment
|
||||
}
|
||||
default:
|
||||
panic("Unknown state.")
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
23
common/json/context.go
Normal file
23
common/json/context.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
//go:build go1.20 && !without_contextjson
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/json/internal/contextjson"
|
||||
)
|
||||
|
||||
var (
|
||||
Marshal = json.Marshal
|
||||
Unmarshal = json.Unmarshal
|
||||
NewEncoder = json.NewEncoder
|
||||
NewDecoder = json.NewDecoder
|
||||
)
|
||||
|
||||
type (
|
||||
Encoder = json.Encoder
|
||||
Decoder = json.Decoder
|
||||
Token = json.Token
|
||||
Delim = json.Delim
|
||||
SyntaxError = json.SyntaxError
|
||||
RawMessage = json.RawMessage
|
||||
)
|
3
common/json/internal/contextjson/README.md
Normal file
3
common/json/internal/contextjson/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# contextjson
|
||||
|
||||
mod from go1.21.4
|
1325
common/json/internal/contextjson/decode.go
Normal file
1325
common/json/internal/contextjson/decode.go
Normal file
File diff suppressed because it is too large
Load diff
49
common/json/internal/contextjson/decode_context.go
Normal file
49
common/json/internal/contextjson/decode_context.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package json
|
||||
|
||||
import "strconv"
|
||||
|
||||
type decodeContext struct {
|
||||
parent *decodeContext
|
||||
index int
|
||||
key string
|
||||
}
|
||||
|
||||
func (d *decodeState) formatContext() string {
|
||||
var description string
|
||||
context := d.context
|
||||
var appendDot bool
|
||||
for context != nil {
|
||||
if appendDot {
|
||||
description = "." + description
|
||||
}
|
||||
if context.key != "" {
|
||||
description = context.key + description
|
||||
appendDot = true
|
||||
} else {
|
||||
description = "[" + strconv.Itoa(context.index) + "]" + description
|
||||
appendDot = false
|
||||
}
|
||||
context = context.parent
|
||||
}
|
||||
return description
|
||||
}
|
||||
|
||||
type contextError struct {
|
||||
parent error
|
||||
context string
|
||||
index bool
|
||||
}
|
||||
|
||||
func (c *contextError) Unwrap() error {
|
||||
return c.parent
|
||||
}
|
||||
|
||||
func (c *contextError) Error() string {
|
||||
//goland:noinspection GoTypeAssertionOnErrors
|
||||
switch c.parent.(type) {
|
||||
case *contextError:
|
||||
return c.context + "." + c.parent.Error()
|
||||
default:
|
||||
return c.context + ": " + c.parent.Error()
|
||||
}
|
||||
}
|
1283
common/json/internal/contextjson/encode.go
Normal file
1283
common/json/internal/contextjson/encode.go
Normal file
File diff suppressed because it is too large
Load diff
48
common/json/internal/contextjson/fold.go
Normal file
48
common/json/internal/contextjson/fold.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// foldName returns a folded string such that foldName(x) == foldName(y)
|
||||
// is identical to bytes.EqualFold(x, y).
|
||||
func foldName(in []byte) []byte {
|
||||
// This is inlinable to take advantage of "function outlining".
|
||||
var arr [32]byte // large enough for most JSON names
|
||||
return appendFoldedName(arr[:0], in)
|
||||
}
|
||||
|
||||
func appendFoldedName(out, in []byte) []byte {
|
||||
for i := 0; i < len(in); {
|
||||
// Handle single-byte ASCII.
|
||||
if c := in[i]; c < utf8.RuneSelf {
|
||||
if 'a' <= c && c <= 'z' {
|
||||
c -= 'a' - 'A'
|
||||
}
|
||||
out = append(out, c)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
// Handle multi-byte Unicode.
|
||||
r, n := utf8.DecodeRune(in[i:])
|
||||
out = utf8.AppendRune(out, foldRune(r))
|
||||
i += n
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// foldRune is returns the smallest rune for all runes in the same fold set.
|
||||
func foldRune(r rune) rune {
|
||||
for {
|
||||
r2 := unicode.SimpleFold(r)
|
||||
if r2 <= r {
|
||||
return r2
|
||||
}
|
||||
r = r2
|
||||
}
|
||||
}
|
179
common/json/internal/contextjson/indent.go
Normal file
179
common/json/internal/contextjson/indent.go
Normal file
|
@ -0,0 +1,179 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import "bytes"
|
||||
|
||||
// TODO(https://go.dev/issue/53685): Use bytes.Buffer.AvailableBuffer instead.
|
||||
func availableBuffer(b *bytes.Buffer) []byte {
|
||||
return b.Bytes()[b.Len():]
|
||||
}
|
||||
|
||||
// HTMLEscape appends to dst the JSON-encoded src with <, >, &, U+2028 and U+2029
|
||||
// characters inside string literals changed to \u003c, \u003e, \u0026, \u2028, \u2029
|
||||
// so that the JSON will be safe to embed inside HTML <script> tags.
|
||||
// For historical reasons, web browsers don't honor standard HTML
|
||||
// escaping within <script> tags, so an alternative JSON encoding must be used.
|
||||
func HTMLEscape(dst *bytes.Buffer, src []byte) {
|
||||
dst.Grow(len(src))
|
||||
dst.Write(appendHTMLEscape(availableBuffer(dst), src))
|
||||
}
|
||||
|
||||
func appendHTMLEscape(dst, src []byte) []byte {
|
||||
// The characters can only appear in string literals,
|
||||
// so just scan the string one byte at a time.
|
||||
start := 0
|
||||
for i, c := range src {
|
||||
if c == '<' || c == '>' || c == '&' {
|
||||
dst = append(dst, src[start:i]...)
|
||||
dst = append(dst, '\\', 'u', '0', '0', hex[c>>4], hex[c&0xF])
|
||||
start = i + 1
|
||||
}
|
||||
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
|
||||
if c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
|
||||
dst = append(dst, src[start:i]...)
|
||||
dst = append(dst, '\\', 'u', '2', '0', '2', hex[src[i+2]&0xF])
|
||||
start = i + len("\u2029")
|
||||
}
|
||||
}
|
||||
return append(dst, src[start:]...)
|
||||
}
|
||||
|
||||
// Compact appends to dst the JSON-encoded src with
|
||||
// insignificant space characters elided.
|
||||
func Compact(dst *bytes.Buffer, src []byte) error {
|
||||
dst.Grow(len(src))
|
||||
b := availableBuffer(dst)
|
||||
b, err := appendCompact(b, src, false)
|
||||
dst.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func appendCompact(dst, src []byte, escape bool) ([]byte, error) {
|
||||
origLen := len(dst)
|
||||
scan := newScanner()
|
||||
defer freeScanner(scan)
|
||||
start := 0
|
||||
for i, c := range src {
|
||||
if escape && (c == '<' || c == '>' || c == '&') {
|
||||
dst = append(dst, src[start:i]...)
|
||||
dst = append(dst, '\\', 'u', '0', '0', hex[c>>4], hex[c&0xF])
|
||||
start = i + 1
|
||||
}
|
||||
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
|
||||
if escape && c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
|
||||
dst = append(dst, src[start:i]...)
|
||||
dst = append(dst, '\\', 'u', '2', '0', '2', hex[src[i+2]&0xF])
|
||||
start = i + len("\u2029")
|
||||
}
|
||||
v := scan.step(scan, c)
|
||||
if v >= scanSkipSpace {
|
||||
if v == scanError {
|
||||
break
|
||||
}
|
||||
dst = append(dst, src[start:i]...)
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if scan.eof() == scanError {
|
||||
return dst[:origLen], scan.err
|
||||
}
|
||||
dst = append(dst, src[start:]...)
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
func appendNewline(dst []byte, prefix, indent string, depth int) []byte {
|
||||
dst = append(dst, '\n')
|
||||
dst = append(dst, prefix...)
|
||||
for i := 0; i < depth; i++ {
|
||||
dst = append(dst, indent...)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// indentGrowthFactor specifies the growth factor of indenting JSON input.
|
||||
// Empirically, the growth factor was measured to be between 1.4x to 1.8x
|
||||
// for some set of compacted JSON with the indent being a single tab.
|
||||
// Specify a growth factor slightly larger than what is observed
|
||||
// to reduce probability of allocation in appendIndent.
|
||||
// A factor no higher than 2 ensures that wasted space never exceeds 50%.
|
||||
const indentGrowthFactor = 2
|
||||
|
||||
// Indent appends to dst an indented form of the JSON-encoded src.
|
||||
// Each element in a JSON object or array begins on a new,
|
||||
// indented line beginning with prefix followed by one or more
|
||||
// copies of indent according to the indentation nesting.
|
||||
// The data appended to dst does not begin with the prefix nor
|
||||
// any indentation, to make it easier to embed inside other formatted JSON data.
|
||||
// Although leading space characters (space, tab, carriage return, newline)
|
||||
// at the beginning of src are dropped, trailing space characters
|
||||
// at the end of src are preserved and copied to dst.
|
||||
// For example, if src has no trailing spaces, neither will dst;
|
||||
// if src ends in a trailing newline, so will dst.
|
||||
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
|
||||
dst.Grow(indentGrowthFactor * len(src))
|
||||
b := availableBuffer(dst)
|
||||
b, err := appendIndent(b, src, prefix, indent)
|
||||
dst.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func appendIndent(dst, src []byte, prefix, indent string) ([]byte, error) {
|
||||
origLen := len(dst)
|
||||
scan := newScanner()
|
||||
defer freeScanner(scan)
|
||||
needIndent := false
|
||||
depth := 0
|
||||
for _, c := range src {
|
||||
scan.bytes++
|
||||
v := scan.step(scan, c)
|
||||
if v == scanSkipSpace {
|
||||
continue
|
||||
}
|
||||
if v == scanError {
|
||||
break
|
||||
}
|
||||
if needIndent && v != scanEndObject && v != scanEndArray {
|
||||
needIndent = false
|
||||
depth++
|
||||
dst = appendNewline(dst, prefix, indent, depth)
|
||||
}
|
||||
|
||||
// Emit semantically uninteresting bytes
|
||||
// (in particular, punctuation in strings) unmodified.
|
||||
if v == scanContinue {
|
||||
dst = append(dst, c)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add spacing around real punctuation.
|
||||
switch c {
|
||||
case '{', '[':
|
||||
// delay indent so that empty object and array are formatted as {} and [].
|
||||
needIndent = true
|
||||
dst = append(dst, c)
|
||||
case ',':
|
||||
dst = append(dst, c)
|
||||
dst = appendNewline(dst, prefix, indent, depth)
|
||||
case ':':
|
||||
dst = append(dst, c, ' ')
|
||||
case '}', ']':
|
||||
if needIndent {
|
||||
// suppress indent in empty object/array
|
||||
needIndent = false
|
||||
} else {
|
||||
depth--
|
||||
dst = appendNewline(dst, prefix, indent, depth)
|
||||
}
|
||||
dst = append(dst, c)
|
||||
default:
|
||||
dst = append(dst, c)
|
||||
}
|
||||
}
|
||||
if scan.eof() == scanError {
|
||||
return dst[:origLen], scan.err
|
||||
}
|
||||
return dst, nil
|
||||
}
|
610
common/json/internal/contextjson/scanner.go
Normal file
610
common/json/internal/contextjson/scanner.go
Normal file
|
@ -0,0 +1,610 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
// JSON value parser state machine.
|
||||
// Just about at the limit of what is reasonable to write by hand.
|
||||
// Some parts are a bit tedious, but overall it nicely factors out the
|
||||
// otherwise common code from the multiple scanning functions
|
||||
// in this package (Compact, Indent, checkValid, etc).
|
||||
//
|
||||
// This file starts with two simple examples using the scanner
|
||||
// before diving into the scanner itself.
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Valid reports whether data is a valid JSON encoding.
|
||||
func Valid(data []byte) bool {
|
||||
scan := newScanner()
|
||||
defer freeScanner(scan)
|
||||
return checkValid(data, scan) == nil
|
||||
}
|
||||
|
||||
// checkValid verifies that data is valid JSON-encoded data.
|
||||
// scan is passed in for use by checkValid to avoid an allocation.
|
||||
// checkValid returns nil or a SyntaxError.
|
||||
func checkValid(data []byte, scan *scanner) error {
|
||||
scan.reset()
|
||||
for _, c := range data {
|
||||
scan.bytes++
|
||||
if scan.step(scan, c) == scanError {
|
||||
return scan.err
|
||||
}
|
||||
}
|
||||
if scan.eof() == scanError {
|
||||
return scan.err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// A SyntaxError is a description of a JSON syntax error.
|
||||
// Unmarshal will return a SyntaxError if the JSON can't be parsed.
|
||||
type SyntaxError struct {
|
||||
msg string // description of error
|
||||
Offset int64 // error occurred after reading Offset bytes
|
||||
}
|
||||
|
||||
func (e *SyntaxError) Error() string { return e.msg }
|
||||
|
||||
// A scanner is a JSON scanning state machine.
|
||||
// Callers call scan.reset and then pass bytes in one at a time
|
||||
// by calling scan.step(&scan, c) for each byte.
|
||||
// The return value, referred to as an opcode, tells the
|
||||
// caller about significant parsing events like beginning
|
||||
// and ending literals, objects, and arrays, so that the
|
||||
// caller can follow along if it wishes.
|
||||
// The return value scanEnd indicates that a single top-level
|
||||
// JSON value has been completed, *before* the byte that
|
||||
// just got passed in. (The indication must be delayed in order
|
||||
// to recognize the end of numbers: is 123 a whole value or
|
||||
// the beginning of 12345e+6?).
|
||||
type scanner struct {
|
||||
// The step is a func to be called to execute the next transition.
|
||||
// Also tried using an integer constant and a single func
|
||||
// with a switch, but using the func directly was 10% faster
|
||||
// on a 64-bit Mac Mini, and it's nicer to read.
|
||||
step func(*scanner, byte) int
|
||||
|
||||
// Reached end of top-level value.
|
||||
endTop bool
|
||||
|
||||
// Stack of what we're in the middle of - array values, object keys, object values.
|
||||
parseState []int
|
||||
|
||||
// Error that happened, if any.
|
||||
err error
|
||||
|
||||
// total bytes consumed, updated by decoder.Decode (and deliberately
|
||||
// not set to zero by scan.reset)
|
||||
bytes int64
|
||||
}
|
||||
|
||||
var scannerPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &scanner{}
|
||||
},
|
||||
}
|
||||
|
||||
func newScanner() *scanner {
|
||||
scan := scannerPool.Get().(*scanner)
|
||||
// scan.reset by design doesn't set bytes to zero
|
||||
scan.bytes = 0
|
||||
scan.reset()
|
||||
return scan
|
||||
}
|
||||
|
||||
func freeScanner(scan *scanner) {
|
||||
// Avoid hanging on to too much memory in extreme cases.
|
||||
if len(scan.parseState) > 1024 {
|
||||
scan.parseState = nil
|
||||
}
|
||||
scannerPool.Put(scan)
|
||||
}
|
||||
|
||||
// These values are returned by the state transition functions
|
||||
// assigned to scanner.state and the method scanner.eof.
|
||||
// They give details about the current state of the scan that
|
||||
// callers might be interested to know about.
|
||||
// It is okay to ignore the return value of any particular
|
||||
// call to scanner.state: if one call returns scanError,
|
||||
// every subsequent call will return scanError too.
|
||||
const (
|
||||
// Continue.
|
||||
scanContinue = iota // uninteresting byte
|
||||
scanBeginLiteral // end implied by next result != scanContinue
|
||||
scanBeginObject // begin object
|
||||
scanObjectKey // just finished object key (string)
|
||||
scanObjectValue // just finished non-last object value
|
||||
scanEndObject // end object (implies scanObjectValue if possible)
|
||||
scanBeginArray // begin array
|
||||
scanArrayValue // just finished array value
|
||||
scanEndArray // end array (implies scanArrayValue if possible)
|
||||
scanSkipSpace // space byte; can skip; known to be last "continue" result
|
||||
|
||||
// Stop.
|
||||
scanEnd // top-level value ended *before* this byte; known to be first "stop" result
|
||||
scanError // hit an error, scanner.err.
|
||||
)
|
||||
|
||||
// These values are stored in the parseState stack.
|
||||
// They give the current state of a composite value
|
||||
// being scanned. If the parser is inside a nested value
|
||||
// the parseState describes the nested state, outermost at entry 0.
|
||||
const (
|
||||
parseObjectKey = iota // parsing object key (before colon)
|
||||
parseObjectValue // parsing object value (after colon)
|
||||
parseArrayValue // parsing array value
|
||||
)
|
||||
|
||||
// This limits the max nesting depth to prevent stack overflow.
|
||||
// This is permitted by https://tools.ietf.org/html/rfc7159#section-9
|
||||
const maxNestingDepth = 10000
|
||||
|
||||
// reset prepares the scanner for use.
|
||||
// It must be called before calling s.step.
|
||||
func (s *scanner) reset() {
|
||||
s.step = stateBeginValue
|
||||
s.parseState = s.parseState[0:0]
|
||||
s.err = nil
|
||||
s.endTop = false
|
||||
}
|
||||
|
||||
// eof tells the scanner that the end of input has been reached.
|
||||
// It returns a scan status just as s.step does.
|
||||
func (s *scanner) eof() int {
|
||||
if s.err != nil {
|
||||
return scanError
|
||||
}
|
||||
if s.endTop {
|
||||
return scanEnd
|
||||
}
|
||||
s.step(s, ' ')
|
||||
if s.endTop {
|
||||
return scanEnd
|
||||
}
|
||||
if s.err == nil {
|
||||
s.err = &SyntaxError{"unexpected end of JSON input", s.bytes}
|
||||
}
|
||||
return scanError
|
||||
}
|
||||
|
||||
// pushParseState pushes a new parse state p onto the parse stack.
|
||||
// an error state is returned if maxNestingDepth was exceeded, otherwise successState is returned.
|
||||
func (s *scanner) pushParseState(c byte, newParseState int, successState int) int {
|
||||
s.parseState = append(s.parseState, newParseState)
|
||||
if len(s.parseState) <= maxNestingDepth {
|
||||
return successState
|
||||
}
|
||||
return s.error(c, "exceeded max depth")
|
||||
}
|
||||
|
||||
// popParseState pops a parse state (already obtained) off the stack
|
||||
// and updates s.step accordingly.
|
||||
func (s *scanner) popParseState() {
|
||||
n := len(s.parseState) - 1
|
||||
s.parseState = s.parseState[0:n]
|
||||
if n == 0 {
|
||||
s.step = stateEndTop
|
||||
s.endTop = true
|
||||
} else {
|
||||
s.step = stateEndValue
|
||||
}
|
||||
}
|
||||
|
||||
func isSpace(c byte) bool {
|
||||
return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n')
|
||||
}
|
||||
|
||||
// stateBeginValueOrEmpty is the state after reading `[`.
|
||||
func stateBeginValueOrEmpty(s *scanner, c byte) int {
|
||||
if isSpace(c) {
|
||||
return scanSkipSpace
|
||||
}
|
||||
if c == ']' {
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
return stateBeginValue(s, c)
|
||||
}
|
||||
|
||||
// stateBeginValue is the state at the beginning of the input.
|
||||
func stateBeginValue(s *scanner, c byte) int {
|
||||
if isSpace(c) {
|
||||
return scanSkipSpace
|
||||
}
|
||||
switch c {
|
||||
case '{':
|
||||
s.step = stateBeginStringOrEmpty
|
||||
return s.pushParseState(c, parseObjectKey, scanBeginObject)
|
||||
case '[':
|
||||
s.step = stateBeginValueOrEmpty
|
||||
return s.pushParseState(c, parseArrayValue, scanBeginArray)
|
||||
case '"':
|
||||
s.step = stateInString
|
||||
return scanBeginLiteral
|
||||
case '-':
|
||||
s.step = stateNeg
|
||||
return scanBeginLiteral
|
||||
case '0': // beginning of 0.123
|
||||
s.step = state0
|
||||
return scanBeginLiteral
|
||||
case 't': // beginning of true
|
||||
s.step = stateT
|
||||
return scanBeginLiteral
|
||||
case 'f': // beginning of false
|
||||
s.step = stateF
|
||||
return scanBeginLiteral
|
||||
case 'n': // beginning of null
|
||||
s.step = stateN
|
||||
return scanBeginLiteral
|
||||
}
|
||||
if '1' <= c && c <= '9' { // beginning of 1234.5
|
||||
s.step = state1
|
||||
return scanBeginLiteral
|
||||
}
|
||||
return s.error(c, "looking for beginning of value")
|
||||
}
|
||||
|
||||
// stateBeginStringOrEmpty is the state after reading `{`.
|
||||
func stateBeginStringOrEmpty(s *scanner, c byte) int {
|
||||
if isSpace(c) {
|
||||
return scanSkipSpace
|
||||
}
|
||||
if c == '}' {
|
||||
n := len(s.parseState)
|
||||
s.parseState[n-1] = parseObjectValue
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
return stateBeginString(s, c)
|
||||
}
|
||||
|
||||
// stateBeginString is the state after reading `{"key": value,`.
|
||||
func stateBeginString(s *scanner, c byte) int {
|
||||
if isSpace(c) {
|
||||
return scanSkipSpace
|
||||
}
|
||||
if c == '"' {
|
||||
s.step = stateInString
|
||||
return scanBeginLiteral
|
||||
}
|
||||
return s.error(c, "looking for beginning of object key string")
|
||||
}
|
||||
|
||||
// stateEndValue is the state after completing a value,
|
||||
// such as after reading `{}` or `true` or `["x"`.
|
||||
func stateEndValue(s *scanner, c byte) int {
|
||||
n := len(s.parseState)
|
||||
if n == 0 {
|
||||
// Completed top-level before the current byte.
|
||||
s.step = stateEndTop
|
||||
s.endTop = true
|
||||
return stateEndTop(s, c)
|
||||
}
|
||||
if isSpace(c) {
|
||||
s.step = stateEndValue
|
||||
return scanSkipSpace
|
||||
}
|
||||
ps := s.parseState[n-1]
|
||||
switch ps {
|
||||
case parseObjectKey:
|
||||
if c == ':' {
|
||||
s.parseState[n-1] = parseObjectValue
|
||||
s.step = stateBeginValue
|
||||
return scanObjectKey
|
||||
}
|
||||
return s.error(c, "after object key")
|
||||
case parseObjectValue:
|
||||
if c == ',' {
|
||||
s.parseState[n-1] = parseObjectKey
|
||||
s.step = stateBeginStringOrEmpty
|
||||
return scanObjectValue
|
||||
}
|
||||
if c == '}' {
|
||||
s.popParseState()
|
||||
return scanEndObject
|
||||
}
|
||||
return s.error(c, "after object key:value pair")
|
||||
case parseArrayValue:
|
||||
if c == ',' {
|
||||
s.step = stateBeginValueOrEmpty
|
||||
return scanArrayValue
|
||||
}
|
||||
if c == ']' {
|
||||
s.popParseState()
|
||||
return scanEndArray
|
||||
}
|
||||
return s.error(c, "after array element")
|
||||
}
|
||||
return s.error(c, "")
|
||||
}
|
||||
|
||||
// stateEndTop is the state after finishing the top-level value,
|
||||
// such as after reading `{}` or `[1,2,3]`.
|
||||
// Only space characters should be seen now.
|
||||
func stateEndTop(s *scanner, c byte) int {
|
||||
if !isSpace(c) {
|
||||
// Complain about non-space byte on next call.
|
||||
s.error(c, "after top-level value")
|
||||
}
|
||||
return scanEnd
|
||||
}
|
||||
|
||||
// stateInString is the state after reading `"`.
|
||||
func stateInString(s *scanner, c byte) int {
|
||||
if c == '"' {
|
||||
s.step = stateEndValue
|
||||
return scanContinue
|
||||
}
|
||||
if c == '\\' {
|
||||
s.step = stateInStringEsc
|
||||
return scanContinue
|
||||
}
|
||||
if c < 0x20 {
|
||||
return s.error(c, "in string literal")
|
||||
}
|
||||
return scanContinue
|
||||
}
|
||||
|
||||
// stateInStringEsc is the state after reading `"\` during a quoted string.
|
||||
func stateInStringEsc(s *scanner, c byte) int {
|
||||
switch c {
|
||||
case 'b', 'f', 'n', 'r', 't', '\\', '/', '"':
|
||||
s.step = stateInString
|
||||
return scanContinue
|
||||
case 'u':
|
||||
s.step = stateInStringEscU
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in string escape code")
|
||||
}
|
||||
|
||||
// stateInStringEscU is the state after reading `"\u` during a quoted string.
|
||||
func stateInStringEscU(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||
s.step = stateInStringEscU1
|
||||
return scanContinue
|
||||
}
|
||||
// numbers
|
||||
return s.error(c, "in \\u hexadecimal character escape")
|
||||
}
|
||||
|
||||
// stateInStringEscU1 is the state after reading `"\u1` during a quoted string.
|
||||
func stateInStringEscU1(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||
s.step = stateInStringEscU12
|
||||
return scanContinue
|
||||
}
|
||||
// numbers
|
||||
return s.error(c, "in \\u hexadecimal character escape")
|
||||
}
|
||||
|
||||
// stateInStringEscU12 is the state after reading `"\u12` during a quoted string.
|
||||
func stateInStringEscU12(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||
s.step = stateInStringEscU123
|
||||
return scanContinue
|
||||
}
|
||||
// numbers
|
||||
return s.error(c, "in \\u hexadecimal character escape")
|
||||
}
|
||||
|
||||
// stateInStringEscU123 is the state after reading `"\u123` during a quoted string.
|
||||
func stateInStringEscU123(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||
s.step = stateInString
|
||||
return scanContinue
|
||||
}
|
||||
// numbers
|
||||
return s.error(c, "in \\u hexadecimal character escape")
|
||||
}
|
||||
|
||||
// stateNeg is the state after reading `-` during a number.
|
||||
func stateNeg(s *scanner, c byte) int {
|
||||
if c == '0' {
|
||||
s.step = state0
|
||||
return scanContinue
|
||||
}
|
||||
if '1' <= c && c <= '9' {
|
||||
s.step = state1
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in numeric literal")
|
||||
}
|
||||
|
||||
// state1 is the state after reading a non-zero integer during a number,
|
||||
// such as after reading `1` or `100` but not `0`.
|
||||
func state1(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
s.step = state1
|
||||
return scanContinue
|
||||
}
|
||||
return state0(s, c)
|
||||
}
|
||||
|
||||
// state0 is the state after reading `0` during a number.
|
||||
func state0(s *scanner, c byte) int {
|
||||
if c == '.' {
|
||||
s.step = stateDot
|
||||
return scanContinue
|
||||
}
|
||||
if c == 'e' || c == 'E' {
|
||||
s.step = stateE
|
||||
return scanContinue
|
||||
}
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
|
||||
// stateDot is the state after reading the integer and decimal point in a number,
|
||||
// such as after reading `1.`.
|
||||
func stateDot(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
s.step = stateDot0
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "after decimal point in numeric literal")
|
||||
}
|
||||
|
||||
// stateDot0 is the state after reading the integer, decimal point, and subsequent
|
||||
// digits of a number, such as after reading `3.14`.
|
||||
func stateDot0(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
return scanContinue
|
||||
}
|
||||
if c == 'e' || c == 'E' {
|
||||
s.step = stateE
|
||||
return scanContinue
|
||||
}
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
|
||||
// stateE is the state after reading the mantissa and e in a number,
|
||||
// such as after reading `314e` or `0.314e`.
|
||||
func stateE(s *scanner, c byte) int {
|
||||
if c == '+' || c == '-' {
|
||||
s.step = stateESign
|
||||
return scanContinue
|
||||
}
|
||||
return stateESign(s, c)
|
||||
}
|
||||
|
||||
// stateESign is the state after reading the mantissa, e, and sign in a number,
|
||||
// such as after reading `314e-` or `0.314e+`.
|
||||
func stateESign(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
s.step = stateE0
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in exponent of numeric literal")
|
||||
}
|
||||
|
||||
// stateE0 is the state after reading the mantissa, e, optional sign,
|
||||
// and at least one digit of the exponent in a number,
|
||||
// such as after reading `314e-2` or `0.314e+1` or `3.14e0`.
|
||||
func stateE0(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
return scanContinue
|
||||
}
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
|
||||
// stateT is the state after reading `t`.
|
||||
func stateT(s *scanner, c byte) int {
|
||||
if c == 'r' {
|
||||
s.step = stateTr
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal true (expecting 'r')")
|
||||
}
|
||||
|
||||
// stateTr is the state after reading `tr`.
|
||||
func stateTr(s *scanner, c byte) int {
|
||||
if c == 'u' {
|
||||
s.step = stateTru
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal true (expecting 'u')")
|
||||
}
|
||||
|
||||
// stateTru is the state after reading `tru`.
|
||||
func stateTru(s *scanner, c byte) int {
|
||||
if c == 'e' {
|
||||
s.step = stateEndValue
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal true (expecting 'e')")
|
||||
}
|
||||
|
||||
// stateF is the state after reading `f`.
|
||||
func stateF(s *scanner, c byte) int {
|
||||
if c == 'a' {
|
||||
s.step = stateFa
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal false (expecting 'a')")
|
||||
}
|
||||
|
||||
// stateFa is the state after reading `fa`.
|
||||
func stateFa(s *scanner, c byte) int {
|
||||
if c == 'l' {
|
||||
s.step = stateFal
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal false (expecting 'l')")
|
||||
}
|
||||
|
||||
// stateFal is the state after reading `fal`.
|
||||
func stateFal(s *scanner, c byte) int {
|
||||
if c == 's' {
|
||||
s.step = stateFals
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal false (expecting 's')")
|
||||
}
|
||||
|
||||
// stateFals is the state after reading `fals`.
|
||||
func stateFals(s *scanner, c byte) int {
|
||||
if c == 'e' {
|
||||
s.step = stateEndValue
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal false (expecting 'e')")
|
||||
}
|
||||
|
||||
// stateN is the state after reading `n`.
|
||||
func stateN(s *scanner, c byte) int {
|
||||
if c == 'u' {
|
||||
s.step = stateNu
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal null (expecting 'u')")
|
||||
}
|
||||
|
||||
// stateNu is the state after reading `nu`.
|
||||
func stateNu(s *scanner, c byte) int {
|
||||
if c == 'l' {
|
||||
s.step = stateNul
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal null (expecting 'l')")
|
||||
}
|
||||
|
||||
// stateNul is the state after reading `nul`.
|
||||
func stateNul(s *scanner, c byte) int {
|
||||
if c == 'l' {
|
||||
s.step = stateEndValue
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal null (expecting 'l')")
|
||||
}
|
||||
|
||||
// stateError is the state after reaching a syntax error,
|
||||
// such as after reading `[1}` or `5.1.2`.
|
||||
func stateError(s *scanner, c byte) int {
|
||||
return scanError
|
||||
}
|
||||
|
||||
// error records an error and switches to the error state.
|
||||
func (s *scanner) error(c byte, context string) int {
|
||||
s.step = stateError
|
||||
s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes}
|
||||
return scanError
|
||||
}
|
||||
|
||||
// quoteChar formats c as a quoted character literal.
|
||||
func quoteChar(c byte) string {
|
||||
// special cases - different from quoted strings
|
||||
if c == '\'' {
|
||||
return `'\''`
|
||||
}
|
||||
if c == '"' {
|
||||
return `'"'`
|
||||
}
|
||||
|
||||
// use quoted string with different quotation marks
|
||||
s := strconv.Quote(string(c))
|
||||
return "'" + s[1:len(s)-1] + "'"
|
||||
}
|
554
common/json/internal/contextjson/stream.go
Normal file
554
common/json/internal/contextjson/stream.go
Normal file
|
@ -0,0 +1,554 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// A Decoder reads and decodes JSON values from an input stream.
|
||||
type Decoder struct {
|
||||
r io.Reader
|
||||
buf []byte
|
||||
d decodeState
|
||||
scanp int // start of unread data in buf
|
||||
scanned int64 // amount of data already scanned
|
||||
scan scanner
|
||||
err error
|
||||
|
||||
tokenState int
|
||||
tokenStack []int
|
||||
}
|
||||
|
||||
// NewDecoder returns a new decoder that reads from r.
|
||||
//
|
||||
// The decoder introduces its own buffering and may
|
||||
// read data from r beyond the JSON values requested.
|
||||
func NewDecoder(r io.Reader) *Decoder {
|
||||
return &Decoder{r: r}
|
||||
}
|
||||
|
||||
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
|
||||
// Number instead of as a float64.
|
||||
func (dec *Decoder) UseNumber() { dec.d.useNumber = true }
|
||||
|
||||
// DisallowUnknownFields causes the Decoder to return an error when the destination
|
||||
// is a struct and the input contains object keys which do not match any
|
||||
// non-ignored, exported fields in the destination.
|
||||
func (dec *Decoder) DisallowUnknownFields() { dec.d.disallowUnknownFields = true }
|
||||
|
||||
// Decode reads the next JSON-encoded value from its
|
||||
// input and stores it in the value pointed to by v.
|
||||
//
|
||||
// See the documentation for Unmarshal for details about
|
||||
// the conversion of JSON into a Go value.
|
||||
func (dec *Decoder) Decode(v any) error {
|
||||
if dec.err != nil {
|
||||
return dec.err
|
||||
}
|
||||
|
||||
if err := dec.tokenPrepareForDecode(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !dec.tokenValueAllowed() {
|
||||
return &SyntaxError{msg: "not at beginning of value", Offset: dec.InputOffset()}
|
||||
}
|
||||
|
||||
// Read whole value into buffer.
|
||||
n, err := dec.readValue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dec.d.init(dec.buf[dec.scanp : dec.scanp+n])
|
||||
dec.scanp += n
|
||||
|
||||
// Don't save err from unmarshal into dec.err:
|
||||
// the connection is still usable since we read a complete JSON
|
||||
// object from it before the error happened.
|
||||
err = dec.d.unmarshal(v)
|
||||
|
||||
// fixup token streaming state
|
||||
dec.tokenValueEnd()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Buffered returns a reader of the data remaining in the Decoder's
|
||||
// buffer. The reader is valid until the next call to Decode.
|
||||
func (dec *Decoder) Buffered() io.Reader {
|
||||
return bytes.NewReader(dec.buf[dec.scanp:])
|
||||
}
|
||||
|
||||
// readValue reads a JSON value into dec.buf.
|
||||
// It returns the length of the encoding.
|
||||
func (dec *Decoder) readValue() (int, error) {
|
||||
dec.scan.reset()
|
||||
|
||||
scanp := dec.scanp
|
||||
var err error
|
||||
Input:
|
||||
// help the compiler see that scanp is never negative, so it can remove
|
||||
// some bounds checks below.
|
||||
for scanp >= 0 {
|
||||
|
||||
// Look in the buffer for a new value.
|
||||
for ; scanp < len(dec.buf); scanp++ {
|
||||
c := dec.buf[scanp]
|
||||
dec.scan.bytes++
|
||||
switch dec.scan.step(&dec.scan, c) {
|
||||
case scanEnd:
|
||||
// scanEnd is delayed one byte so we decrement
|
||||
// the scanner bytes count by 1 to ensure that
|
||||
// this value is correct in the next call of Decode.
|
||||
dec.scan.bytes--
|
||||
break Input
|
||||
case scanEndObject, scanEndArray:
|
||||
// scanEnd is delayed one byte.
|
||||
// We might block trying to get that byte from src,
|
||||
// so instead invent a space byte.
|
||||
if stateEndValue(&dec.scan, ' ') == scanEnd {
|
||||
scanp++
|
||||
break Input
|
||||
}
|
||||
case scanError:
|
||||
dec.err = dec.scan.err
|
||||
return 0, dec.scan.err
|
||||
}
|
||||
}
|
||||
|
||||
// Did the last read have an error?
|
||||
// Delayed until now to allow buffer scan.
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
if dec.scan.step(&dec.scan, ' ') == scanEnd {
|
||||
break Input
|
||||
}
|
||||
if nonSpace(dec.buf) {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
}
|
||||
dec.err = err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
n := scanp - dec.scanp
|
||||
err = dec.refill()
|
||||
scanp = dec.scanp + n
|
||||
}
|
||||
return scanp - dec.scanp, nil
|
||||
}
|
||||
|
||||
func (dec *Decoder) refill() error {
|
||||
// Make room to read more into the buffer.
|
||||
// First slide down data already consumed.
|
||||
if dec.scanp > 0 {
|
||||
dec.scanned += int64(dec.scanp)
|
||||
n := copy(dec.buf, dec.buf[dec.scanp:])
|
||||
dec.buf = dec.buf[:n]
|
||||
dec.scanp = 0
|
||||
}
|
||||
|
||||
return dec.refill0()
|
||||
}
|
||||
|
||||
func (dec *Decoder) refill0() error {
|
||||
// Grow buffer if not large enough.
|
||||
const minRead = 512
|
||||
if cap(dec.buf)-len(dec.buf) < minRead {
|
||||
newBuf := make([]byte, len(dec.buf), 2*cap(dec.buf)+minRead)
|
||||
copy(newBuf, dec.buf)
|
||||
dec.buf = newBuf
|
||||
}
|
||||
|
||||
// Read. Delay error for next iteration (after scan).
|
||||
n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
|
||||
dec.buf = dec.buf[0 : len(dec.buf)+n]
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func nonSpace(b []byte) bool {
|
||||
for _, c := range b {
|
||||
if !isSpace(c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// An Encoder writes JSON values to an output stream.
|
||||
type Encoder struct {
|
||||
w io.Writer
|
||||
err error
|
||||
escapeHTML bool
|
||||
|
||||
indentBuf []byte
|
||||
indentPrefix string
|
||||
indentValue string
|
||||
}
|
||||
|
||||
// NewEncoder returns a new encoder that writes to w.
|
||||
func NewEncoder(w io.Writer) *Encoder {
|
||||
return &Encoder{w: w, escapeHTML: true}
|
||||
}
|
||||
|
||||
// Encode writes the JSON encoding of v to the stream,
|
||||
// followed by a newline character.
|
||||
//
|
||||
// See the documentation for Marshal for details about the
|
||||
// conversion of Go values to JSON.
|
||||
func (enc *Encoder) Encode(v any) error {
|
||||
if enc.err != nil {
|
||||
return enc.err
|
||||
}
|
||||
|
||||
e := newEncodeState()
|
||||
defer encodeStatePool.Put(e)
|
||||
|
||||
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Terminate each value with a newline.
|
||||
// This makes the output look a little nicer
|
||||
// when debugging, and some kind of space
|
||||
// is required if the encoded value was a number,
|
||||
// so that the reader knows there aren't more
|
||||
// digits coming.
|
||||
e.WriteByte('\n')
|
||||
|
||||
b := e.Bytes()
|
||||
if enc.indentPrefix != "" || enc.indentValue != "" {
|
||||
enc.indentBuf, err = appendIndent(enc.indentBuf[:0], b, enc.indentPrefix, enc.indentValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b = enc.indentBuf
|
||||
}
|
||||
if _, err = enc.w.Write(b); err != nil {
|
||||
enc.err = err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// SetIndent instructs the encoder to format each subsequent encoded
|
||||
// value as if indented by the package-level function Indent(dst, src, prefix, indent).
|
||||
// Calling SetIndent("", "") disables indentation.
|
||||
func (enc *Encoder) SetIndent(prefix, indent string) {
|
||||
enc.indentPrefix = prefix
|
||||
enc.indentValue = indent
|
||||
}
|
||||
|
||||
// SetEscapeHTML specifies whether problematic HTML characters
|
||||
// should be escaped inside JSON quoted strings.
|
||||
// The default behavior is to escape &, <, and > to \u0026, \u003c, and \u003e
|
||||
// to avoid certain safety problems that can arise when embedding JSON in HTML.
|
||||
//
|
||||
// In non-HTML settings where the escaping interferes with the readability
|
||||
// of the output, SetEscapeHTML(false) disables this behavior.
|
||||
func (enc *Encoder) SetEscapeHTML(on bool) {
|
||||
enc.escapeHTML = on
|
||||
}
|
||||
|
||||
// RawMessage is a raw encoded JSON value.
|
||||
// It implements Marshaler and Unmarshaler and can
|
||||
// be used to delay JSON decoding or precompute a JSON encoding.
|
||||
type RawMessage []byte
|
||||
|
||||
// MarshalJSON returns m as the JSON encoding of m.
|
||||
func (m RawMessage) MarshalJSON() ([]byte, error) {
|
||||
if m == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON sets *m to a copy of data.
|
||||
func (m *RawMessage) UnmarshalJSON(data []byte) error {
|
||||
if m == nil {
|
||||
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
|
||||
}
|
||||
*m = append((*m)[0:0], data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ Marshaler = (*RawMessage)(nil)
|
||||
_ Unmarshaler = (*RawMessage)(nil)
|
||||
)
|
||||
|
||||
// A Token holds a value of one of these types:
|
||||
//
|
||||
// Delim, for the four JSON delimiters [ ] { }
|
||||
// bool, for JSON booleans
|
||||
// float64, for JSON numbers
|
||||
// Number, for JSON numbers
|
||||
// string, for JSON string literals
|
||||
// nil, for JSON null
|
||||
type Token any
|
||||
|
||||
const (
|
||||
tokenTopValue = iota
|
||||
tokenArrayStart
|
||||
tokenArrayValue
|
||||
tokenArrayComma
|
||||
tokenObjectStart
|
||||
tokenObjectKey
|
||||
tokenObjectColon
|
||||
tokenObjectValue
|
||||
tokenObjectComma
|
||||
)
|
||||
|
||||
// advance tokenstate from a separator state to a value state
|
||||
func (dec *Decoder) tokenPrepareForDecode() error {
|
||||
// Note: Not calling peek before switch, to avoid
|
||||
// putting peek into the standard Decode path.
|
||||
// peek is only called when using the Token API.
|
||||
switch dec.tokenState {
|
||||
case tokenArrayComma:
|
||||
c, err := dec.peek()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c != ',' {
|
||||
return &SyntaxError{"expected comma after array element", dec.InputOffset()}
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenArrayValue
|
||||
case tokenObjectColon:
|
||||
c, err := dec.peek()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c != ':' {
|
||||
return &SyntaxError{"expected colon after object key", dec.InputOffset()}
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenObjectValue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dec *Decoder) tokenValueAllowed() bool {
|
||||
switch dec.tokenState {
|
||||
case tokenTopValue, tokenArrayStart, tokenArrayValue, tokenObjectValue:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (dec *Decoder) tokenValueEnd() {
|
||||
switch dec.tokenState {
|
||||
case tokenArrayStart, tokenArrayValue:
|
||||
dec.tokenState = tokenArrayComma
|
||||
case tokenObjectValue:
|
||||
dec.tokenState = tokenObjectComma
|
||||
}
|
||||
}
|
||||
|
||||
// A Delim is a JSON array or object delimiter, one of [ ] { or }.
|
||||
type Delim rune
|
||||
|
||||
func (d Delim) String() string {
|
||||
return string(d)
|
||||
}
|
||||
|
||||
// Token returns the next JSON token in the input stream.
|
||||
// At the end of the input stream, Token returns nil, io.EOF.
|
||||
//
|
||||
// Token guarantees that the delimiters [ ] { } it returns are
|
||||
// properly nested and matched: if Token encounters an unexpected
|
||||
// delimiter in the input, it will return an error.
|
||||
//
|
||||
// The input stream consists of basic JSON values—bool, string,
|
||||
// number, and null—along with delimiters [ ] { } of type Delim
|
||||
// to mark the start and end of arrays and objects.
|
||||
// Commas and colons are elided.
|
||||
func (dec *Decoder) Token() (Token, error) {
|
||||
for {
|
||||
c, err := dec.peek()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch c {
|
||||
case '[':
|
||||
if !dec.tokenValueAllowed() {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
|
||||
dec.tokenState = tokenArrayStart
|
||||
return Delim('['), nil
|
||||
|
||||
case ']':
|
||||
if dec.tokenState != tokenArrayStart && dec.tokenState != tokenArrayComma {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
|
||||
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
|
||||
dec.tokenValueEnd()
|
||||
return Delim(']'), nil
|
||||
|
||||
case '{':
|
||||
if !dec.tokenValueAllowed() {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
|
||||
dec.tokenState = tokenObjectStart
|
||||
return Delim('{'), nil
|
||||
|
||||
case '}':
|
||||
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma && dec.tokenState != tokenObjectKey {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
|
||||
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
|
||||
dec.tokenValueEnd()
|
||||
return Delim('}'), nil
|
||||
case ':':
|
||||
if dec.tokenState != tokenObjectColon {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenObjectValue
|
||||
continue
|
||||
|
||||
case ',':
|
||||
if dec.tokenState == tokenArrayComma {
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenArrayValue
|
||||
continue
|
||||
}
|
||||
if dec.tokenState == tokenObjectComma {
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenObjectKey
|
||||
continue
|
||||
}
|
||||
return dec.tokenError(c)
|
||||
|
||||
case '"':
|
||||
if dec.tokenState == tokenObjectStart || dec.tokenState == tokenObjectKey {
|
||||
var x string
|
||||
old := dec.tokenState
|
||||
dec.tokenState = tokenTopValue
|
||||
err := dec.Decode(&x)
|
||||
dec.tokenState = old
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dec.tokenState = tokenObjectColon
|
||||
return x, nil
|
||||
}
|
||||
fallthrough
|
||||
|
||||
default:
|
||||
if !dec.tokenValueAllowed() {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
var x any
|
||||
if err := dec.Decode(&x); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dec *Decoder) tokenError(c byte) (Token, error) {
|
||||
var context string
|
||||
switch dec.tokenState {
|
||||
case tokenTopValue:
|
||||
context = " looking for beginning of value"
|
||||
case tokenArrayStart, tokenArrayValue, tokenObjectValue:
|
||||
context = " looking for beginning of value"
|
||||
case tokenArrayComma:
|
||||
context = " after array element"
|
||||
case tokenObjectKey:
|
||||
context = " looking for beginning of object key string"
|
||||
case tokenObjectColon:
|
||||
context = " after object key"
|
||||
case tokenObjectComma:
|
||||
context = " after object key:value pair"
|
||||
}
|
||||
return nil, &SyntaxError{"invalid character " + quoteChar(c) + context, dec.InputOffset()}
|
||||
}
|
||||
|
||||
// More reports whether there is another element in the
|
||||
// current array or object being parsed.
|
||||
func (dec *Decoder) More() bool {
|
||||
c, err := dec.peek()
|
||||
// return err == nil && c != ']' && c != '}'
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if c == ']' || c == '}' {
|
||||
return false
|
||||
}
|
||||
if c == ',' {
|
||||
scanp := dec.scanp
|
||||
dec.scanp++
|
||||
c, err = dec.peekNoRefill()
|
||||
dec.scanp = scanp
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if c == ']' || c == '}' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (dec *Decoder) peek() (byte, error) {
|
||||
var err error
|
||||
for {
|
||||
for i := dec.scanp; i < len(dec.buf); i++ {
|
||||
c := dec.buf[i]
|
||||
if isSpace(c) {
|
||||
continue
|
||||
}
|
||||
dec.scanp = i
|
||||
return c, nil
|
||||
}
|
||||
// buffer has been scanned, now report any error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = dec.refill()
|
||||
}
|
||||
}
|
||||
|
||||
func (dec *Decoder) peekNoRefill() (byte, error) {
|
||||
var err error
|
||||
for {
|
||||
for i := dec.scanp; i < len(dec.buf); i++ {
|
||||
c := dec.buf[i]
|
||||
if isSpace(c) {
|
||||
continue
|
||||
}
|
||||
dec.scanp = i
|
||||
return c, nil
|
||||
}
|
||||
// buffer has been scanned, now report any error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = dec.refill0()
|
||||
}
|
||||
}
|
||||
|
||||
// InputOffset returns the input stream byte offset of the current decoder position.
|
||||
// The offset gives the location of the end of the most recently returned token
|
||||
// and the beginning of the next token.
|
||||
func (dec *Decoder) InputOffset() int64 {
|
||||
return dec.scanned + int64(dec.scanp)
|
||||
}
|
218
common/json/internal/contextjson/tables.go
Normal file
218
common/json/internal/contextjson/tables.go
Normal file
|
@ -0,0 +1,218 @@
|
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import "unicode/utf8"
|
||||
|
||||
// safeSet holds the value true if the ASCII character with the given array
|
||||
// position can be represented inside a JSON string without any further
|
||||
// escaping.
|
||||
//
|
||||
// All values are true except for the ASCII control characters (0-31), the
|
||||
// double quote ("), and the backslash character ("\").
|
||||
var safeSet = [utf8.RuneSelf]bool{
|
||||
' ': true,
|
||||
'!': true,
|
||||
'"': false,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': true,
|
||||
'\'': true,
|
||||
'(': true,
|
||||
')': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
',': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'/': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
':': true,
|
||||
';': true,
|
||||
'<': true,
|
||||
'=': true,
|
||||
'>': true,
|
||||
'?': true,
|
||||
'@': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'V': true,
|
||||
'W': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'[': true,
|
||||
'\\': false,
|
||||
']': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'{': true,
|
||||
'|': true,
|
||||
'}': true,
|
||||
'~': true,
|
||||
'\u007f': true,
|
||||
}
|
||||
|
||||
// htmlSafeSet holds the value true if the ASCII character with the given
|
||||
// array position can be safely represented inside a JSON string, embedded
|
||||
// inside of HTML <script> tags, without any additional escaping.
|
||||
//
|
||||
// All values are true except for the ASCII control characters (0-31), the
|
||||
// double quote ("), the backslash character ("\"), HTML opening and closing
|
||||
// tags ("<" and ">"), and the ampersand ("&").
|
||||
var htmlSafeSet = [utf8.RuneSelf]bool{
|
||||
' ': true,
|
||||
'!': true,
|
||||
'"': false,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': false,
|
||||
'\'': true,
|
||||
'(': true,
|
||||
')': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
',': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'/': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
':': true,
|
||||
';': true,
|
||||
'<': false,
|
||||
'=': true,
|
||||
'>': false,
|
||||
'?': true,
|
||||
'@': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'V': true,
|
||||
'W': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'[': true,
|
||||
'\\': false,
|
||||
']': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'{': true,
|
||||
'|': true,
|
||||
'}': true,
|
||||
'~': true,
|
||||
'\u007f': true,
|
||||
}
|
38
common/json/internal/contextjson/tags.go
Normal file
38
common/json/internal/contextjson/tags.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// tagOptions is the string following a comma in a struct field's "json"
|
||||
// tag, or the empty string. It does not include the leading comma.
|
||||
type tagOptions string
|
||||
|
||||
// parseTag splits a struct field's json tag into its name and
|
||||
// comma-separated options.
|
||||
func parseTag(tag string) (string, tagOptions) {
|
||||
tag, opt, _ := strings.Cut(tag, ",")
|
||||
return tag, tagOptions(opt)
|
||||
}
|
||||
|
||||
// Contains reports whether a comma-separated list of options
|
||||
// contains a particular substr flag. substr must be surrounded by a
|
||||
// string boundary or commas.
|
||||
func (o tagOptions) Contains(optionName string) bool {
|
||||
if len(o) == 0 {
|
||||
return false
|
||||
}
|
||||
s := string(o)
|
||||
for s != "" {
|
||||
var name string
|
||||
name, s, _ = strings.Cut(s, ",")
|
||||
if name == optionName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
21
common/json/std.go
Normal file
21
common/json/std.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
//go:build !go1.20 || without_contextjson
|
||||
|
||||
package json
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
var (
|
||||
Marshal = json.Marshal
|
||||
Unmarshal = json.Unmarshal
|
||||
NewEncoder = json.NewEncoder
|
||||
NewDecoder = json.NewDecoder
|
||||
)
|
||||
|
||||
type (
|
||||
Encoder = json.Encoder
|
||||
Decoder = json.Decoder
|
||||
Token = json.Token
|
||||
Delim = json.Delim
|
||||
SyntaxError = json.SyntaxError
|
||||
RawMessage = json.RawMessage
|
||||
)
|
25
common/json/unmarshal.go
Normal file
25
common/json/unmarshal.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func UnmarshalExtended[T any](content []byte) (T, error) {
|
||||
decoder := NewDecoder(NewCommentFilter(bytes.NewReader(content)))
|
||||
var value T
|
||||
err := decoder.Decode(&value)
|
||||
if err == nil {
|
||||
return value, err
|
||||
}
|
||||
if syntaxError, isSyntaxError := err.(*SyntaxError); isSyntaxError {
|
||||
prefix := string(content[:syntaxError.Offset])
|
||||
row := strings.Count(prefix, "\n") + 1
|
||||
column := len(prefix) - strings.LastIndex(prefix, "\n") - 1
|
||||
return common.DefaultValue[T](), E.Extend(syntaxError, "row ", row, ", column ", column)
|
||||
}
|
||||
return common.DefaultValue[T](), err
|
||||
}
|
16
common/memory/memory.go
Normal file
16
common/memory/memory.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package memory
|
||||
|
||||
import "runtime"
|
||||
|
||||
func Total() uint64 {
|
||||
if nativeAvailable {
|
||||
return usageNative()
|
||||
}
|
||||
return Inuse()
|
||||
}
|
||||
|
||||
func Inuse() uint64 {
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
return memStats.StackInuse + memStats.HeapInuse + memStats.HeapIdle - memStats.HeapReleased
|
||||
}
|
18
common/memory/memory_darwin.go
Normal file
18
common/memory/memory_darwin.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package memory
|
||||
|
||||
// #include <mach/mach.h>
|
||||
import "C"
|
||||
import "unsafe"
|
||||
|
||||
const nativeAvailable = true
|
||||
|
||||
func usageNative() uint64 {
|
||||
var memoryUsageInByte uint64
|
||||
var vmInfo C.task_vm_info_data_t
|
||||
var count C.mach_msg_type_number_t = C.TASK_VM_INFO_COUNT
|
||||
var kernelReturn C.kern_return_t = C.task_info(C.vm_map_t(C.mach_task_self_), C.TASK_VM_INFO, (*C.integer_t)(unsafe.Pointer(&vmInfo)), &count)
|
||||
if kernelReturn == C.KERN_SUCCESS {
|
||||
memoryUsageInByte = uint64(vmInfo.phys_footprint)
|
||||
}
|
||||
return memoryUsageInByte
|
||||
}
|
9
common/memory/memory_stub.go
Normal file
9
common/memory/memory_stub.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
//go:build (darwin && !cgo) || !darwin
|
||||
|
||||
package memory
|
||||
|
||||
const nativeAvailable = false
|
||||
|
||||
func usageNative() uint64 {
|
||||
return 0
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue