Merge branch 'SagerNet:dev' into dev

This commit is contained in:
shij 2024-06-22 12:18:09 +08:00 committed by GitHub
commit 5d7890f308
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
155 changed files with 10203 additions and 1390 deletions

View file

@ -1,6 +1,7 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"commitMessagePrefix": "[dependencies]",
"branchName": "main",
"extends": [
"config:base",
":disableRateLimiting"

View file

@ -3,6 +3,7 @@ name: Debug build
on:
push:
branches:
- main
- dev
paths-ignore:
- '**.md'
@ -10,33 +11,85 @@ on:
- '!.github/workflows/debug.yml'
pull_request:
branches:
- main
- dev
jobs:
build:
name: Debug build
name: Linux Debug build
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ${{ steps.version.outputs.go_version }}
- name: Add cache to Go proxy
go-version: ^1.22
- name: Build
run: |
version=`git rev-parse HEAD`
mkdir build
pushd build
go mod init build
go get -v github.com/sagernet/sing@$version
popd
make test
build_go120:
name: Linux Debug build (Go 1.20)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ~1.20
continue-on-error: true
- name: Build
run: |
make test
build_go121:
name: Linux Debug build (Go 1.21)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ~1.21
continue-on-error: true
- name: Build
run: |
make test
build__windows:
name: Windows Debug build
runs-on: windows-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ^1.22
continue-on-error: true
- name: Build
run: |
make test
build_darwin:
name: macOS Debug build
runs-on: macos-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ^1.22
continue-on-error: true
- name: Build
run: |

View file

@ -3,6 +3,7 @@ name: Lint
on:
push:
branches:
- main
- dev
paths-ignore:
- '**.md'
@ -10,6 +11,7 @@ on:
- '!.github/workflows/lint.yml'
pull_request:
branches:
- main
- dev
jobs:
@ -18,17 +20,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ${{ steps.version.outputs.go_version }}
go-version: ^1.22
- name: Cache go module
uses: actions/cache@v3
with:

1
.gitignore vendored
View file

@ -1,2 +1,3 @@
/.idea/
/vendor/
.DS_Store

View file

@ -18,4 +18,4 @@ lint_install:
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
test:
go test -v ./...
go test $(shell go list ./... | grep -v /internal/)

View file

@ -18,7 +18,6 @@ var _ xml.TokenReader = (*Reader)(nil)
type Reader struct {
reader *bytes.Reader
stringRefs []string
attrs []xml.Attr
}
func NewReader(content []byte) (xml.TokenReader, bool) {
@ -47,7 +46,7 @@ func (r *Reader) Token() (token xml.Token, err error) {
return
}
var attrs []xml.Attr
attrs, err = r.pullAttributes()
attrs, err = r.readAttributes()
if err != nil {
return
}
@ -93,35 +92,41 @@ func (r *Reader) Token() (token xml.Token, err error) {
_, err = r.readUTF()
return
case ATTRIBUTE:
return nil, E.New("unexpected attribute")
_, err = r.readAttribute()
return
}
return nil, E.New("unknown token type ", tokenType, " with type ", eventType)
}
func (r *Reader) pullAttributes() ([]xml.Attr, error) {
err := r.pullAttribute()
if err != nil {
return nil, err
func (r *Reader) readAttributes() ([]xml.Attr, error) {
var attrs []xml.Attr
for {
attr, err := r.readAttribute()
if err == io.EOF {
break
}
attrs = append(attrs, attr)
}
attrs := r.attrs
r.attrs = nil
return attrs, nil
}
func (r *Reader) pullAttribute() error {
func (r *Reader) readAttribute() (xml.Attr, error) {
event, err := r.reader.ReadByte()
if err != nil {
return nil
return xml.Attr{}, nil
}
tokenType := event & 0x0f
eventType := event & 0xf0
if tokenType != ATTRIBUTE {
return r.reader.UnreadByte()
}
var name string
name, err = r.readInternedUTF()
err = r.reader.UnreadByte()
if err != nil {
return err
return xml.Attr{}, nil
}
return xml.Attr{}, io.EOF
}
name, err := r.readInternedUTF()
if err != nil {
return xml.Attr{}, err
}
var value string
switch eventType {
@ -134,74 +139,73 @@ func (r *Reader) pullAttribute() error {
case TypeString:
value, err = r.readUTF()
if err != nil {
return err
return xml.Attr{}, err
}
case TypeStringInterned:
value, err = r.readInternedUTF()
if err != nil {
return err
return xml.Attr{}, err
}
case TypeBytesHex:
var data []byte
data, err = r.readBytes()
if err != nil {
return err
return xml.Attr{}, err
}
value = hex.EncodeToString(data)
case TypeBytesBase64:
var data []byte
data, err = r.readBytes()
if err != nil {
return err
return xml.Attr{}, err
}
value = base64.StdEncoding.EncodeToString(data)
case TypeInt:
var data int32
err = binary.Read(r.reader, binary.BigEndian, &data)
if err != nil {
return err
return xml.Attr{}, err
}
value = strconv.FormatInt(int64(data), 10)
case TypeIntHex:
var data int32
err = binary.Read(r.reader, binary.BigEndian, &data)
if err != nil {
return err
return xml.Attr{}, err
}
value = "0x" + strconv.FormatInt(int64(data), 16)
case TypeLong:
var data int64
err = binary.Read(r.reader, binary.BigEndian, &data)
if err != nil {
return err
return xml.Attr{}, err
}
value = strconv.FormatInt(data, 10)
case TypeLongHex:
var data int64
err = binary.Read(r.reader, binary.BigEndian, &data)
if err != nil {
return err
return xml.Attr{}, err
}
value = "0x" + strconv.FormatInt(data, 16)
case TypeFloat:
var data float32
err = binary.Read(r.reader, binary.BigEndian, &data)
if err != nil {
return err
return xml.Attr{}, err
}
value = strconv.FormatFloat(float64(data), 'g', -1, 32)
case TypeDouble:
var data float64
err = binary.Read(r.reader, binary.BigEndian, &data)
if err != nil {
return err
return xml.Attr{}, err
}
value = strconv.FormatFloat(data, 'g', -1, 64)
default:
return E.New("unexpected attribute type, ", eventType)
return xml.Attr{}, E.New("unexpected attribute type, ", eventType)
}
r.attrs = append(r.attrs, xml.Attr{Name: xml.Name{Local: name}, Value: value})
return r.pullAttribute()
return xml.Attr{Name: xml.Name{Local: name}, Value: value}, nil
}
func (r *Reader) readUnsignedShort() (uint16, error) {

View file

@ -10,26 +10,37 @@ type TypedValue[T any] struct {
value atomic.Value
}
// typedValue is a struct with determined type to resolve atomic.Value usages with interface types
// https://github.com/golang/go/issues/22550
//
// The intention to have an atomic value store for errors. However, running this code panics:
// panic: sync/atomic: store of inconsistently typed value into Value
// This is because atomic.Value requires that the underlying concrete type be the same (which is a reasonable expectation for its implementation).
// When going through the atomic.Value.Store method call, the fact that both these are of the error interface is lost.
type typedValue[T any] struct {
value T
}
func (t *TypedValue[T]) Load() T {
value := t.value.Load()
if value == nil {
return common.DefaultValue[T]()
}
return value.(T)
return value.(typedValue[T]).value
}
func (t *TypedValue[T]) Store(value T) {
t.value.Store(value)
t.value.Store(typedValue[T]{value})
}
func (t *TypedValue[T]) Swap(new T) T {
old := t.value.Swap(new)
old := t.value.Swap(typedValue[T]{new})
if old == nil {
return common.DefaultValue[T]()
}
return old.(T)
return old.(typedValue[T]).value
}
func (t *TypedValue[T]) CompareAndSwap(old, new T) bool {
return t.value.CompareAndSwap(old, new)
return t.value.CompareAndSwap(typedValue[T]{old}, typedValue[T]{new})
}

View file

@ -1,38 +1,30 @@
package auth
type Authenticator interface {
Verify(user string, pass string) bool
Users() []string
}
import "github.com/sagernet/sing/common"
type User struct {
Username string `json:"username"`
Password string `json:"password"`
Username string
Password string
}
type inMemoryAuthenticator struct {
storage map[string]string
usernames []string
type Authenticator struct {
userMap map[string][]string
}
func (au *inMemoryAuthenticator) Verify(username string, password string) bool {
realPass, ok := au.storage[username]
return ok && realPass == password
}
func (au *inMemoryAuthenticator) Users() []string { return au.usernames }
func NewAuthenticator(users []User) Authenticator {
func NewAuthenticator(users []User) *Authenticator {
if len(users) == 0 {
return nil
}
au := &inMemoryAuthenticator{
storage: make(map[string]string),
usernames: make([]string, 0, len(users)),
au := &Authenticator{
userMap: make(map[string][]string),
}
for _, user := range users {
au.storage[user.Username] = user.Password
au.usernames = append(au.usernames, user.Username)
au.userMap[user.Username] = append(au.userMap[user.Username], user.Password)
}
return au
}
func (au *Authenticator) Verify(username string, password string) bool {
passwordList, ok := au.userMap[username]
return ok && common.Contains(passwordList, password)
}

View file

@ -55,7 +55,10 @@ func WrapQUIC(err error) error {
if err == nil {
return nil
}
if Contains(err, "canceled with error code 0") {
if Contains(err,
"canceled by remote with error code 0",
"canceled by local with error code 0",
) {
return net.ErrClosed
}
return err

3
common/binary/README.md Normal file
View file

@ -0,0 +1,3 @@
# binary
mod from go 1.22.3

817
common/binary/binary.go Normal file
View file

@ -0,0 +1,817 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package binary implements simple translation between numbers and byte
// sequences and encoding and decoding of varints.
//
// Numbers are translated by reading and writing fixed-size values.
// A fixed-size value is either a fixed-size arithmetic
// type (bool, int8, uint8, int16, float32, complex64, ...)
// or an array or struct containing only fixed-size values.
//
// The varint functions encode and decode single integer values using
// a variable-length encoding; smaller values require fewer bytes.
// For a specification, see
// https://developers.google.com/protocol-buffers/docs/encoding.
//
// This package favors simplicity over efficiency. Clients that require
// high-performance serialization, especially for large data structures,
// should look at more advanced solutions such as the [encoding/gob]
// package or [google.golang.org/protobuf] for protocol buffers.
package binary
import (
"errors"
"io"
"math"
"reflect"
"sync"
)
// A ByteOrder specifies how to convert byte slices into
// 16-, 32-, or 64-bit unsigned integers.
//
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
type ByteOrder interface {
Uint16([]byte) uint16
Uint32([]byte) uint32
Uint64([]byte) uint64
PutUint16([]byte, uint16)
PutUint32([]byte, uint32)
PutUint64([]byte, uint64)
String() string
}
// AppendByteOrder specifies how to append 16-, 32-, or 64-bit unsigned integers
// into a byte slice.
//
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
type AppendByteOrder interface {
AppendUint16([]byte, uint16) []byte
AppendUint32([]byte, uint32) []byte
AppendUint64([]byte, uint64) []byte
String() string
}
// LittleEndian is the little-endian implementation of [ByteOrder] and [AppendByteOrder].
var LittleEndian littleEndian
// BigEndian is the big-endian implementation of [ByteOrder] and [AppendByteOrder].
var BigEndian bigEndian
type littleEndian struct{}
func (littleEndian) Uint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[0]) | uint16(b[1])<<8
}
func (littleEndian) PutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
}
func (littleEndian) AppendUint16(b []byte, v uint16) []byte {
return append(b,
byte(v),
byte(v>>8),
)
}
func (littleEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func (littleEndian) PutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
}
func (littleEndian) AppendUint32(b []byte, v uint32) []byte {
return append(b,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
)
}
func (littleEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
func (littleEndian) PutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
b[4] = byte(v >> 32)
b[5] = byte(v >> 40)
b[6] = byte(v >> 48)
b[7] = byte(v >> 56)
}
func (littleEndian) AppendUint64(b []byte, v uint64) []byte {
return append(b,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
byte(v>>32),
byte(v>>40),
byte(v>>48),
byte(v>>56),
)
}
func (littleEndian) String() string { return "LittleEndian" }
func (littleEndian) GoString() string { return "binary.LittleEndian" }
type bigEndian struct{}
func (bigEndian) Uint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[1]) | uint16(b[0])<<8
}
func (bigEndian) PutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 8)
b[1] = byte(v)
}
func (bigEndian) AppendUint16(b []byte, v uint16) []byte {
return append(b,
byte(v>>8),
byte(v),
)
}
func (bigEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}
func (bigEndian) PutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
b[3] = byte(v)
}
func (bigEndian) AppendUint32(b []byte, v uint32) []byte {
return append(b,
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v),
)
}
func (bigEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
}
func (bigEndian) PutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
b[3] = byte(v >> 32)
b[4] = byte(v >> 24)
b[5] = byte(v >> 16)
b[6] = byte(v >> 8)
b[7] = byte(v)
}
func (bigEndian) AppendUint64(b []byte, v uint64) []byte {
return append(b,
byte(v>>56),
byte(v>>48),
byte(v>>40),
byte(v>>32),
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v),
)
}
func (bigEndian) String() string { return "BigEndian" }
func (bigEndian) GoString() string { return "binary.BigEndian" }
func (nativeEndian) String() string { return "NativeEndian" }
func (nativeEndian) GoString() string { return "binary.NativeEndian" }
// Read reads structured binary data from r into data.
// Data must be a pointer to a fixed-size value or a slice
// of fixed-size values.
// Bytes read from r are decoded using the specified byte order
// and written to successive fields of the data.
// When decoding boolean values, a zero byte is decoded as false, and
// any other non-zero byte is decoded as true.
// When reading into structs, the field data for fields with
// blank (_) field names is skipped; i.e., blank field names
// may be used for padding.
// When reading into a struct, all non-blank fields must be exported
// or Read may panic.
//
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// Read returns [io.ErrUnexpectedEOF].
func Read(r io.Reader, order ByteOrder, data any) error {
// Fast path for basic types and slices.
if n := intDataSize(data); n != 0 {
bs := make([]byte, n)
if _, err := io.ReadFull(r, bs); err != nil {
return err
}
switch data := data.(type) {
case *bool:
*data = bs[0] != 0
case *int8:
*data = int8(bs[0])
case *uint8:
*data = bs[0]
case *int16:
*data = int16(order.Uint16(bs))
case *uint16:
*data = order.Uint16(bs)
case *int32:
*data = int32(order.Uint32(bs))
case *uint32:
*data = order.Uint32(bs)
case *int64:
*data = int64(order.Uint64(bs))
case *uint64:
*data = order.Uint64(bs)
case *float32:
*data = math.Float32frombits(order.Uint32(bs))
case *float64:
*data = math.Float64frombits(order.Uint64(bs))
case []bool:
for i, x := range bs { // Easier to loop over the input for 8-bit values.
data[i] = x != 0
}
case []int8:
for i, x := range bs {
data[i] = int8(x)
}
case []uint8:
copy(data, bs)
case []int16:
for i := range data {
data[i] = int16(order.Uint16(bs[2*i:]))
}
case []uint16:
for i := range data {
data[i] = order.Uint16(bs[2*i:])
}
case []int32:
for i := range data {
data[i] = int32(order.Uint32(bs[4*i:]))
}
case []uint32:
for i := range data {
data[i] = order.Uint32(bs[4*i:])
}
case []int64:
for i := range data {
data[i] = int64(order.Uint64(bs[8*i:]))
}
case []uint64:
for i := range data {
data[i] = order.Uint64(bs[8*i:])
}
case []float32:
for i := range data {
data[i] = math.Float32frombits(order.Uint32(bs[4*i:]))
}
case []float64:
for i := range data {
data[i] = math.Float64frombits(order.Uint64(bs[8*i:]))
}
default:
n = 0 // fast path doesn't apply
}
if n != 0 {
return nil
}
}
// Fallback to reflect-based decoding.
v := reflect.ValueOf(data)
size := -1
switch v.Kind() {
case reflect.Pointer:
v = v.Elem()
size = dataSize(v)
case reflect.Slice:
size = dataSize(v)
}
if size < 0 {
return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String())
}
d := &decoder{order: order, buf: make([]byte, size)}
if _, err := io.ReadFull(r, d.buf); err != nil {
return err
}
d.value(v)
return nil
}
// Write writes the binary representation of data into w.
// Data must be a fixed-size value or a slice of fixed-size
// values, or a pointer to such data.
// Boolean values encode as one byte: 1 for true, and 0 for false.
// Bytes written to w are encoded using the specified byte order
// and read from successive fields of the data.
// When writing structs, zero values are written for fields
// with blank (_) field names.
func Write(w io.Writer, order ByteOrder, data any) error {
// Fast path for basic types and slices.
if n := intDataSize(data); n != 0 {
bs := make([]byte, n)
switch v := data.(type) {
case *bool:
if *v {
bs[0] = 1
} else {
bs[0] = 0
}
case bool:
if v {
bs[0] = 1
} else {
bs[0] = 0
}
case []bool:
for i, x := range v {
if x {
bs[i] = 1
} else {
bs[i] = 0
}
}
case *int8:
bs[0] = byte(*v)
case int8:
bs[0] = byte(v)
case []int8:
for i, x := range v {
bs[i] = byte(x)
}
case *uint8:
bs[0] = *v
case uint8:
bs[0] = v
case []uint8:
bs = v
case *int16:
order.PutUint16(bs, uint16(*v))
case int16:
order.PutUint16(bs, uint16(v))
case []int16:
for i, x := range v {
order.PutUint16(bs[2*i:], uint16(x))
}
case *uint16:
order.PutUint16(bs, *v)
case uint16:
order.PutUint16(bs, v)
case []uint16:
for i, x := range v {
order.PutUint16(bs[2*i:], x)
}
case *int32:
order.PutUint32(bs, uint32(*v))
case int32:
order.PutUint32(bs, uint32(v))
case []int32:
for i, x := range v {
order.PutUint32(bs[4*i:], uint32(x))
}
case *uint32:
order.PutUint32(bs, *v)
case uint32:
order.PutUint32(bs, v)
case []uint32:
for i, x := range v {
order.PutUint32(bs[4*i:], x)
}
case *int64:
order.PutUint64(bs, uint64(*v))
case int64:
order.PutUint64(bs, uint64(v))
case []int64:
for i, x := range v {
order.PutUint64(bs[8*i:], uint64(x))
}
case *uint64:
order.PutUint64(bs, *v)
case uint64:
order.PutUint64(bs, v)
case []uint64:
for i, x := range v {
order.PutUint64(bs[8*i:], x)
}
case *float32:
order.PutUint32(bs, math.Float32bits(*v))
case float32:
order.PutUint32(bs, math.Float32bits(v))
case []float32:
for i, x := range v {
order.PutUint32(bs[4*i:], math.Float32bits(x))
}
case *float64:
order.PutUint64(bs, math.Float64bits(*v))
case float64:
order.PutUint64(bs, math.Float64bits(v))
case []float64:
for i, x := range v {
order.PutUint64(bs[8*i:], math.Float64bits(x))
}
}
_, err := w.Write(bs)
return err
}
// Fallback to reflect-based encoding.
v := reflect.Indirect(reflect.ValueOf(data))
size := dataSize(v)
if size < 0 {
return errors.New("binary.Write: some values are not fixed-sized in type " + reflect.TypeOf(data).String())
}
buf := make([]byte, size)
e := &encoder{order: order, buf: buf}
e.value(v)
_, err := w.Write(buf)
return err
}
// Size returns how many bytes [Write] would generate to encode the value v, which
// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data.
// If v is neither of these, Size returns -1.
func Size(v any) int {
return dataSize(reflect.Indirect(reflect.ValueOf(v)))
}
var structSize sync.Map // map[reflect.Type]int
// dataSize returns the number of bytes the actual data represented by v occupies in memory.
// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice
// it returns the length of the slice times the element size and does not count the memory
// occupied by the header. If the type of v is not acceptable, dataSize returns -1.
func dataSize(v reflect.Value) int {
switch v.Kind() {
case reflect.Slice:
if s := sizeof(v.Type().Elem()); s >= 0 {
return s * v.Len()
}
case reflect.Struct:
t := v.Type()
if size, ok := structSize.Load(t); ok {
return size.(int)
}
size := sizeof(t)
structSize.Store(t, size)
return size
default:
if v.IsValid() {
return sizeof(v.Type())
}
}
return -1
}
// sizeof returns the size >= 0 of variables for the given type or -1 if the type is not acceptable.
func sizeof(t reflect.Type) int {
switch t.Kind() {
case reflect.Array:
if s := sizeof(t.Elem()); s >= 0 {
return s * t.Len()
}
case reflect.Struct:
sum := 0
for i, n := 0, t.NumField(); i < n; i++ {
s := sizeof(t.Field(i).Type)
if s < 0 {
return -1
}
sum += s
}
return sum
case reflect.Bool,
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
return int(t.Size())
}
return -1
}
type coder struct {
order ByteOrder
buf []byte
offset int
}
type (
decoder coder
encoder coder
)
func (d *decoder) bool() bool {
x := d.buf[d.offset]
d.offset++
return x != 0
}
func (e *encoder) bool(x bool) {
if x {
e.buf[e.offset] = 1
} else {
e.buf[e.offset] = 0
}
e.offset++
}
func (d *decoder) uint8() uint8 {
x := d.buf[d.offset]
d.offset++
return x
}
func (e *encoder) uint8(x uint8) {
e.buf[e.offset] = x
e.offset++
}
func (d *decoder) uint16() uint16 {
x := d.order.Uint16(d.buf[d.offset : d.offset+2])
d.offset += 2
return x
}
func (e *encoder) uint16(x uint16) {
e.order.PutUint16(e.buf[e.offset:e.offset+2], x)
e.offset += 2
}
func (d *decoder) uint32() uint32 {
x := d.order.Uint32(d.buf[d.offset : d.offset+4])
d.offset += 4
return x
}
func (e *encoder) uint32(x uint32) {
e.order.PutUint32(e.buf[e.offset:e.offset+4], x)
e.offset += 4
}
func (d *decoder) uint64() uint64 {
x := d.order.Uint64(d.buf[d.offset : d.offset+8])
d.offset += 8
return x
}
func (e *encoder) uint64(x uint64) {
e.order.PutUint64(e.buf[e.offset:e.offset+8], x)
e.offset += 8
}
func (d *decoder) int8() int8 { return int8(d.uint8()) }
func (e *encoder) int8(x int8) { e.uint8(uint8(x)) }
func (d *decoder) int16() int16 { return int16(d.uint16()) }
func (e *encoder) int16(x int16) { e.uint16(uint16(x)) }
func (d *decoder) int32() int32 { return int32(d.uint32()) }
func (e *encoder) int32(x int32) { e.uint32(uint32(x)) }
func (d *decoder) int64() int64 { return int64(d.uint64()) }
func (e *encoder) int64(x int64) { e.uint64(uint64(x)) }
func (d *decoder) value(v reflect.Value) {
switch v.Kind() {
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
d.value(v.Index(i))
}
case reflect.Struct:
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
// Note: Calling v.CanSet() below is an optimization.
// It would be sufficient to check the field name,
// but creating the StructField info for each field is
// costly (run "go test -bench=ReadStruct" and compare
// results when making changes to this code).
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
d.value(v)
} else {
d.skip(v)
}
}
case reflect.Slice:
l := v.Len()
for i := 0; i < l; i++ {
d.value(v.Index(i))
}
case reflect.Bool:
v.SetBool(d.bool())
case reflect.Int8:
v.SetInt(int64(d.int8()))
case reflect.Int16:
v.SetInt(int64(d.int16()))
case reflect.Int32:
v.SetInt(int64(d.int32()))
case reflect.Int64:
v.SetInt(d.int64())
case reflect.Uint8:
v.SetUint(uint64(d.uint8()))
case reflect.Uint16:
v.SetUint(uint64(d.uint16()))
case reflect.Uint32:
v.SetUint(uint64(d.uint32()))
case reflect.Uint64:
v.SetUint(d.uint64())
case reflect.Float32:
v.SetFloat(float64(math.Float32frombits(d.uint32())))
case reflect.Float64:
v.SetFloat(math.Float64frombits(d.uint64()))
case reflect.Complex64:
v.SetComplex(complex(
float64(math.Float32frombits(d.uint32())),
float64(math.Float32frombits(d.uint32())),
))
case reflect.Complex128:
v.SetComplex(complex(
math.Float64frombits(d.uint64()),
math.Float64frombits(d.uint64()),
))
}
}
func (e *encoder) value(v reflect.Value) {
switch v.Kind() {
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
e.value(v.Index(i))
}
case reflect.Struct:
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
// see comment for corresponding code in decoder.value()
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
e.value(v)
} else {
e.skip(v)
}
}
case reflect.Slice:
l := v.Len()
for i := 0; i < l; i++ {
e.value(v.Index(i))
}
case reflect.Bool:
e.bool(v.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch v.Type().Kind() {
case reflect.Int8:
e.int8(int8(v.Int()))
case reflect.Int16:
e.int16(int16(v.Int()))
case reflect.Int32:
e.int32(int32(v.Int()))
case reflect.Int64:
e.int64(v.Int())
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
switch v.Type().Kind() {
case reflect.Uint8:
e.uint8(uint8(v.Uint()))
case reflect.Uint16:
e.uint16(uint16(v.Uint()))
case reflect.Uint32:
e.uint32(uint32(v.Uint()))
case reflect.Uint64:
e.uint64(v.Uint())
}
case reflect.Float32, reflect.Float64:
switch v.Type().Kind() {
case reflect.Float32:
e.uint32(math.Float32bits(float32(v.Float())))
case reflect.Float64:
e.uint64(math.Float64bits(v.Float()))
}
case reflect.Complex64, reflect.Complex128:
switch v.Type().Kind() {
case reflect.Complex64:
x := v.Complex()
e.uint32(math.Float32bits(float32(real(x))))
e.uint32(math.Float32bits(float32(imag(x))))
case reflect.Complex128:
x := v.Complex()
e.uint64(math.Float64bits(real(x)))
e.uint64(math.Float64bits(imag(x)))
}
}
}
func (d *decoder) skip(v reflect.Value) {
d.offset += dataSize(v)
}
func (e *encoder) skip(v reflect.Value) {
n := dataSize(v)
zero := e.buf[e.offset : e.offset+n]
for i := range zero {
zero[i] = 0
}
e.offset += n
}
// intDataSize returns the size of the data required to represent the data when encoded.
// It returns zero if the type cannot be implemented by the fast path in Read or Write.
func intDataSize(data any) int {
switch data := data.(type) {
case bool, int8, uint8, *bool, *int8, *uint8:
return 1
case []bool:
return len(data)
case []int8:
return len(data)
case []uint8:
return len(data)
case int16, uint16, *int16, *uint16:
return 2
case []int16:
return 2 * len(data)
case []uint16:
return 2 * len(data)
case int32, uint32, *int32, *uint32:
return 4
case []int32:
return 4 * len(data)
case []uint32:
return 4 * len(data)
case int64, uint64, *int64, *uint64:
return 8
case []int64:
return 8 * len(data)
case []uint64:
return 8 * len(data)
case float32, *float32:
return 4
case float64, *float64:
return 8
case []float32:
return 4 * len(data)
case []float64:
return 8 * len(data)
}
return 0
}

View file

@ -0,0 +1,14 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build armbe || arm64be || m68k || mips || mips64 || mips64p32 || ppc || ppc64 || s390 || s390x || shbe || sparc || sparc64
package binary
type nativeEndian struct {
bigEndian
}
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
var NativeEndian nativeEndian

View file

@ -0,0 +1,14 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build 386 || amd64 || amd64p32 || alpha || arm || arm64 || loong64 || mipsle || mips64le || mips64p32le || nios2 || ppc64le || riscv || riscv64 || sh || wasm
package binary
type nativeEndian struct {
littleEndian
}
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
var NativeEndian nativeEndian

View file

@ -0,0 +1,305 @@
package binary
import (
"bufio"
"errors"
"io"
"reflect"
E "github.com/sagernet/sing/common/exceptions"
)
func ReadDataSlice(r *bufio.Reader, order ByteOrder, data ...any) error {
for index, item := range data {
err := ReadData(r, order, item)
if err != nil {
return E.Cause(err, "[", index, "]")
}
}
return nil
}
func ReadData(r *bufio.Reader, order ByteOrder, data any) error {
switch dataPtr := data.(type) {
case *[]uint8:
bytesLen, err := ReadUvarint(r)
if err != nil {
return E.Cause(err, "bytes length")
}
newBytes := make([]uint8, bytesLen)
_, err = io.ReadFull(r, newBytes)
if err != nil {
return E.Cause(err, "bytes value")
}
*dataPtr = newBytes
default:
if intBaseDataSize(data) != 0 {
return Read(r, order, data)
}
}
dataValue := reflect.ValueOf(data)
if dataValue.Kind() == reflect.Pointer {
dataValue = dataValue.Elem()
}
return readData(r, order, dataValue)
}
func readData(r *bufio.Reader, order ByteOrder, data reflect.Value) error {
switch data.Kind() {
case reflect.Pointer:
pointerValue, err := r.ReadByte()
if err != nil {
return err
}
if pointerValue == 0 {
data.SetZero()
return nil
}
if data.IsNil() {
data.Set(reflect.New(data.Type().Elem()))
}
return readData(r, order, data.Elem())
case reflect.String:
stringLength, err := ReadUvarint(r)
if err != nil {
return E.Cause(err, "string length")
}
if stringLength == 0 {
data.SetZero()
} else {
stringData := make([]byte, stringLength)
_, err = io.ReadFull(r, stringData)
if err != nil {
return E.Cause(err, "string value")
}
data.SetString(string(stringData))
}
case reflect.Array:
arrayLen := data.Len()
for i := 0; i < arrayLen; i++ {
err := readData(r, order, data.Index(i))
if err != nil {
return E.Cause(err, "[", i, "]")
}
}
case reflect.Slice:
sliceLength, err := ReadUvarint(r)
if err != nil {
return E.Cause(err, "slice length")
}
if !data.IsNil() && data.Cap() >= int(sliceLength) {
data.SetLen(int(sliceLength))
} else if sliceLength > 0 {
data.Set(reflect.MakeSlice(data.Type(), int(sliceLength), int(sliceLength)))
}
if sliceLength > 0 {
if data.Type().Elem().Kind() == reflect.Uint8 {
_, err = io.ReadFull(r, data.Bytes())
if err != nil {
return E.Cause(err, "bytes value")
}
} else {
for index := 0; index < int(sliceLength); index++ {
err = readData(r, order, data.Index(index))
if err != nil {
return E.Cause(err, "[", index, "]")
}
}
}
}
case reflect.Map:
mapLength, err := ReadUvarint(r)
if err != nil {
return E.Cause(err, "map length")
}
data.Set(reflect.MakeMap(data.Type()))
for index := 0; index < int(mapLength); index++ {
key := reflect.New(data.Type().Key()).Elem()
err = readData(r, order, key)
if err != nil {
return E.Cause(err, "[", index, "].key")
}
value := reflect.New(data.Type().Elem()).Elem()
err = readData(r, order, value)
if err != nil {
return E.Cause(err, "[", index, "].value")
}
data.SetMapIndex(key, value)
}
case reflect.Struct:
fieldType := data.Type()
fieldLen := data.NumField()
for i := 0; i < fieldLen; i++ {
field := data.Field(i)
fieldName := fieldType.Field(i).Name
if field.CanSet() || fieldName != "_" {
err := readData(r, order, field)
if err != nil {
return E.Cause(err, fieldName)
}
}
}
default:
size := dataSize(data)
if size < 0 {
return errors.New("invalid type " + reflect.TypeOf(data).String())
}
d := &decoder{order: order, buf: make([]byte, size)}
_, err := io.ReadFull(r, d.buf)
if err != nil {
return err
}
d.value(data)
}
return nil
}
func WriteDataSlice(writer *bufio.Writer, order ByteOrder, data ...any) error {
for index, item := range data {
err := WriteData(writer, order, item)
if err != nil {
return E.Cause(err, "[", index, "]")
}
}
return nil
}
func WriteData(writer *bufio.Writer, order ByteOrder, data any) error {
switch dataPtr := data.(type) {
case []uint8:
_, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(len(dataPtr))))
if err != nil {
return E.Cause(err, "bytes length")
}
_, err = writer.Write(dataPtr)
if err != nil {
return E.Cause(err, "bytes value")
}
default:
if intBaseDataSize(data) != 0 {
return Write(writer, order, data)
}
}
return writeData(writer, order, reflect.Indirect(reflect.ValueOf(data)))
}
func writeData(writer *bufio.Writer, order ByteOrder, data reflect.Value) error {
switch data.Kind() {
case reflect.Pointer:
if data.IsNil() {
err := writer.WriteByte(0)
if err != nil {
return err
}
} else {
err := writer.WriteByte(1)
if err != nil {
return err
}
return writeData(writer, order, data.Elem())
}
case reflect.String:
stringValue := data.String()
_, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(len(stringValue))))
if err != nil {
return E.Cause(err, "string length")
}
if stringValue != "" {
_, err = writer.WriteString(stringValue)
if err != nil {
return E.Cause(err, "string value")
}
}
case reflect.Array:
dataLen := data.Len()
for i := 0; i < dataLen; i++ {
err := writeData(writer, order, data.Index(i))
if err != nil {
return E.Cause(err, "[", i, "]")
}
}
case reflect.Slice:
dataLen := data.Len()
_, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(dataLen)))
if err != nil {
return E.Cause(err, "slice length")
}
if dataLen > 0 {
if data.Type().Elem().Kind() == reflect.Uint8 {
_, err = writer.Write(data.Bytes())
if err != nil {
return E.Cause(err, "bytes value")
}
} else {
for i := 0; i < dataLen; i++ {
err = writeData(writer, order, data.Index(i))
if err != nil {
return E.Cause(err, "[", i, "]")
}
}
}
}
case reflect.Map:
dataLen := data.Len()
_, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(dataLen)))
if err != nil {
return E.Cause(err, "map length")
}
if dataLen > 0 {
for index, key := range data.MapKeys() {
err = writeData(writer, order, key)
if err != nil {
return E.Cause(err, "[", index, "].key")
}
err = writeData(writer, order, data.MapIndex(key))
if err != nil {
return E.Cause(err, "[", index, "].value")
}
}
}
case reflect.Struct:
fieldType := data.Type()
fieldLen := data.NumField()
for i := 0; i < fieldLen; i++ {
field := data.Field(i)
fieldName := fieldType.Field(i).Name
if field.CanSet() || fieldName != "_" {
err := writeData(writer, order, field)
if err != nil {
return E.Cause(err, fieldName)
}
}
}
default:
size := dataSize(data)
if size < 0 {
return errors.New("binary.Write: some values are not fixed-sized in type " + data.Type().String())
}
buf := make([]byte, size)
e := &encoder{order: order, buf: buf}
e.value(data)
_, err := writer.Write(buf)
if err != nil {
return E.Cause(err, reflect.TypeOf(data).String())
}
}
return nil
}
func intBaseDataSize(data any) int {
switch data.(type) {
case bool, int8, uint8:
return 1
case int16, uint16:
return 2
case int32, uint32:
return 4
case int64, uint64:
return 8
case float32:
return 4
case float64:
return 8
}
return 0
}

166
common/binary/varint.go Normal file
View file

@ -0,0 +1,166 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package binary
// This file implements "varint" encoding of 64-bit integers.
// The encoding is:
// - unsigned integers are serialized 7 bits at a time, starting with the
// least significant bits
// - the most significant bit (msb) in each output byte indicates if there
// is a continuation byte (msb = 1)
// - signed integers are mapped to unsigned integers using "zig-zag"
// encoding: Positive values x are written as 2*x + 0, negative values
// are written as 2*(^x) + 1; that is, negative numbers are complemented
// and whether to complement is encoded in bit 0.
//
// Design note:
// At most 10 bytes are needed for 64-bit values. The encoding could
// be more dense: a full 64-bit value needs an extra byte just to hold bit 63.
// Instead, the msb of the previous byte could be used to hold bit 63 since we
// know there can't be more than 64 bits. This is a trivial improvement and
// would reduce the maximum encoding length to 9 bytes. However, it breaks the
// invariant that the msb is always the "continuation bit" and thus makes the
// format incompatible with a varint encoding for larger numbers (say 128-bit).
import (
"errors"
"io"
)
// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer.
const (
MaxVarintLen16 = 3
MaxVarintLen32 = 5
MaxVarintLen64 = 10
)
// AppendUvarint appends the varint-encoded form of x,
// as generated by [PutUvarint], to buf and returns the extended buffer.
func AppendUvarint(buf []byte, x uint64) []byte {
for x >= 0x80 {
buf = append(buf, byte(x)|0x80)
x >>= 7
}
return append(buf, byte(x))
}
// PutUvarint encodes a uint64 into buf and returns the number of bytes written.
// If the buffer is too small, PutUvarint will panic.
func PutUvarint(buf []byte, x uint64) int {
i := 0
for x >= 0x80 {
buf[i] = byte(x) | 0x80
x >>= 7
i++
}
buf[i] = byte(x)
return i + 1
}
// Uvarint decodes a uint64 from buf and returns that value and the
// number of bytes read (> 0). If an error occurred, the value is 0
// and the number of bytes n is <= 0 meaning:
//
// n == 0: buf too small
// n < 0: value larger than 64 bits (overflow)
// and -n is the number of bytes read
func Uvarint(buf []byte) (uint64, int) {
var x uint64
var s uint
for i, b := range buf {
if i == MaxVarintLen64 {
// Catch byte reads past MaxVarintLen64.
// See issue https://golang.org/issues/41185
return 0, -(i + 1) // overflow
}
if b < 0x80 {
if i == MaxVarintLen64-1 && b > 1 {
return 0, -(i + 1) // overflow
}
return x | uint64(b)<<s, i + 1
}
x |= uint64(b&0x7f) << s
s += 7
}
return 0, 0
}
// AppendVarint appends the varint-encoded form of x,
// as generated by [PutVarint], to buf and returns the extended buffer.
func AppendVarint(buf []byte, x int64) []byte {
ux := uint64(x) << 1
if x < 0 {
ux = ^ux
}
return AppendUvarint(buf, ux)
}
// PutVarint encodes an int64 into buf and returns the number of bytes written.
// If the buffer is too small, PutVarint will panic.
func PutVarint(buf []byte, x int64) int {
ux := uint64(x) << 1
if x < 0 {
ux = ^ux
}
return PutUvarint(buf, ux)
}
// Varint decodes an int64 from buf and returns that value and the
// number of bytes read (> 0). If an error occurred, the value is 0
// and the number of bytes n is <= 0 with the following meaning:
//
// n == 0: buf too small
// n < 0: value larger than 64 bits (overflow)
// and -n is the number of bytes read
func Varint(buf []byte) (int64, int) {
ux, n := Uvarint(buf) // ok to continue in presence of error
x := int64(ux >> 1)
if ux&1 != 0 {
x = ^x
}
return x, n
}
var errOverflow = errors.New("binary: varint overflows a 64-bit integer")
// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64.
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// ReadUvarint returns [io.ErrUnexpectedEOF].
func ReadUvarint(r io.ByteReader) (uint64, error) {
var x uint64
var s uint
for i := 0; i < MaxVarintLen64; i++ {
b, err := r.ReadByte()
if err != nil {
if i > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return x, err
}
if b < 0x80 {
if i == MaxVarintLen64-1 && b > 1 {
return x, errOverflow
}
return x | uint64(b)<<s, nil
}
x |= uint64(b&0x7f) << s
s += 7
}
return x, errOverflow
}
// ReadVarint reads an encoded signed integer from r and returns it as an int64.
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// ReadVarint returns [io.ErrUnexpectedEOF].
func ReadVarint(r io.ByteReader) (int64, error) {
ux, err := ReadUvarint(r) // ok to continue in presence of error
x := int64(ux >> 1)
if ux&1 != 0 {
x = ^x
}
return x, err
}

View file

@ -8,7 +8,7 @@ import (
"sync"
)
var DefaultAllocator = newDefaultAllocer()
var DefaultAllocator = newDefaultAllocator()
type Allocator interface {
Get(size int) []byte
@ -17,22 +17,28 @@ type Allocator interface {
// defaultAllocator for incoming frames, optimized to prevent overwriting after zeroing
type defaultAllocator struct {
buffers []sync.Pool
buffers [11]sync.Pool
}
// NewAllocator initiates a []byte allocator for frames less than 65536 bytes,
// the waste(memory fragmentation) of space allocation is guaranteed to be
// no more than 50%.
func newDefaultAllocer() Allocator {
alloc := new(defaultAllocator)
alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K
for k := range alloc.buffers {
i := k
alloc.buffers[k].New = func() any {
return make([]byte, 1<<uint32(i))
func newDefaultAllocator() Allocator {
return &defaultAllocator{
buffers: [...]sync.Pool{ // 64B -> 64K
{New: func() any { return new([1 << 6]byte) }},
{New: func() any { return new([1 << 7]byte) }},
{New: func() any { return new([1 << 8]byte) }},
{New: func() any { return new([1 << 9]byte) }},
{New: func() any { return new([1 << 10]byte) }},
{New: func() any { return new([1 << 11]byte) }},
{New: func() any { return new([1 << 12]byte) }},
{New: func() any { return new([1 << 13]byte) }},
{New: func() any { return new([1 << 14]byte) }},
{New: func() any { return new([1 << 15]byte) }},
{New: func() any { return new([1 << 16]byte) }},
},
}
}
return alloc
}
// Get a []byte from pool with most appropriate cap
@ -41,12 +47,42 @@ func (alloc *defaultAllocator) Get(size int) []byte {
return nil
}
bits := msb(size)
if size == 1<<bits {
return alloc.buffers[bits].Get().([]byte)[:size]
var index uint16
if size > 64 {
index = msb(size)
if size != 1<<index {
index += 1
}
index -= 6
}
return alloc.buffers[bits+1].Get().([]byte)[:size]
buffer := alloc.buffers[index].Get()
switch index {
case 0:
return buffer.(*[1 << 6]byte)[:size]
case 1:
return buffer.(*[1 << 7]byte)[:size]
case 2:
return buffer.(*[1 << 8]byte)[:size]
case 3:
return buffer.(*[1 << 9]byte)[:size]
case 4:
return buffer.(*[1 << 10]byte)[:size]
case 5:
return buffer.(*[1 << 11]byte)[:size]
case 6:
return buffer.(*[1 << 12]byte)[:size]
case 7:
return buffer.(*[1 << 13]byte)[:size]
case 8:
return buffer.(*[1 << 14]byte)[:size]
case 9:
return buffer.(*[1 << 15]byte)[:size]
case 10:
return buffer.(*[1 << 16]byte)[:size]
default:
panic("invalid pool index")
}
}
// Put returns a []byte to pool for future use,
@ -56,10 +92,37 @@ func (alloc *defaultAllocator) Put(buf []byte) error {
if cap(buf) == 0 || cap(buf) > 65536 || cap(buf) != 1<<bits {
return errors.New("allocator Put() incorrect buffer size")
}
bits -= 6
buf = buf[:cap(buf)]
//nolint
//lint:ignore SA6002 ignore temporarily
alloc.buffers[bits].Put(buf)
switch bits {
case 0:
alloc.buffers[bits].Put((*[1 << 6]byte)(buf))
case 1:
alloc.buffers[bits].Put((*[1 << 7]byte)(buf))
case 2:
alloc.buffers[bits].Put((*[1 << 8]byte)(buf))
case 3:
alloc.buffers[bits].Put((*[1 << 9]byte)(buf))
case 4:
alloc.buffers[bits].Put((*[1 << 10]byte)(buf))
case 5:
alloc.buffers[bits].Put((*[1 << 11]byte)(buf))
case 6:
alloc.buffers[bits].Put((*[1 << 12]byte)(buf))
case 7:
alloc.buffers[bits].Put((*[1 << 13]byte)(buf))
case 8:
alloc.buffers[bits].Put((*[1 << 14]byte)(buf))
case 9:
alloc.buffers[bits].Put((*[1 << 15]byte)(buf))
case 10:
alloc.buffers[bits].Put((*[1 << 16]byte)(buf))
default:
panic("invalid pool index")
}
return nil
}

View file

@ -4,29 +4,27 @@ import (
"crypto/rand"
"io"
"net"
"strconv"
"sync/atomic"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)
const ReversedHeader = 1024
type Buffer struct {
data []byte
start int
end int
refs int32
capacity int
refs atomic.Int32
managed bool
closed bool
}
func New() *Buffer {
return &Buffer{
data: Get(BufferSize),
start: ReversedHeader,
end: ReversedHeader,
capacity: BufferSize,
managed: true,
}
}
@ -34,8 +32,7 @@ func New() *Buffer {
func NewPacket() *Buffer {
return &Buffer{
data: Get(UDPBufferSize),
start: ReversedHeader,
end: ReversedHeader,
capacity: UDPBufferSize,
managed: true,
}
}
@ -46,61 +43,28 @@ func NewSize(size int) *Buffer {
} else if size > 65535 {
return &Buffer{
data: make([]byte, size),
capacity: size,
}
}
return &Buffer{
data: Get(size),
capacity: size,
managed: true,
}
}
func StackNew() *Buffer {
if common.UnsafeBuffer {
return &Buffer{
data: make([]byte, BufferSize),
start: ReversedHeader,
end: ReversedHeader,
}
} else {
return New()
}
}
func StackNewPacket() *Buffer {
if common.UnsafeBuffer {
return &Buffer{
data: make([]byte, UDPBufferSize),
start: ReversedHeader,
end: ReversedHeader,
}
} else {
return NewPacket()
}
}
func StackNewSize(size int) *Buffer {
if size == 0 {
return &Buffer{}
}
if common.UnsafeBuffer {
return &Buffer{
data: Make(size),
}
} else {
return NewSize(size)
}
}
func As(data []byte) *Buffer {
return &Buffer{
data: data,
end: len(data),
capacity: len(data),
}
}
func With(data []byte) *Buffer {
return &Buffer{
data: data,
capacity: len(data),
}
}
@ -114,8 +78,8 @@ func (b *Buffer) SetByte(index int, value byte) {
func (b *Buffer) Extend(n int) []byte {
end := b.end + n
if end > cap(b.data) {
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",end " + strconv.Itoa(b.end) + ", need " + strconv.Itoa(n))
if end > b.capacity {
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",end ", b.end, ", need ", n))
}
ext := b.data[b.end:end]
b.end = end
@ -137,14 +101,14 @@ func (b *Buffer) Write(data []byte) (n int, err error) {
if b.IsFull() {
return 0, io.ErrShortBuffer
}
n = copy(b.data[b.end:], data)
n = copy(b.data[b.end:b.capacity], data)
b.end += n
return
}
func (b *Buffer) ExtendHeader(n int) []byte {
if b.start < n {
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",start " + strconv.Itoa(b.start) + ", need " + strconv.Itoa(n))
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",start ", b.start, ", need ", n))
}
b.start -= n
return b.data[b.start : b.start+n]
@ -197,7 +161,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
}
func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) {
if b.end+size > b.Cap() {
if b.end+size > b.capacity {
return 0, io.ErrShortBuffer
}
n, err = io.ReadFull(r, b.data[b.end:b.end+size])
@ -234,7 +198,7 @@ func (b *Buffer) WriteString(s string) (n int, err error) {
if b.IsFull() {
return 0, io.ErrShortBuffer
}
n = copy(b.data[b.end:], s)
n = copy(b.data[b.end:b.capacity], s)
b.end += n
return
}
@ -249,13 +213,10 @@ func (b *Buffer) WriteZero() error {
}
func (b *Buffer) WriteZeroN(n int) error {
if b.end+n > b.Cap() {
if b.end+n > b.capacity {
return io.ErrShortBuffer
}
for i := b.end; i <= b.end+n; i++ {
b.data[i] = 0
}
b.end += n
common.ClearArray(b.Extend(n))
return nil
}
@ -298,40 +259,63 @@ func (b *Buffer) Resize(start, end int) {
b.end = b.start + end
}
func (b *Buffer) Reset() {
b.start = ReversedHeader
b.end = ReversedHeader
func (b *Buffer) Reserve(n int) {
if n > b.capacity {
panic(F.ToString("buffer overflow: capacity ", b.capacity, ", need ", n))
}
b.capacity -= n
}
func (b *Buffer) FullReset() {
func (b *Buffer) OverCap(n int) {
if b.capacity+n > len(b.data) {
panic(F.ToString("buffer overflow: capacity ", len(b.data), ", need ", b.capacity+n))
}
b.capacity += n
}
func (b *Buffer) Reset() {
b.start = 0
b.end = 0
b.capacity = len(b.data)
}
// Deprecated: use Reset instead.
func (b *Buffer) FullReset() {
b.Reset()
}
func (b *Buffer) IncRef() {
atomic.AddInt32(&b.refs, 1)
b.refs.Add(1)
}
func (b *Buffer) DecRef() {
atomic.AddInt32(&b.refs, -1)
b.refs.Add(-1)
}
func (b *Buffer) Release() {
if b == nil || b.closed || !b.managed {
if b == nil || !b.managed {
return
}
if atomic.LoadInt32(&b.refs) > 0 {
if b.refs.Load() > 0 {
return
}
common.Must(Put(b.data))
*b = Buffer{closed: true}
*b = Buffer{}
}
func (b *Buffer) Cut(start int, end int) *Buffer {
b.start += start
b.end = len(b.data) - end
return &Buffer{
data: b.data[b.start:b.end],
func (b *Buffer) Leak() {
if debug.Enabled {
if b == nil || !b.managed {
return
}
refs := b.refs.Load()
if refs == 0 {
panic("leaking buffer")
} else {
panic(F.ToString("leaking buffer with ", refs, " references"))
}
} else {
b.Release()
}
}
@ -344,6 +328,10 @@ func (b *Buffer) Len() int {
}
func (b *Buffer) Cap() int {
return b.capacity
}
func (b *Buffer) RawCap() int {
return len(b.data)
}
@ -351,10 +339,6 @@ func (b *Buffer) Bytes() []byte {
return b.data[b.start:b.end]
}
func (b *Buffer) Slice() []byte {
return b.data
}
func (b *Buffer) From(n int) []byte {
return b.data[b.start+n : b.end]
}
@ -372,11 +356,11 @@ func (b *Buffer) Index(start int) []byte {
}
func (b *Buffer) FreeLen() int {
return b.Cap() - b.end
return b.capacity - b.end
}
func (b *Buffer) FreeBytes() []byte {
return b.data[b.end:b.Cap()]
return b.data[b.end:b.capacity]
}
func (b *Buffer) IsEmpty() bool {
@ -384,7 +368,7 @@ func (b *Buffer) IsEmpty() bool {
}
func (b *Buffer) IsFull() bool {
return b.end == b.Cap()
return b.end == b.capacity
}
func (b *Buffer) ToOwned() *Buffer {
@ -392,5 +376,6 @@ func (b *Buffer) ToOwned() *Buffer {
copy(n.data[b.start:b.end], b.data[b.start:b.end])
n.start = b.start
n.end = b.end
n.capacity = b.capacity
return n
}

View file

@ -1,9 +0,0 @@
package buf
import "encoding/hex"
func EncodeHexString(src []byte) string {
dst := Make(hex.EncodedLen(len(src)))
hex.Encode(dst, src)
return string(dst)
}

View file

@ -11,46 +11,7 @@ func Put(buf []byte) error {
return DefaultAllocator.Put(buf)
}
// Deprecated: use array instead.
func Make(size int) []byte {
if size == 0 {
return nil
}
var buffer []byte
switch {
case size <= 2:
buffer = make([]byte, 2)
case size <= 4:
buffer = make([]byte, 4)
case size <= 8:
buffer = make([]byte, 8)
case size <= 16:
buffer = make([]byte, 16)
case size <= 32:
buffer = make([]byte, 32)
case size <= 64:
buffer = make([]byte, 64)
case size <= 128:
buffer = make([]byte, 128)
case size <= 256:
buffer = make([]byte, 256)
case size <= 512:
buffer = make([]byte, 512)
case size <= 1024:
buffer = make([]byte, 1024)
case size <= 2048:
buffer = make([]byte, 2048)
case size <= 4096:
buffer = make([]byte, 4096)
case size <= 8192:
buffer = make([]byte, 8192)
case size <= 16384:
buffer = make([]byte, 16384)
case size <= 32768:
buffer = make([]byte, 32768)
case size <= 65535:
buffer = make([]byte, 65535)
default:
return make([]byte, size)
}
return buffer[:size]
}

View file

@ -1,34 +0,0 @@
//go:build !disable_unsafe
package buf
import (
"unsafe"
"github.com/sagernet/sing/common"
)
type dbgVar struct {
name string
value *int32
}
//go:linkname dbgvars runtime.dbgvars
var dbgvars any
// go.info.runtime.dbgvars: relocation target go.info.[]github.com/sagernet/sing/common/buf.dbgVar not defined
// var dbgvars []dbgVar
func init() {
if !common.UnsafeBuffer {
return
}
debugVars := *(*[]dbgVar)(unsafe.Pointer(&dbgvars))
for _, v := range debugVars {
if v.name == "invalidptr" {
*v.value = 0
return
}
}
panic("can't disable invalidptr")
}

34
common/bufio/addr_bsd.go Normal file
View file

@ -0,0 +1,34 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package bufio
import (
"encoding/binary"
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
if destination.Addr().Is4() {
sa := unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: destination.Addr().As4(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet4
} else {
sa := unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: destination.Addr().As16(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet6
}
return
}

View file

@ -0,0 +1,30 @@
package bufio
import (
"encoding/binary"
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
if destination.Addr().Is4() {
sa := unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: destination.Addr().As4(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet4
} else {
sa := unix.RawSockaddrInet6{
Family: unix.AF_INET6,
Addr: destination.Addr().As16(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet6
}
return
}

View file

@ -0,0 +1,30 @@
package bufio
import (
"encoding/binary"
"net/netip"
"unsafe"
"golang.org/x/sys/windows"
)
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen int32) {
if destination.Addr().Is4() {
sa := windows.RawSockaddrInet4{
Family: windows.AF_INET,
Addr: destination.Addr().As4(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = int32(unsafe.Sizeof(sa))
} else {
sa := windows.RawSockaddrInet6{
Family: windows.AF_INET6,
Addr: destination.Addr().As16(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = int32(unsafe.Sizeof(sa))
}
return
}

View file

@ -8,51 +8,76 @@ import (
N "github.com/sagernet/sing/common/network"
)
type BindPacketConn struct {
type BindPacketConn interface {
N.NetPacketConn
Addr net.Addr
net.Conn
}
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) *BindPacketConn {
return &BindPacketConn{
type bindPacketConn struct {
N.NetPacketConn
addr net.Addr
}
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) BindPacketConn {
return &bindPacketConn{
NewPacketConn(conn),
addr,
}
}
func (c *BindPacketConn) Read(b []byte) (n int, err error) {
func (c *bindPacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *BindPacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.Addr)
func (c *bindPacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.addr)
}
func (c *BindPacketConn) RemoteAddr() net.Addr {
return c.Addr
func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
if !isReadWaiter {
return nil, false
}
return &bindPacketReadWaiter{readWaiter}, true
}
func (c *BindPacketConn) Upstream() any {
func (c *bindPacketConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *bindPacketConn) Upstream() any {
return c.NetPacketConn
}
var (
_ N.NetPacketConn = (*UnbindPacketConn)(nil)
_ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil)
)
type UnbindPacketConn struct {
N.ExtendedConn
Addr M.Socksaddr
addr M.Socksaddr
}
func NewUnbindPacketConn(conn net.Conn) *UnbindPacketConn {
func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn {
return &UnbindPacketConn{
NewExtendedConn(conn),
M.SocksaddrFromNet(conn.RemoteAddr()),
}
}
func NewUnbindPacketConnWithAddr(conn net.Conn, addr M.Socksaddr) N.NetPacketConn {
return &UnbindPacketConn{
NewExtendedConn(conn),
addr,
}
}
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.ExtendedConn.Read(p)
if err == nil {
addr = c.Addr.UDPAddr()
addr = c.addr.UDPAddr()
}
return
}
@ -66,7 +91,7 @@ func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
if err != nil {
return
}
destination = c.Addr
destination = c.addr
return
}
@ -74,6 +99,67 @@ func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
readWaiter, isReadWaiter := CreateReadWaiter(c.ExtendedConn)
if !isReadWaiter {
return nil, false
}
return &unbindPacketReadWaiter{readWaiter, c.addr}, true
}
func (c *UnbindPacketConn) Upstream() any {
return c.ExtendedConn
}
func NewServerPacketConn(conn net.PacketConn) N.ExtendedConn {
return &serverPacketConn{
NetPacketConn: NewPacketConn(conn),
}
}
type serverPacketConn struct {
N.NetPacketConn
remoteAddr M.Socksaddr
}
func (c *serverPacketConn) Read(p []byte) (n int, err error) {
n, addr, err := c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
c.remoteAddr = M.SocksaddrFromNet(addr)
return
}
func (c *serverPacketConn) ReadBuffer(buffer *buf.Buffer) error {
destination, err := c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return err
}
c.remoteAddr = destination
return nil
}
func (c *serverPacketConn) Write(p []byte) (n int, err error) {
return c.NetPacketConn.WriteTo(p, c.remoteAddr.UDPAddr())
}
func (c *serverPacketConn) WriteBuffer(buffer *buf.Buffer) error {
return c.NetPacketConn.WritePacket(buffer, c.remoteAddr)
}
func (c *serverPacketConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *serverPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
if !isReadWaiter {
return nil, false
}
return &serverPacketReadWaiter{c, readWaiter}, true
}

62
common/bufio/bind_wait.go Normal file
View file

@ -0,0 +1,62 @@
package bufio
import (
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var _ N.ReadWaiter = (*bindPacketReadWaiter)(nil)
type bindPacketReadWaiter struct {
readWaiter N.PacketReadWaiter
}
func (w *bindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *bindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, _, err = w.readWaiter.WaitReadPacket()
return
}
var _ N.PacketReadWaiter = (*unbindPacketReadWaiter)(nil)
type unbindPacketReadWaiter struct {
readWaiter N.ReadWaiter
addr M.Socksaddr
}
func (w *unbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *unbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, err = w.readWaiter.WaitReadBuffer()
if err != nil {
return
}
destination = w.addr
return
}
var _ N.ReadWaiter = (*serverPacketReadWaiter)(nil)
type serverPacketReadWaiter struct {
*serverPacketConn
readWaiter N.PacketReadWaiter
}
func (w *serverPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *serverPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, destination, err := w.readWaiter.WaitReadPacket()
if err != nil {
return
}
w.remoteAddr = destination
return
}

View file

@ -37,7 +37,7 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) {
if err != nil {
return
}
w.buffer.FullReset()
w.buffer.Reset()
}
}

View file

@ -30,7 +30,7 @@ func (c *ChunkReader) ReadBuffer(buffer *buf.Buffer) error {
} else if !c.cache.IsEmpty() {
return common.Error(buffer.ReadFrom(c.cache))
}
c.cache.FullReset()
c.cache.Reset()
err := c.upstream.ReadBuffer(c.cache)
if err != nil {
c.cache.Release()
@ -46,7 +46,7 @@ func (c *ChunkReader) Read(p []byte) (n int, err error) {
} else if !c.cache.IsEmpty() {
return c.cache.Read(p)
}
c.cache.FullReset()
c.cache.Reset()
err = c.upstream.ReadBuffer(c.cache)
if err != nil {
c.cache.Release()
@ -70,7 +70,7 @@ func (c *ChunkReader) ReadChunk() (*buf.Buffer, error) {
} else if !c.cache.IsEmpty() {
return c.cache, nil
}
c.cache.FullReset()
c.cache.Reset()
err := c.upstream.ReadBuffer(c.cache)
if err != nil {
c.cache.Release()

View file

@ -22,7 +22,7 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
} else if destination == nil {
return 0, E.New("nil writer")
}
originDestination := destination
originSource := source
var readCounters, writeCounters []N.CountFunc
for {
source, readCounters = N.UnwrapCountReader(source, readCounters)
@ -45,105 +45,61 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
break
}
return CopyExtended(originDestination, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
}
func CopyExtended(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
safeSrc := N.IsSafeReader(source)
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
if safeSrc != nil {
if headroom == 0 {
return CopyExtendedWithSrcBuffer(originDestination, destination, safeSrc, readCounters, writeCounters)
}
}
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
readWaiter, isReadWaiter := CreateReadWaiter(source)
if isReadWaiter {
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destination),
})
if !needCopy || common.LowMemory {
var handled bool
handled, n, err = copyWaitWithPool(originDestination, destination, readWaiter, readCounters, writeCounters)
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
if handled {
return
}
}
if !common.UnsafeBuffer || N.IsUnsafeWriter(destination) {
return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters)
}
bufferSize := N.CalculateMTU(source, destination)
if bufferSize > 0 {
bufferSize += headroom
} else {
bufferSize = buf.BufferSize
}
_buffer := buf.StackNewSize(bufferSize)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
return CopyExtendedBuffer(originDestination, destination, source, buffer, readCounters, writeCounters)
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
}
func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
buffer.IncRef()
defer buffer.DecRef()
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
var notFirstTime bool
for {
readBuffer.Resize(frontHeadroom, 0)
err = source.ReadBuffer(readBuffer)
err = source.ReadBuffer(buffer)
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
if !notFirstTime {
err = N.HandshakeFailure(originDestination, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destination.WriteBuffer(buffer)
if err != nil {
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var notFirstTime bool
for {
var buffer *buf.Buffer
buffer, err = source.ReadBufferThreadSafe()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
if !notFirstTime {
err = N.HandshakeFailure(originDestination, err)
}
return
}
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
@ -157,7 +113,7 @@ func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.Extend
}
}
func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
bufferSize := N.CalculateMTU(source, destination)
@ -169,26 +125,25 @@ func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWri
var notFirstTime bool
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
err = source.ReadBuffer(readBuffer)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
err = source.ReadBuffer(buffer)
if err != nil {
buffer.Release()
if errors.Is(err, io.EOF) {
err = nil
return
}
if !notFirstTime {
err = N.HandshakeFailure(originDestination, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Release()
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
@ -249,6 +204,7 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
var readCounters, writeCounters []N.CountFunc
var cachedPackets []*N.PacketBuffer
originSource := source
for {
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters)
@ -262,113 +218,38 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
break
}
if cachedPackets != nil {
n, err = WritePacketWithPool(destinationConn, cachedPackets)
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
if err != nil {
return
}
}
safeSrc := N.IsSafePacketReader(source)
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
headroom := frontHeadroom + rearHeadroom
if safeSrc != nil {
if headroom == 0 {
var copyN int64
copyN, err = CopyPacketWithSrcBuffer(destinationConn, safeSrc, readCounters, writeCounters)
n += copyN
return
}
}
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter {
var (
handled bool
copeN int64
)
handled, copeN, err = copyPacketWaitWithPool(destinationConn, readWaiter, readCounters, writeCounters)
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter {
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destinationConn),
})
if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled {
n += copeN
return
}
}
if N.IsUnsafeWriter(destinationConn) {
return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters)
}
bufferSize := N.CalculateMTU(source, destinationConn)
if bufferSize > 0 {
bufferSize += headroom
} else {
bufferSize = buf.UDPBufferSize
}
_buffer := buf.StackNewSize(bufferSize)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
buffer.IncRef()
defer buffer.DecRef()
var destination M.Socksaddr
var notFirstTime bool
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
for {
readBuffer.Resize(frontHeadroom, 0)
destination, err = source.ReadPacket(readBuffer)
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(destinationConn, err)
}
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
n += copeN
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var buffer *buf.Buffer
var destination M.Socksaddr
var notFirstTime bool
for {
buffer, destination, err = source.ReadPacketThreadSafe()
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(destinationConn, err)
}
return
}
dataLen := buffer.Len()
if dataLen == 0 {
continue
}
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
bufferSize := N.CalculateMTU(source, destinationConn)
@ -378,25 +259,23 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r
bufferSize = buf.UDPBufferSize
}
var destination M.Socksaddr
var notFirstTime bool
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
destination, err = source.ReadPacket(readBuffer)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
destination, err = source.ReadPacket(buffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(destinationConn, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
@ -410,24 +289,28 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r
}
}
func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
var notFirstTime bool
for _, packetBuffer := range packetBuffers {
buffer := buf.NewPacket()
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
_, err = readBuffer.Write(packetBuffer.Buffer.Bytes())
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
_, err = buffer.Write(packetBuffer.Buffer.Bytes())
packetBuffer.Buffer.Release()
if err != nil {
buffer.Release()
continue
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
if err != nil {
buffer.Release()
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)

View file

@ -1,12 +1,16 @@
package bufio
import (
"errors"
"io"
"syscall"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
rawSource, err := source.SyscallConn()
if err != nil {
return
@ -18,3 +22,69 @@ func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
return
}
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
notFirstTime bool
)
for {
buffer, err = source.WaitReadBuffer()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := buffer.Len()
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
destination M.Socksaddr
)
for {
buffer, destination, err = source.WaitReadPacket()
if err != nil {
return
}
dataLen := buffer.Len()
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}

View file

@ -3,7 +3,6 @@
package bufio
import (
"errors"
"io"
"net/netip"
"os"
@ -15,115 +14,14 @@ import (
N "github.com/sagernet/sing/common/network"
)
func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
bufferSize := N.CalculateMTU(source, destination)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
notFirstTime bool
)
source.InitializeReadWaiter(func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
})
defer source.InitializeReadWaiter(nil)
for {
err = source.WaitReadBuffer()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
if !notFirstTime {
err = N.HandshakeFailure(originDestination, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
bufferSize := N.CalculateMTU(source, destinationConn)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
destination M.Socksaddr
notFirstTime bool
)
source.InitializeReadWaiter(func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
})
defer source.InitializeReadWaiter(nil)
for {
destination, err = source.WaitReadPacket()
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(destinationConn, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
type syscallReadWaiter struct {
rawConn syscall.RawConn
readErr error
readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer
options N.ReadWaitOptions
}
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
@ -136,47 +34,48 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
return nil, false
}
func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readErr = nil
if newBuffer == nil {
w.readFunc = nil
} else {
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
buffer := newBuffer()
buffer := w.options.NewBuffer()
var readN int
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
if readN > 0 {
buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
buffer = nil
}
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr == syscall.EAGAIN {
return false
}
if readN == 0 {
if readN == 0 && w.readErr == nil {
w.readErr = io.EOF
}
return true
}
}
return false
}
func (w *syscallReadWaiter) WaitReadBuffer() error {
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
if w.readFunc == nil {
return os.ErrInvalid
return nil, os.ErrInvalid
}
err := w.rawConn.Read(w.readFunc)
err = w.rawConn.Read(w.readFunc)
if err != nil {
return err
return
}
if w.readErr != nil {
if w.readErr == io.EOF {
return io.EOF
return nil, io.EOF
}
return E.Cause(w.readErr, "raw read")
return nil, E.Cause(w.readErr, "raw read")
}
return nil
buffer = w.buffer
w.buffer = nil
return
}
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
@ -186,6 +85,8 @@ type syscallPacketReadWaiter struct {
readErr error
readFrom M.Socksaddr
readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer
options N.ReadWaitOptions
}
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
@ -198,42 +99,37 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
return nil, false
}
func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readErr = nil
w.readFrom = M.Socksaddr{}
if newBuffer == nil {
w.readFunc = nil
} else {
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
buffer := newBuffer()
buffer := w.options.NewPacketBuffer()
var readN int
var from syscall.Sockaddr
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr != nil {
buffer.Release()
return w.readErr != syscall.EAGAIN
}
if readN > 0 {
buffer.Truncate(readN)
} else {
buffer.Release()
buffer = nil
}
if w.readErr == syscall.EAGAIN {
return false
}
if from != nil {
w.options.PostReturn(buffer)
w.buffer = buffer
switch fromAddr := from.(type) {
case *syscall.SockaddrInet4:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *syscall.SockaddrInet6:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port))
}
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
return true
}
}
return false
}
func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if w.readFunc == nil {
return M.Socksaddr{}, os.ErrInvalid
return nil, M.Socksaddr{}, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
@ -243,6 +139,8 @@ func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err
err = E.Cause(w.readErr, "raw read")
return
}
buffer = w.buffer
w.buffer = nil
destination = w.readFrom
return
}

View file

@ -0,0 +1,77 @@
package bufio
import (
"net"
"testing"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
"github.com/stretchr/testify/require"
)
func TestCopyWaitTCP(t *testing.T) {
t.Parallel()
inputConn, outputConn := TCPPipe(t)
readWaiter, created := CreateReadWaiter(outputConn)
require.True(t, created)
require.NotNil(t, readWaiter)
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{
Conn: outputConn,
readWaiter: readWaiter,
}))
}
type readWaitWrapper struct {
net.Conn
readWaiter N.ReadWaiter
buffer *buf.Buffer
}
func (r *readWaitWrapper) Read(p []byte) (n int, err error) {
if r.buffer != nil {
if r.buffer.Len() > 0 {
return r.buffer.Read(p)
}
if r.buffer.IsEmpty() {
r.buffer.Release()
r.buffer = nil
}
}
buffer, err := r.readWaiter.WaitReadBuffer()
if err != nil {
return
}
r.buffer = buffer
return r.buffer.Read(p)
}
func TestCopyWaitUDP(t *testing.T) {
t.Parallel()
inputConn, outputConn, outputAddr := UDPPipe(t)
readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn))
require.True(t, created)
require.NotNil(t, readWaiter)
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{
PacketConn: outputConn,
readWaiter: readWaiter,
}, outputAddr))
}
type packetReadWaitWrapper struct {
net.PacketConn
readWaiter N.PacketReadWaiter
}
func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
buffer, destination, err := r.readWaiter.WaitReadPacket()
if err != nil {
return
}
n = copy(p, buffer.Bytes())
buffer.Release()
addr = destination.UDPAddr()
return
}

View file

@ -2,22 +2,206 @@ package bufio
import (
"io"
"net/netip"
"os"
"syscall"
"unsafe"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/windows"
)
func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
var modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
var procrecv = modws2_32.NewProc("recv")
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
r0, _, e1 := syscall.SyscallN(procrecv.Addr(), uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags))
n = int32(r0)
if n == -1 {
err = errnoErr(e1)
}
return
}
func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
type syscallReadWaiter struct {
rawConn syscall.RawConn
readErr error
readFunc func(fd uintptr) (done bool)
hasData bool
buffer *buf.Buffer
options N.ReadWaitOptions
}
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
rawConn, err := syscallConn.SyscallConn()
if err == nil {
return &syscallReadWaiter{rawConn: rawConn}, true
}
}
return nil, false
}
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
if !w.hasData {
w.hasData = true
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
// socket is readable if we return false. So the `recv` syscall will not block the system thread.
return false
}
buffer := w.options.NewBuffer()
var readN int32
readN, w.readErr = recv(windows.Handle(fd), buffer.FreeBytes(), 0)
if readN > 0 {
buffer.Truncate(int(readN))
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
}
if w.readErr == windows.WSAEWOULDBLOCK {
return false
}
if readN == 0 && w.readErr == nil {
w.readErr = io.EOF
}
w.hasData = false
return true
}
return false
}
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
if w.readFunc == nil {
return nil, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
return
}
if w.readErr != nil {
if w.readErr == io.EOF {
return nil, io.EOF
}
return nil, E.Cause(w.readErr, "raw read")
}
buffer = w.buffer
w.buffer = nil
return
}
func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) {
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
type syscallPacketReadWaiter struct {
rawConn syscall.RawConn
readErr error
readFrom M.Socksaddr
readFunc func(fd uintptr) (done bool)
hasData bool
buffer *buf.Buffer
options N.ReadWaitOptions
}
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
rawConn, err := syscallConn.SyscallConn()
if err == nil {
return &syscallPacketReadWaiter{rawConn: rawConn}, true
}
}
return nil, false
}
func createSyscallPacketReadWaiter(reader any) (N.PacketReadWaiter, bool) {
return nil, false
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
if !w.hasData {
w.hasData = true
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
// socket is readable if we return false. So the `recvfrom` syscall will not block the system thread.
return false
}
buffer := w.options.NewPacketBuffer()
var readN int
var from windows.Sockaddr
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
if readN > 0 {
buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
}
if w.readErr == windows.WSAEWOULDBLOCK {
return false
}
if from != nil {
switch fromAddr := from.(type) {
case *windows.SockaddrInet4:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *windows.SockaddrInet6:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
}
w.hasData = false
return true
}
return false
}
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if w.readFunc == nil {
return nil, M.Socksaddr{}, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
return
}
if w.readErr != nil {
err = E.Cause(w.readErr, "raw read")
return
}
buffer = w.buffer
w.buffer = nil
destination = w.readFrom
return
}

View file

@ -14,18 +14,18 @@ type Conn struct {
reader Reader
}
func NewConn(conn net.Conn) *Conn {
func NewConn(conn net.Conn) N.ExtendedConn {
if deadlineConn, isDeadline := conn.(*Conn); isDeadline {
return deadlineConn
}
return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)}
return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)})
}
func NewFallbackConn(conn net.Conn) *Conn {
func NewFallbackConn(conn net.Conn) N.ExtendedConn {
if deadlineConn, isDeadline := conn.(*Conn); isDeadline {
return deadlineConn
}
return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)}
return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)})
}
func (c *Conn) Read(p []byte) (n int, err error) {

View file

@ -14,18 +14,18 @@ type PacketConn struct {
reader PacketReader
}
func NewPacketConn(conn N.NetPacketConn) *PacketConn {
func NewPacketConn(conn N.NetPacketConn) N.NetPacketConn {
if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline {
return deadlineConn
}
return &PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)}
return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)})
}
func NewFallbackPacketConn(conn N.NetPacketConn) *PacketConn {
func NewFallbackPacketConn(conn N.NetPacketConn) N.NetPacketConn {
if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline {
return deadlineConn
}
return &PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)}
return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)})
}
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {

View file

@ -52,14 +52,13 @@ func (r *packetReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
default:
}
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
case <-r.done:
go r.pipeReadFrom(len(p))
default:
}
return r.readFrom(p)
}
func (r *packetReader) readFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
@ -106,14 +105,13 @@ func (r *packetReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr,
default:
}
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
case <-r.done:
go r.pipeReadFromBuffer(buffer.FreeLen())
default:
go r.pipeReadFrom(buffer.FreeLen())
}
return r.readPacket(buffer)
}
func (r *packetReader) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
@ -134,17 +132,6 @@ func (r *packetReader) pipeReturnFromBuffer(result *packetReadResult, buffer *bu
}
}
func (r *packetReader) pipeReadFromBuffer(pLen int) {
buffer := buf.NewSize(pLen)
destination, err := r.TimeoutPacketReader.ReadPacket(buffer)
r.result <- &packetReadResult{
buffer: buffer,
destination: destination,
err: err,
}
r.done <- struct{}{}
}
func (r *packetReader) SetReadDeadline(t time.Time) error {
r.deadline.Store(t)
r.pipeDeadline.set(t)

View file

@ -2,6 +2,7 @@ package deadline
import (
"net"
"os"
"time"
"github.com/sagernet/sing/common/atomic"
@ -25,12 +26,15 @@ func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err err
return r.pipeReturnFrom(result, p)
default:
}
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
case <-r.done:
if r.disablePipe.Load() {
return r.TimeoutPacketReader.ReadFrom(p)
}
select {
case <-r.done:
if r.deadline.Load().IsZero() {
} else if r.deadline.Load().IsZero() {
r.done <- struct{}{}
r.inRead.Store(true)
defer r.inRead.Store(false)
@ -38,9 +42,13 @@ func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err err
return
}
go r.pipeReadFrom(len(p))
default:
}
return r.readFrom(p)
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
}
}
func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
@ -49,22 +57,29 @@ func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Soc
return r.pipeReturnFromBuffer(result, buffer)
default:
}
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
case <-r.done:
if r.disablePipe.Load() {
return r.TimeoutPacketReader.ReadPacket(buffer)
}
select {
case <-r.done:
if r.deadline.Load().IsZero() {
} else if r.deadline.Load().IsZero() {
r.done <- struct{}{}
r.inRead.Store(true)
defer r.inRead.Store(false)
destination, err = r.TimeoutPacketReader.ReadPacket(buffer)
return
}
go r.pipeReadFromBuffer(buffer.FreeLen())
default:
go r.pipeReadFrom(buffer.FreeLen())
}
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
return r.readPacket(buffer)
}
func (r *fallbackPacketReader) SetReadDeadline(t time.Time) error {

View file

@ -54,14 +54,13 @@ func (r *reader) Read(p []byte) (n int, err error) {
default:
}
select {
case result := <-r.result:
return r.pipeReturn(result, p)
case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
case <-r.done:
go r.pipeRead(len(p))
default:
}
return r.read(p)
}
func (r *reader) read(p []byte) (n int, err error) {
select {
case result := <-r.result:
return r.pipeReturn(result, p)
@ -99,14 +98,13 @@ func (r *reader) ReadBuffer(buffer *buf.Buffer) error {
default:
}
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
case <-r.done:
go r.pipeReadBuffer(buffer.FreeLen())
default:
go r.pipeRead(buffer.FreeLen())
}
return r.readBuffer(buffer)
}
func (r *reader) readBuffer(buffer *buf.Buffer) error {
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
@ -127,16 +125,6 @@ func (r *reader) pipeReturnBuffer(result *readResult, buffer *buf.Buffer) error
}
}
func (r *reader) pipeReadBuffer(pLen int) {
cacheBuffer := buf.NewSize(pLen)
err := r.ExtendedReader.ReadBuffer(cacheBuffer)
r.result <- &readResult{
buffer: cacheBuffer,
err: err,
}
r.done <- struct{}{}
}
func (r *reader) SetReadDeadline(t time.Time) error {
r.deadline.Store(t)
r.pipeDeadline.set(t)

View file

@ -1,6 +1,7 @@
package deadline
import (
"os"
"time"
"github.com/sagernet/sing/common/atomic"
@ -23,12 +24,15 @@ func (r *fallbackReader) Read(p []byte) (n int, err error) {
return r.pipeReturn(result, p)
default:
}
select {
case result := <-r.result:
return r.pipeReturn(result, p)
case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
case <-r.done:
if r.disablePipe.Load() {
return r.ExtendedReader.Read(p)
}
select {
case <-r.done:
if r.deadline.Load().IsZero() {
} else if r.deadline.Load().IsZero() {
r.done <- struct{}{}
r.inRead.Store(true)
defer r.inRead.Store(false)
@ -36,9 +40,13 @@ func (r *fallbackReader) Read(p []byte) (n int, err error) {
return
}
go r.pipeRead(len(p))
default:
}
return r.reader.read(p)
select {
case result := <-r.result:
return r.pipeReturn(result, p)
case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}
func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error {
@ -47,21 +55,28 @@ func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error {
return r.pipeReturnBuffer(result, buffer)
default:
}
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
case <-r.done:
if r.disablePipe.Load() {
return r.ExtendedReader.ReadBuffer(buffer)
}
select {
case <-r.done:
if r.deadline.Load().IsZero() {
} else if r.deadline.Load().IsZero() {
r.done <- struct{}{}
r.inRead.Store(true)
defer r.inRead.Store(false)
return r.ExtendedReader.ReadBuffer(buffer)
}
go r.pipeReadBuffer(buffer.FreeLen())
default:
go r.pipeRead(buffer.FreeLen())
}
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
}
return r.readBuffer(buffer)
}
func (r *fallbackReader) SetReadDeadline(t time.Time) error {

View file

@ -0,0 +1,75 @@
package deadline
import (
"net"
"sync"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/debug"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type SerialConn struct {
N.ExtendedConn
access sync.Mutex
}
func NewSerialConn(conn N.ExtendedConn) N.ExtendedConn {
if !debug.Enabled {
return conn
}
return &SerialConn{ExtendedConn: conn}
}
func (c *SerialConn) Read(p []byte) (n int, err error) {
if !c.access.TryLock() {
panic("concurrent read on deadline conn")
}
defer c.access.Unlock()
return c.ExtendedConn.Read(p)
}
func (c *SerialConn) ReadBuffer(buffer *buf.Buffer) error {
if !c.access.TryLock() {
panic("concurrent read on deadline conn")
}
defer c.access.Unlock()
return c.ExtendedConn.ReadBuffer(buffer)
}
func (c *SerialConn) Upstream() any {
return c.ExtendedConn
}
type SerialPacketConn struct {
N.NetPacketConn
access sync.Mutex
}
func NewSerialPacketConn(conn N.NetPacketConn) N.NetPacketConn {
if !debug.Enabled {
return conn
}
return &SerialPacketConn{NetPacketConn: conn}
}
func (c *SerialPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if !c.access.TryLock() {
panic("concurrent read on deadline conn")
}
defer c.access.Unlock()
return c.NetPacketConn.ReadFrom(p)
}
func (c *SerialPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if !c.access.TryLock() {
panic("concurrent read on deadline conn")
}
defer c.access.Unlock()
return c.NetPacketConn.ReadPacket(buffer)
}
func (c *SerialPacketConn) Upstream() any {
return c.NetPacketConn
}

View file

@ -3,6 +3,7 @@ package bufio
import (
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
@ -12,13 +13,17 @@ var _ N.NetPacketConn = (*FallbackPacketConn)(nil)
type FallbackPacketConn struct {
N.PacketConn
writer N.NetPacketWriter
}
func NewNetPacketConn(conn N.PacketConn) N.NetPacketConn {
if packetConn, loaded := conn.(N.NetPacketConn); loaded {
return packetConn
}
return &FallbackPacketConn{PacketConn: conn}
return &FallbackPacketConn{
PacketConn: conn,
writer: NewNetPacketWriter(conn),
}
}
func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
@ -36,11 +41,7 @@ func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error
}
func (c *FallbackPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
err = c.WritePacket(buf.As(p), M.SocksaddrFromNet(addr))
if err == nil {
n = len(p)
}
return
return c.writer.WriteTo(p, addr)
}
func (c *FallbackPacketConn) ReaderReplaceable() bool {
@ -54,3 +55,50 @@ func (c *FallbackPacketConn) WriterReplaceable() bool {
func (c *FallbackPacketConn) Upstream() any {
return c.PacketConn
}
func (c *FallbackPacketConn) UpstreamWriter() any {
return c.writer
}
var _ N.NetPacketWriter = (*FallbackPacketWriter)(nil)
type FallbackPacketWriter struct {
N.PacketWriter
frontHeadroom int
rearHeadroom int
}
func NewNetPacketWriter(writer N.PacketWriter) N.NetPacketWriter {
if packetWriter, loaded := writer.(N.NetPacketWriter); loaded {
return packetWriter
}
return &FallbackPacketWriter{
PacketWriter: writer,
frontHeadroom: N.CalculateFrontHeadroom(writer),
rearHeadroom: N.CalculateRearHeadroom(writer),
}
}
func (c *FallbackPacketWriter) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if c.frontHeadroom > 0 || c.rearHeadroom > 0 {
buffer := buf.NewSize(len(p) + c.frontHeadroom + c.rearHeadroom)
buffer.Resize(c.frontHeadroom, 0)
common.Must1(buffer.Write(p))
err = c.PacketWriter.WritePacket(buffer, M.SocksaddrFromNet(addr))
} else {
err = c.PacketWriter.WritePacket(buf.As(p), M.SocksaddrFromNet(addr))
}
if err != nil {
return
}
n = len(p)
return
}
func (c *FallbackPacketWriter) WriterReplaceable() bool {
return true
}
func (c *FallbackPacketWriter) Upstream() any {
return c.PacketWriter
}

View file

@ -37,13 +37,7 @@ func WriteBuffer(writer N.ExtendedWriter, buffer *buf.Buffer) (n int, err error)
frontHeadroom := N.CalculateFrontHeadroom(writer)
rearHeadroom := N.CalculateRearHeadroom(writer)
if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() {
bufferSize := N.CalculateMTU(nil, writer)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
newBuffer := buf.NewSize(bufferSize)
newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom)
newBuffer.Resize(frontHeadroom, 0)
common.Must1(newBuffer.Write(buffer.Bytes()))
buffer.Release()
@ -69,13 +63,7 @@ func WritePacketBuffer(writer N.PacketWriter, buffer *buf.Buffer, destination M.
frontHeadroom := N.CalculateFrontHeadroom(writer)
rearHeadroom := N.CalculateRearHeadroom(writer)
if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() {
bufferSize := N.CalculateMTU(nil, writer)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
newBuffer := buf.NewSize(bufferSize)
newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom)
newBuffer.Resize(frontHeadroom, 0)
common.Must1(newBuffer.Write(buffer.Bytes()))
buffer.Release()

View file

@ -9,54 +9,142 @@ import (
N "github.com/sagernet/sing/common/network"
)
type NATPacketConn struct {
type NATPacketConn interface {
N.NetPacketConn
UpdateDestination(destinationAddress netip.Addr)
}
func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &unidirectionalNATPacketConn{
NetPacketConn: conn,
origin: socksaddrWithoutPort(origin),
destination: socksaddrWithoutPort(destination),
}
}
func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &bidirectionalNATPacketConn{
NetPacketConn: conn,
origin: socksaddrWithoutPort(origin),
destination: socksaddrWithoutPort(destination),
}
}
type unidirectionalNATPacketConn struct {
N.NetPacketConn
origin M.Socksaddr
destination M.Socksaddr
}
func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) *NATPacketConn {
return &NATPacketConn{
NetPacketConn: conn,
origin: origin,
destination: destination,
func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
}
func (c *NATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p)
if err == nil && M.SocksaddrFromNet(addr) == c.origin {
addr = c.destination.UDPAddr()
func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
return
}
func (c *NATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if M.SocksaddrFromNet(addr) == c.destination {
addr = c.origin.UDPAddr()
}
return c.NetPacketConn.WriteTo(p, addr)
}
func (c *NATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer)
if destination == c.origin {
destination = c.destination
}
return
}
func (c *NATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination == c.destination {
destination = c.origin
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
func (c *NATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
}
func (c *NATPacketConn) Upstream() any {
func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func (c *unidirectionalNATPacketConn) Upstream() any {
return c.NetPacketConn
}
type bidirectionalNATPacketConn struct {
N.NetPacketConn
origin M.Socksaddr
destination M.Socksaddr
}
func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
addr = destination.UDPAddr()
return
}
func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
}
func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return
}
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
return
}
func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
}
func (c *bidirectionalNATPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
destination.Port = 0
return destination
}

39
common/bufio/nat_wait.go Normal file
View file

@ -0,0 +1,39 @@
package bufio
import (
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func (c *bidirectionalNATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaiter, bool) {
waiter, created := CreatePacketReadWaiter(c.NetPacketConn)
if !created {
return nil, false
}
return &waitBidirectionalNATPacketConn{c, waiter}, true
}
type waitBidirectionalNATPacketConn struct {
*bidirectionalNATPacketConn
readWaiter N.PacketReadWaiter
}
func (c *waitBidirectionalNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return c.readWaiter.InitializeReadWaiter(options)
}
func (c *waitBidirectionalNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, destination, err = c.readWaiter.WaitReadPacket()
if err != nil {
return
}
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
return
}

277
common/bufio/net_test.go Normal file
View file

@ -0,0 +1,277 @@
package bufio
import (
"context"
"crypto/md5"
"crypto/rand"
"errors"
"io"
"net"
"sync"
"testing"
"time"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/task"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
var (
group task.Group
serverConn net.Conn
clientConn net.Conn
)
group.Append0(func(ctx context.Context) error {
var serverErr error
serverConn, serverErr = listener.Accept()
return serverErr
})
group.Append0(func(ctx context.Context) error {
var clientErr error
clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
return clientErr
})
err = group.Run()
require.NoError(t, err)
listener.Close()
t.Cleanup(func() {
serverConn.Close()
clientConn.Close()
})
return serverConn, clientConn
}
func UDPPipe(t *testing.T) (net.PacketConn, net.PacketConn, M.Socksaddr) {
serverConn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
clientConn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
return serverConn, clientConn, M.SocksaddrFromNet(clientConn.LocalAddr())
}
func Timeout(t *testing.T) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-ctx.Done():
return
case <-time.After(5 * time.Second):
t.Error("timeout")
}
}()
return cancel
}
type hashPair struct {
sendHash map[int][]byte
recvHash map[int][]byte
}
func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) {
pingCh := make(chan hashPair)
pongCh := make(chan hashPair)
test := func(t *testing.T) error {
defer close(pingCh)
defer close(pongCh)
pingOpen := false
pongOpen := false
var serverPair hashPair
var clientPair hashPair
for {
if pingOpen && pongOpen {
break
}
select {
case serverPair, pingOpen = <-pingCh:
assert.True(t, pingOpen)
case clientPair, pongOpen = <-pongCh:
assert.True(t, pongOpen)
case <-time.After(10 * time.Second):
return errors.New("timeout")
}
}
assert.Equal(t, serverPair.recvHash, clientPair.sendHash)
assert.Equal(t, serverPair.sendHash, clientPair.recvHash)
return nil
}
return pingCh, pongCh, test
}
func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error {
times := 100
chunkSize := int64(64 * 1024)
pingCh, pongCh, test := newLargeDataPair()
writeRandData := func(conn net.Conn) (map[int][]byte, error) {
buf := make([]byte, chunkSize)
hashMap := map[int][]byte{}
for i := 0; i < times; i++ {
if _, err := rand.Read(buf[1:]); err != nil {
return nil, err
}
buf[0] = byte(i)
hash := md5.Sum(buf)
hashMap[i] = hash[:]
if _, err := conn.Write(buf); err != nil {
return nil, err
}
}
return hashMap, nil
}
go func() {
hashMap := map[int][]byte{}
buf := make([]byte, chunkSize)
for i := 0; i < times; i++ {
_, err := io.ReadFull(outputConn, buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf)
hashMap[int(buf[0])] = hash[:]
}
sendHash, err := writeRandData(outputConn)
if err != nil {
t.Log(err.Error())
return
}
pingCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
go func() {
sendHash, err := writeRandData(inputConn)
if err != nil {
t.Log(err.Error())
return
}
hashMap := map[int][]byte{}
buf := make([]byte, chunkSize)
for i := 0; i < times; i++ {
_, err = io.ReadFull(inputConn, buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf)
hashMap[int(buf[0])] = hash[:]
}
pongCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
return test(t)
}
func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error {
rAddr := outputAddr.UDPAddr()
times := 50
chunkSize := 9000
pingCh, pongCh, test := newLargeDataPair()
writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) {
hashMap := map[int][]byte{}
mux := sync.Mutex{}
for i := 0; i < times; i++ {
buf := make([]byte, chunkSize)
if _, err := rand.Read(buf[1:]); err != nil {
t.Log(err.Error())
continue
}
buf[0] = byte(i)
hash := md5.Sum(buf)
mux.Lock()
hashMap[i] = hash[:]
mux.Unlock()
if _, err := pc.WriteTo(buf, addr); err != nil {
t.Log(err.Error())
}
time.Sleep(10 * time.Millisecond)
}
return hashMap, nil
}
go func() {
var (
lAddr net.Addr
err error
)
hashMap := map[int][]byte{}
buf := make([]byte, 64*1024)
for i := 0; i < times; i++ {
_, lAddr, err = outputConn.ReadFrom(buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf[:chunkSize])
hashMap[int(buf[0])] = hash[:]
}
sendHash, err := writeRandData(outputConn, lAddr)
if err != nil {
t.Log(err.Error())
return
}
pingCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
go func() {
sendHash, err := writeRandData(inputConn, rAddr)
if err != nil {
t.Log(err.Error())
return
}
hashMap := map[int][]byte{}
buf := make([]byte, 64*1024)
for i := 0; i < times; i++ {
_, _, err := inputConn.ReadFrom(buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf[:chunkSize])
hashMap[int(buf[0])] = hash[:]
}
pongCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
return test(t)
}

View file

@ -1,127 +0,0 @@
package bufio
import (
"io"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
)
func CopyTimes(dst io.Writer, src io.Reader, times int) (n int64, err error) {
return CopyExtendedTimes(NewExtendedWriter(N.UnwrapWriter(dst)), NewExtendedReader(N.UnwrapReader(src)), times)
}
func CopyExtendedTimes(dst N.ExtendedWriter, src N.ExtendedReader, times int) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
dstUnsafe := N.IsUnsafeWriter(dst)
var buffer *buf.Buffer
if !dstUnsafe {
_buffer := buf.StackNewSize(bufferSize)
defer common.KeepAlive(_buffer)
buffer = common.Dup(_buffer)
defer buffer.Release()
buffer.IncRef()
defer buffer.DecRef()
}
notFirstTime := true
for i := 0; i < times; i++ {
if dstUnsafe {
buffer = buf.NewSize(bufferSize)
}
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
err = src.ReadBuffer(readBuffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
return
}
type ReadFromWriter interface {
io.ReaderFrom
io.Writer
}
func ReadFrom0(readerFrom ReadFromWriter, reader io.Reader) (n int64, err error) {
n, err = CopyTimes(readerFrom, reader, 1)
if err != nil {
return
}
var rn int64
rn, err = readerFrom.ReadFrom(reader)
if err != nil {
return
}
n += rn
return
}
func ReadFromN(readerFrom ReadFromWriter, reader io.Reader, times int) (n int64, err error) {
n, err = CopyTimes(readerFrom, reader, times)
if err != nil {
return
}
var rn int64
rn, err = readerFrom.ReadFrom(reader)
if err != nil {
return
}
n += rn
return
}
type WriteToReader interface {
io.WriterTo
io.Reader
}
func WriteTo0(writerTo WriteToReader, writer io.Writer) (n int64, err error) {
n, err = CopyTimes(writer, writerTo, 1)
if err != nil {
return
}
var wn int64
wn, err = writerTo.WriteTo(writer)
if err != nil {
return
}
n += wn
return
}
func WriteToN(writerTo WriteToReader, writer io.Writer, times int) (n int64, err error) {
n, err = CopyTimes(writer, writerTo, times)
if err != nil {
return
}
var wn int64
wn, err = writerTo.WriteTo(writer)
if err != nil {
return
}
n += wn
return
}

View file

@ -33,10 +33,10 @@ func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) {
case syscall.Conn:
rawConn, err := w.SyscallConn()
if err == nil {
return &SyscallVectorisedWriter{writer, rawConn}, true
return &SyscallVectorisedWriter{upstream: writer, rawConn: rawConn}, true
}
case syscall.RawConn:
return &SyscallVectorisedWriter{writer, w}, true
return &SyscallVectorisedWriter{upstream: writer, rawConn: w}, true
}
return nil, false
}
@ -48,10 +48,10 @@ func CreateVectorisedPacketWriter(writer any) (N.VectorisedPacketWriter, bool) {
case syscall.Conn:
rawConn, err := w.SyscallConn()
if err == nil {
return &SyscallVectorisedPacketWriter{writer, rawConn}, true
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: rawConn}, true
}
case syscall.RawConn:
return &SyscallVectorisedPacketWriter{writer, w}, true
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: w}, true
}
return nil, false
}
@ -74,9 +74,7 @@ func (w *BufferedVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error
if bufferLen > 65535 {
bufferBytes = make([]byte, bufferLen)
} else {
_buffer := buf.StackNewSize(bufferLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
buffer := buf.NewSize(bufferLen)
defer buffer.Release()
bufferBytes = buffer.FreeBytes()
}
@ -113,6 +111,7 @@ var _ N.VectorisedWriter = (*SyscallVectorisedWriter)(nil)
type SyscallVectorisedWriter struct {
upstream any
rawConn syscall.RawConn
syscallVectorisedWriterFields
}
func (w *SyscallVectorisedWriter) Upstream() any {
@ -128,6 +127,7 @@ var _ N.VectorisedPacketWriter = (*SyscallVectorisedPacketWriter)(nil)
type SyscallVectorisedPacketWriter struct {
upstream any
rawConn syscall.RawConn
syscallVectorisedWriterFields
}
func (w *SyscallVectorisedPacketWriter) Upstream() any {

View file

@ -0,0 +1,60 @@
package bufio
import (
"crypto/rand"
"io"
"testing"
"github.com/stretchr/testify/require"
)
func TestWriteVectorised(t *testing.T) {
t.Parallel()
inputConn, outputConn := TCPPipe(t)
vectorisedWriter, created := CreateVectorisedWriter(inputConn)
require.True(t, created)
require.NotNil(t, vectorisedWriter)
var bufA [1024]byte
var bufB [1024]byte
var bufC [2048]byte
_, err := io.ReadFull(rand.Reader, bufA[:])
require.NoError(t, err)
_, err = io.ReadFull(rand.Reader, bufB[:])
require.NoError(t, err)
copy(bufC[:], bufA[:])
copy(bufC[1024:], bufB[:])
finish := Timeout(t)
_, err = WriteVectorised(vectorisedWriter, [][]byte{bufA[:], bufB[:]})
require.NoError(t, err)
output := make([]byte, 2048)
_, err = io.ReadFull(outputConn, output)
finish()
require.NoError(t, err)
require.Equal(t, bufC[:], output)
}
func TestWriteVectorisedPacket(t *testing.T) {
t.Parallel()
inputConn, outputConn, outputAddr := UDPPipe(t)
vectorisedWriter, created := CreateVectorisedPacketWriter(inputConn)
require.True(t, created)
require.NotNil(t, vectorisedWriter)
var bufA [1024]byte
var bufB [1024]byte
var bufC [2048]byte
_, err := io.ReadFull(rand.Reader, bufA[:])
require.NoError(t, err)
_, err = io.ReadFull(rand.Reader, bufB[:])
require.NoError(t, err)
copy(bufC[:], bufA[:])
copy(bufC[1024:], bufB[:])
finish := Timeout(t)
_, err = WriteVectorisedPacket(vectorisedWriter, [][]byte{bufA[:], bufB[:]}, outputAddr)
require.NoError(t, err)
output := make([]byte, 2048)
n, _, err := outputConn.ReadFrom(output)
finish()
require.NoError(t, err)
require.Equal(t, 2048, n)
require.Equal(t, bufC[:], output)
}

View file

@ -3,6 +3,8 @@
package bufio
import (
"os"
"sync"
"unsafe"
"github.com/sagernet/sing/common/buf"
@ -11,15 +13,28 @@ import (
"golang.org/x/sys/unix"
)
type syscallVectorisedWriterFields struct {
access sync.Mutex
iovecList *[]unix.Iovec
}
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
iovecList := make([]unix.Iovec, 0, len(buffers))
for _, buffer := range buffers {
var iovec unix.Iovec
iovec.Base = &buffer.Bytes()[0]
iovec.SetLen(buffer.Len())
iovecList = append(iovecList, iovec)
var iovecList []unix.Iovec
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for index, buffer := range buffers {
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
iovecList[index].SetLen(buffer.Len())
}
if w.iovecList == nil {
w.iovecList = new([]unix.Iovec)
}
*w.iovecList = iovecList // cache
var innerErr unix.Errno
err := w.rawConn.Write(func(fd uintptr) (done bool) {
//nolint:staticcheck
@ -28,32 +43,52 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
})
if innerErr != 0 {
err = innerErr
err = os.NewSyscallError("SYS_WRITEV", innerErr)
}
for index := range iovecList {
iovecList[index] = unix.Iovec{}
}
return err
}
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
var sockaddr unix.Sockaddr
if destination.IsIPv4() {
sockaddr = &unix.SockaddrInet4{
Port: int(destination.Port),
Addr: destination.Addr.As4(),
var iovecList []unix.Iovec
if w.iovecList != nil {
iovecList = *w.iovecList
}
} else {
sockaddr = &unix.SockaddrInet6{
Port: int(destination.Port),
Addr: destination.Addr.As16(),
iovecList = iovecList[:0]
for index, buffer := range buffers {
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
iovecList[index].SetLen(buffer.Len())
}
if w.iovecList == nil {
w.iovecList = new([]unix.Iovec)
}
*w.iovecList = iovecList // cache
var innerErr error
err := w.rawConn.Write(func(fd uintptr) (done bool) {
_, innerErr = unix.SendmsgBuffers(int(fd), buf.ToSliceMulti(buffers), nil, sockaddr, 0)
var msg unix.Msghdr
name, nameLen := ToSockaddr(destination.AddrPort())
msg.Name = (*byte)(name)
msg.Namelen = nameLen
if len(iovecList) > 0 {
msg.Iov = &iovecList[0]
msg.SetIovlen(len(iovecList))
}
_, innerErr = sendmsg(int(fd), &msg, 0)
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
for index := range iovecList {
iovecList[index] = unix.Iovec{}
}
return err
}
//go:linkname sendmsg golang.org/x/sys/unix.sendmsg
func sendmsg(s int, msg *unix.Msghdr, flags int) (n int, err error)

View file

@ -1,62 +1,93 @@
package bufio
import (
"sync"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/windows"
)
type syscallVectorisedWriterFields struct {
access sync.Mutex
iovecList *[]windows.WSABuf
}
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
iovecList := make([]*windows.WSABuf, len(buffers))
for i, buffer := range buffers {
iovecList[i] = &windows.WSABuf{
Len: uint32(buffer.Len()),
var iovecList []windows.WSABuf
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for _, buffer := range buffers {
iovecList = append(iovecList, windows.WSABuf{
Buf: &buffer.Bytes()[0],
Len: uint32(buffer.Len()),
})
}
if w.iovecList == nil {
w.iovecList = new([]windows.WSABuf)
}
*w.iovecList = iovecList // cache
var n uint32
var innerErr error
err := w.rawConn.Write(func(fd uintptr) (done bool) {
innerErr = windows.WSASend(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
innerErr = windows.WSASend(windows.Handle(fd), &iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
return innerErr != windows.WSAEWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
for index := range iovecList {
iovecList[index] = windows.WSABuf{}
}
return err
}
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
iovecList := make([]*windows.WSABuf, len(buffers))
for i, buffer := range buffers {
iovecList[i] = &windows.WSABuf{
Len: uint32(buffer.Len()),
var iovecList []windows.WSABuf
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for _, buffer := range buffers {
iovecList = append(iovecList, windows.WSABuf{
Buf: &buffer.Bytes()[0],
Len: uint32(buffer.Len()),
})
}
if w.iovecList == nil {
w.iovecList = new([]windows.WSABuf)
}
var sockaddr windows.Sockaddr
if destination.IsIPv4() {
sockaddr = &windows.SockaddrInet4{
Port: int(destination.Port),
Addr: destination.Addr.As4(),
}
} else {
sockaddr = &windows.SockaddrInet6{
Port: int(destination.Port),
Addr: destination.Addr.As16(),
}
}
*w.iovecList = iovecList // cache
var n uint32
var innerErr error
err := w.rawConn.Write(func(fd uintptr) (done bool) {
innerErr = windows.WSASendto(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, sockaddr, nil, nil)
name, nameLen := ToSockaddr(destination.AddrPort())
innerErr = windows.WSASendTo(
windows.Handle(fd),
&iovecList[0],
uint32(len(iovecList)),
&n,
0,
(*windows.RawSockaddrAny)(name),
nameLen,
nil,
nil)
return innerErr != windows.WSAEWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
for index := range iovecList {
iovecList[index] = windows.WSABuf{}
}
return err
}

View file

@ -258,6 +258,14 @@ func (c *LruCache[K, V]) Delete(key K) {
c.mu.Unlock()
}
func (c *LruCache[K, V]) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
for element := c.lru.Front(); element != nil; element = element.Next() {
c.deleteElement(element)
}
}
func (c *LruCache[K, V]) maybeDeleteOldest() {
if !c.staleReturn && c.maxAge > 0 {
now := time.Now().Unix()

View file

@ -21,13 +21,13 @@ type TimerPacketConn struct {
instance *Instance
}
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) {
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
oldTimeout := timeoutConn.Timeout()
if timeout < oldTimeout {
timeoutConn.SetTimeout(timeout)
}
return ctx, timeoutConn
return ctx, conn
}
err := conn.SetReadDeadline(time.Time{})
if err == nil {

View file

@ -2,6 +2,7 @@ package canceler
import (
"context"
"net"
"time"
"github.com/sagernet/sing/common"
@ -31,7 +32,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
for {
err = c.PacketConn.SetReadDeadline(time.Now().Add(c.timeout))
if err != nil {
return M.Socksaddr{}, err
return
}
destination, err = c.PacketConn.ReadPacket(buffer)
if err == nil {
@ -43,7 +44,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
return
}
} else {
return M.Socksaddr{}, err
return
}
}
}
@ -66,6 +67,7 @@ func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
}
func (c *TimeoutPacketConn) Close() error {
c.cancel(net.ErrClosed)
return c.PacketConn.Close()
}

11
common/clear.go Normal file
View file

@ -0,0 +1,11 @@
//go:build go1.21
package common
func ClearArray[T ~[]E, E any](t T) {
clear(t)
}
func ClearMap[T ~map[K]V, K comparable, V any](t T) {
clear(t)
}

16
common/clear_compat.go Normal file
View file

@ -0,0 +1,16 @@
//go:build !go1.21
package common
func ClearArray[T ~[]E, E any](t T) {
var defaultValue E
for i := range t {
t[i] = defaultValue
}
}
func ClearMap[T ~map[K]V, K comparable, V any](t T) {
for k := range t {
delete(t, k)
}
}

View file

@ -159,20 +159,14 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
//go:norace
func Dup[T any](obj T) T {
if UnsafeBuffer {
pointer := uintptr(unsafe.Pointer(&obj))
//nolint:staticcheck
//goland:noinspection GoVetUnsafePointer
return *(*T)(unsafe.Pointer(pointer))
} else {
return obj
}
}
func KeepAlive(obj any) {
if UnsafeBuffer {
runtime.KeepAlive(obj)
}
}
func Uniq[T comparable](arr []T) []T {
@ -342,6 +336,10 @@ func DefaultValue[T any]() T {
return defaultValue
}
func Ptr[T any](obj T) *T {
return &obj
}
func Close(closers ...any) error {
var retErr error
for _, closer := range closers {

View file

@ -1,59 +1,35 @@
package control
import (
"os"
"runtime"
"syscall"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func BindToInterface(finder InterfaceFinder, interfaceName string, interfaceIndex int) Func {
return func(network, address string, conn syscall.RawConn) error {
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false)
}
}
func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int)) Func {
func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int, err error)) Func {
return func(network, address string, conn syscall.RawConn) error {
interfaceName, interfaceIndex := block(network, address)
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
}
}
const useInterfaceName = runtime.GOOS == "linux" || runtime.GOOS == "android"
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) {
return nil
}
if interfaceName == "" && interfaceIndex == -1 {
return nil
}
if interfaceName != "" && useInterfaceName || interfaceIndex != -1 && !useInterfaceName {
return bindToInterface(conn, network, address, interfaceName, interfaceIndex)
}
if finder == nil {
return os.ErrInvalid
}
var err error
if useInterfaceName {
interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex)
} else {
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
}
interfaceName, interfaceIndex, err := block(network, address)
if err != nil {
return err
}
if useInterfaceName {
if interfaceName == "" {
return nil
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false)
}
} else {
if interfaceIndex == -1 {
return nil
}
}
return bindToInterface(conn, network, address, interfaceName, interfaceIndex)
}
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
if interfaceName == "" && interfaceIndex == -1 {
return E.New("interface not found: ", interfaceName)
}
if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) {
return nil
}
return bindToInterface(conn, network, address, finder, interfaceName, interfaceIndex, preferInterfaceName)
}

View file

@ -1,16 +1,24 @@
package control
import (
"os"
"syscall"
"golang.org/x/sys/unix"
)
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
if interfaceIndex == -1 {
return nil
}
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
var err error
if interfaceIndex == -1 {
if finder == nil {
return os.ErrInvalid
}
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
if err != nil {
return err
}
}
switch network {
case "tcp6", "udp6":
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, interfaceIndex)

View file

@ -1,30 +1,21 @@
package control
import "net"
import (
"net"
"net/netip"
)
type InterfaceFinder interface {
Interfaces() []Interface
InterfaceIndexByName(name string) (int, error)
InterfaceNameByIndex(index int) (string, error)
InterfaceByAddr(addr netip.Addr) (*Interface, error)
}
func DefaultInterfaceFinder() InterfaceFinder {
return (*netInterfaceFinder)(nil)
}
type netInterfaceFinder struct{}
func (w *netInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
netInterface, err := net.InterfaceByName(name)
if err != nil {
return 0, err
}
return netInterface.Index, nil
}
func (w *netInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
netInterface, err := net.InterfaceByIndex(index)
if err != nil {
return "", err
}
return netInterface.Name, nil
type Interface struct {
Index int
MTU int
Name string
Addresses []netip.Prefix
HardwareAddr net.HardwareAddr
}

View file

@ -0,0 +1,104 @@
package control
import (
"net"
"net/netip"
_ "unsafe"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
)
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
type DefaultInterfaceFinder struct {
interfaces []Interface
}
func NewDefaultInterfaceFinder() *DefaultInterfaceFinder {
return &DefaultInterfaceFinder{}
}
func (f *DefaultInterfaceFinder) Update() error {
netIfs, err := net.Interfaces()
if err != nil {
return err
}
interfaces := make([]Interface, 0, len(netIfs))
for _, netIf := range netIfs {
ifAddrs, err := netIf.Addrs()
if err != nil {
return err
}
interfaces = append(interfaces, Interface{
Index: netIf.Index,
MTU: netIf.MTU,
Name: netIf.Name,
Addresses: common.Map(ifAddrs, M.PrefixFromNet),
HardwareAddr: netIf.HardwareAddr,
})
}
f.interfaces = interfaces
return nil
}
func (f *DefaultInterfaceFinder) UpdateInterfaces(interfaces []Interface) {
f.interfaces = interfaces
}
func (f *DefaultInterfaceFinder) Interfaces() []Interface {
return f.interfaces
}
func (f *DefaultInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
for _, netInterface := range f.interfaces {
if netInterface.Name == name {
return netInterface.Index, nil
}
}
netInterface, err := net.InterfaceByName(name)
if err != nil {
return 0, err
}
f.Update()
return netInterface.Index, nil
}
func (f *DefaultInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
for _, netInterface := range f.interfaces {
if netInterface.Index == index {
return netInterface.Name, nil
}
}
netInterface, err := net.InterfaceByIndex(index)
if err != nil {
return "", err
}
f.Update()
return netInterface.Name, nil
}
//go:linkname errNoSuchInterface net.errNoSuchInterface
var errNoSuchInterface error
func (f *DefaultInterfaceFinder) InterfaceByAddr(addr netip.Addr) (*Interface, error) {
for _, netInterface := range f.interfaces {
for _, prefix := range netInterface.Addresses {
if prefix.Contains(addr) {
return &netInterface, nil
}
}
}
err := f.Update()
if err != nil {
return nil, err
}
for _, netInterface := range f.interfaces {
for _, prefix := range netInterface.Addresses {
if prefix.Contains(addr) {
return &netInterface, nil
}
}
}
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: addr.AsSlice()}, Err: errNoSuchInterface}
}

View file

@ -1,13 +1,42 @@
package control
import (
"os"
"syscall"
"github.com/sagernet/sing/common/atomic"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/sys/unix"
)
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
var ifIndexDisabled atomic.Bool
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
if !preferInterfaceName && !ifIndexDisabled.Load() {
if interfaceIndex == -1 {
if interfaceName == "" {
return os.ErrInvalid
}
var err error
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
if err != nil {
return err
}
}
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
if err == nil {
return nil
} else if E.IsMulti(err, unix.ENOPROTOOPT, unix.EINVAL) {
ifIndexDisabled.Store(true)
} else {
return err
}
}
if interfaceName == "" {
return os.ErrInvalid
}
return unix.BindToDevice(int(fd), interfaceName)
})
}

View file

@ -4,6 +4,6 @@ package control
import "syscall"
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return nil
}

View file

@ -2,17 +2,28 @@ package control
import (
"encoding/binary"
"os"
"syscall"
"unsafe"
M "github.com/sagernet/sing/common/metadata"
)
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
var err error
if interfaceIndex == -1 {
if finder == nil {
return os.ErrInvalid
}
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
if err != nil {
return err
}
}
handle := syscall.Handle(fd)
if M.ParseSocksaddr(address).AddrString() == "" {
err := bind4(handle, interfaceIndex)
err = bind4(handle, interfaceIndex)
if err != nil {
return err
}

View file

@ -3,6 +3,7 @@ package control
import (
"syscall"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
@ -30,6 +31,14 @@ func Conn(conn syscall.Conn, block func(fd uintptr) error) error {
return Raw(rawConn, block)
}
func Conn0[T any](conn syscall.Conn, block func(fd uintptr) (T, error)) (T, error) {
rawConn, err := conn.SyscallConn()
if err != nil {
return common.DefaultValue[T](), err
}
return Raw0[T](rawConn, block)
}
func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
var innerErr error
err := rawConn.Control(func(fd uintptr) {
@ -37,3 +46,14 @@ func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
})
return E.Errors(innerErr, err)
}
func Raw0[T any](rawConn syscall.RawConn, block func(fd uintptr) (T, error)) (T, error) {
var (
value T
innerErr error
)
err := rawConn.Control(func(fd uintptr) {
value, innerErr = block(fd)
})
return value, E.Errors(innerErr, err)
}

View file

@ -4,10 +4,10 @@ import (
"syscall"
)
func RoutingMark(mark int) Func {
func RoutingMark(mark uint32) Func {
return func(network, address string, conn syscall.RawConn) error {
return Raw(conn, func(fd uintptr) error {
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark)
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(mark))
})
}
}

View file

@ -2,6 +2,6 @@
package control
func RoutingMark(mark int) Func {
func RoutingMark(mark uint32) Func {
return nil
}

View file

@ -0,0 +1,58 @@
package control
import (
"encoding/binary"
"net"
"net/netip"
"syscall"
"unsafe"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
const (
PF_OUT = 0x2
DIOCNATLOOK = 0xc0544417
)
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
pfFd, err := syscall.Open("/dev/pf", 0, syscall.O_RDONLY)
if err != nil {
return netip.AddrPort{}, err
}
defer syscall.Close(pfFd)
nl := struct {
saddr, daddr, rsaddr, rdaddr [16]byte
sxport, dxport, rsxport, rdxport [4]byte
af, proto, protoVariant, direction uint8
}{
af: syscall.AF_INET,
proto: syscall.IPPROTO_TCP,
direction: PF_OUT,
}
localAddr := M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
removeAddr := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
if localAddr.IsIPv4() {
copy(nl.saddr[:net.IPv4len], removeAddr.Addr.AsSlice())
copy(nl.daddr[:net.IPv4len], localAddr.Addr.AsSlice())
nl.af = syscall.AF_INET
} else {
copy(nl.saddr[:], removeAddr.Addr.AsSlice())
copy(nl.daddr[:], localAddr.Addr.AsSlice())
nl.af = syscall.AF_INET6
}
binary.BigEndian.PutUint16(nl.sxport[:], removeAddr.Port)
binary.BigEndian.PutUint16(nl.dxport[:], localAddr.Port)
if _, _, errno := unix.Syscall(syscall.SYS_IOCTL, uintptr(pfFd), DIOCNATLOOK, uintptr(unsafe.Pointer(&nl))); errno != 0 {
return netip.AddrPort{}, errno
}
var address netip.Addr
if nl.af == unix.AF_INET {
address = M.AddrFromIP(nl.rdaddr[:net.IPv4len])
} else {
address = netip.AddrFrom16(nl.rdaddr)
}
return netip.AddrPortFrom(address, binary.BigEndian.Uint16(nl.rdxport[:])), nil
}

View file

@ -0,0 +1,38 @@
package control
import (
"encoding/binary"
"net"
"net/netip"
"os"
"syscall"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
syscallConn, loaded := common.Cast[syscall.Conn](conn)
if !loaded {
return netip.AddrPort{}, os.ErrInvalid
}
return Conn0[netip.AddrPort](syscallConn, func(fd uintptr) (netip.AddrPort, error) {
if M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap().IsIPv4() {
raw, err := unix.GetsockoptIPv6Mreq(int(fd), unix.IPPROTO_IP, unix.SO_ORIGINAL_DST)
if err != nil {
return netip.AddrPort{}, err
}
return netip.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
} else {
raw, err := unix.GetsockoptIPv6MTUInfo(int(fd), unix.IPPROTO_IPV6, unix.SO_ORIGINAL_DST)
if err != nil {
return netip.AddrPort{}, err
}
var port [2]byte
binary.BigEndian.PutUint16(port[:], raw.Addr.Port)
return netip.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), binary.LittleEndian.Uint16(port[:])), nil
}
})
}

View file

@ -0,0 +1,13 @@
//go:build !linux && !darwin
package control
import (
"net"
"net/netip"
"os"
)
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
return netip.AddrPort{}, os.ErrInvalid
}

View file

@ -0,0 +1,30 @@
package control
import (
"syscall"
"time"
_ "unsafe"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
)
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
return func(network, address string, conn syscall.RawConn) error {
if N.NetworkName(network) != N.NetworkTCP {
return nil
}
return Raw(conn, func(fd uintptr) error {
return E.Errors(
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPIDLE, int(roundDurationUp(idle, time.Second))),
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPINTVL, int(roundDurationUp(interval, time.Second))),
)
})
}
}
func roundDurationUp(d time.Duration, to time.Duration) time.Duration {
return (d + to - 1) / to
}

View file

@ -0,0 +1,11 @@
//go:build !linux
package control
import (
"time"
)
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
return nil
}

View file

@ -0,0 +1,56 @@
package control
import (
"encoding/binary"
"net/netip"
"syscall"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
func TProxy(fd uintptr, family int) error {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
if err == nil {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
}
if err == nil && family == unix.AF_INET6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
}
if err == nil {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
}
if err == nil && family == unix.AF_INET6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
}
return err
}
func TProxyWriteBack() Func {
return func(network, address string, conn syscall.RawConn) error {
return Raw(conn, func(fd uintptr) error {
if M.ParseSocksaddr(address).Addr.Is6() {
return syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
} else {
return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
}
})
}
}
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
controlMessages, err := unix.ParseSocketControlMessage(oob)
if err != nil {
return netip.AddrPort{}, err
}
for _, message := range controlMessages {
if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR {
return netip.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil
} else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR {
return netip.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil
}
}
return netip.AddrPort{}, E.New("not found")
}

View file

@ -0,0 +1,20 @@
//go:build !linux
package control
import (
"net/netip"
"os"
)
func TProxy(fd uintptr, isIPv6 bool) error {
return os.ErrInvalid
}
func TProxyWriteBack() Func {
return nil
}
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
return netip.AddrPort{}, os.ErrInvalid
}

View file

@ -1,8 +1,12 @@
package domain
import (
"encoding/binary"
"io"
"sort"
"unicode/utf8"
"github.com/sagernet/sing/common/rw"
)
type Matcher struct {
@ -10,14 +14,19 @@ type Matcher struct {
}
func NewMatcher(domains []string, domainSuffix []string) *Matcher {
domainList := make([]string, 0, len(domains)+len(domainSuffix))
domainList := make([]string, 0, len(domains)+2*len(domainSuffix))
seen := make(map[string]bool, len(domainList))
for _, domain := range domainSuffix {
if seen[domain] {
continue
}
seen[domain] = true
if domain[0] == '.' {
domainList = append(domainList, reverseDomainSuffix(domain))
} else {
domainList = append(domainList, reverseDomain(domain))
domainList = append(domainList, reverseRootDomainSuffix(domain))
}
}
for _, domain := range domains {
if seen[domain] {
@ -27,15 +36,87 @@ func NewMatcher(domains []string, domainSuffix []string) *Matcher {
domainList = append(domainList, reverseDomain(domain))
}
sort.Strings(domainList)
return &Matcher{
newSuccinctSet(domainList),
return &Matcher{newSuccinctSet(domainList)}
}
func ReadMatcher(reader io.Reader) (*Matcher, error) {
var version uint8
err := binary.Read(reader, binary.BigEndian, &version)
if err != nil {
return nil, err
}
leavesLength, err := rw.ReadUVariant(reader)
if err != nil {
return nil, err
}
leaves := make([]uint64, leavesLength)
err = binary.Read(reader, binary.BigEndian, leaves)
if err != nil {
return nil, err
}
labelBitmapLength, err := rw.ReadUVariant(reader)
if err != nil {
return nil, err
}
labelBitmap := make([]uint64, labelBitmapLength)
err = binary.Read(reader, binary.BigEndian, labelBitmap)
if err != nil {
return nil, err
}
labelsLength, err := rw.ReadUVariant(reader)
if err != nil {
return nil, err
}
labels := make([]byte, labelsLength)
_, err = io.ReadFull(reader, labels)
if err != nil {
return nil, err
}
set := &succinctSet{
leaves: leaves,
labelBitmap: labelBitmap,
labels: labels,
}
set.init()
return &Matcher{set}, nil
}
func (m *Matcher) Match(domain string) bool {
return m.set.Has(reverseDomain(domain))
}
func (m *Matcher) Write(writer io.Writer) error {
err := binary.Write(writer, binary.BigEndian, byte(1))
if err != nil {
return err
}
err = rw.WriteUVariant(writer, uint64(len(m.set.leaves)))
if err != nil {
return err
}
err = binary.Write(writer, binary.BigEndian, m.set.leaves)
if err != nil {
return err
}
err = rw.WriteUVariant(writer, uint64(len(m.set.labelBitmap)))
if err != nil {
return err
}
err = binary.Write(writer, binary.BigEndian, m.set.labelBitmap)
if err != nil {
return err
}
err = rw.WriteUVariant(writer, uint64(len(m.set.labels)))
if err != nil {
return err
}
_, err = writer.Write(m.set.labels)
if err != nil {
return err
}
return nil
}
func reverseDomain(domain string) string {
l := len(domain)
b := make([]byte, l)
@ -58,3 +139,16 @@ func reverseDomainSuffix(domain string) string {
b[l] = prefixLabel
return string(b)
}
func reverseRootDomainSuffix(domain string) string {
l := len(domain)
b := make([]byte, l+2)
for i := 0; i < l; {
r, n := utf8.DecodeRuneInString(domain[i:])
i += n
utf8.EncodeRune(b[l-i:], r)
}
b[l] = '.'
b[l+1] = prefixLabel
return string(b)
}

View file

@ -6,9 +6,6 @@ type causeError struct {
}
func (e *causeError) Error() string {
if e.cause == nil {
return e.message
}
return e.message + ": " + e.cause.Error()
}

View file

@ -26,14 +26,14 @@ func New(message ...any) error {
func Cause(cause error, message ...any) error {
if cause == nil {
return nil
panic("cause on an nil error")
}
return &causeError{F.ToString(message...), cause}
}
func Extend(cause error, message ...any) error {
if cause == nil {
return nil
panic("extend on an nil error")
}
return &extendedError{F.ToString(message...), cause}
}

View file

@ -23,6 +23,7 @@ func (e *multiError) Unwrap() []error {
func Errors(errors ...error) error {
errors = common.FilterNotNil(errors)
errors = ExpandAll(errors)
errors = common.FilterNotNil(errors)
errors = common.UniqBy(errors, error.Error)
switch len(errors) {
case 0:
@ -36,10 +37,13 @@ func Errors(errors ...error) error {
}
func Expand(err error) []error {
if multiErr, isMultiErr := err.(MultiError); isMultiErr {
return ExpandAll(multiErr.Unwrap())
}
if err == nil {
return nil
} else if multiErr, isMultiErr := err.(MultiError); isMultiErr {
return ExpandAll(common.FilterNotNil(multiErr.Unwrap()))
} else {
return []error{err}
}
}
func ExpandAll(errs []error) []error {

View file

@ -0,0 +1,59 @@
package badjson
import (
"bytes"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
type JSONArray []any
func (a JSONArray) IsEmpty() bool {
if len(a) == 0 {
return true
}
return common.All(a, func(it any) bool {
if valueInterface, valueMaybeEmpty := it.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
return true
}
return false
})
}
func (a JSONArray) MarshalJSON() ([]byte, error) {
return json.Marshal([]any(a))
}
func (a *JSONArray) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content))
arrayStart, err := decoder.Token()
if err != nil {
return err
} else if arrayStart != json.Delim('[') {
return E.New("excepted array start, but got ", arrayStart)
}
err = a.decodeJSON(decoder)
if err != nil {
return err
}
arrayEnd, err := decoder.Token()
if err != nil {
return err
} else if arrayEnd != json.Delim(']') {
return E.New("excepted array end, but got ", arrayEnd)
}
return nil
}
func (a *JSONArray) decodeJSON(decoder *json.Decoder) error {
for decoder.More() {
item, err := decodeJSON(decoder)
if err != nil {
return err
}
*a = append(*a, item)
}
return nil
}

View file

@ -0,0 +1,5 @@
package badjson
type isEmpty interface {
IsEmpty() bool
}

View file

@ -0,0 +1,54 @@
package badjson
import (
"bytes"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
func Decode(content []byte) (any, error) {
decoder := json.NewDecoder(bytes.NewReader(content))
return decodeJSON(decoder)
}
func decodeJSON(decoder *json.Decoder) (any, error) {
rawToken, err := decoder.Token()
if err != nil {
return nil, err
}
switch token := rawToken.(type) {
case json.Delim:
switch token {
case '{':
var object JSONObject
err = object.decodeJSON(decoder)
if err != nil {
return nil, err
}
rawToken, err = decoder.Token()
if err != nil {
return nil, err
} else if rawToken != json.Delim('}') {
return nil, E.New("excepted object end, but got ", rawToken)
}
return &object, nil
case '[':
var array JSONArray
err = array.decodeJSON(decoder)
if err != nil {
return nil, err
}
rawToken, err = decoder.Token()
if err != nil {
return nil, err
} else if rawToken != json.Delim(']') {
return nil, E.New("excepted array end, but got ", rawToken)
}
return array, nil
default:
return nil, E.New("excepted object or array end: ", token)
}
}
return rawToken, nil
}

View file

@ -0,0 +1,139 @@
package badjson
import (
"os"
"reflect"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
func Omitempty[T any](value T) (T, error) {
objectContent, err := json.Marshal(value)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal object")
}
rawNewObject, err := Decode(objectContent)
if err != nil {
return common.DefaultValue[T](), err
}
newObjectContent, err := json.Marshal(rawNewObject)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
}
var newObject T
err = json.Unmarshal(newObjectContent, &newObject)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
}
return newObject, nil
}
func Merge[T any](source T, destination T) (T, error) {
rawSource, err := json.Marshal(source)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source")
}
rawDestination, err := json.Marshal(destination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
}
return MergeFrom[T](rawSource, rawDestination)
}
func MergeFromSource[T any](rawSource json.RawMessage, destination T) (T, error) {
if rawSource == nil {
return destination, nil
}
rawDestination, err := json.Marshal(destination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
}
return MergeFrom[T](rawSource, rawDestination)
}
func MergeFromDestination[T any](source T, rawDestination json.RawMessage) (T, error) {
if rawDestination == nil {
return source, nil
}
rawSource, err := json.Marshal(source)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source")
}
return MergeFrom[T](rawSource, rawDestination)
}
func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage) (T, error) {
rawMerged, err := MergeJSON(rawSource, rawDestination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "merge options")
}
var merged T
err = json.Unmarshal(rawMerged, &merged)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
}
return merged, nil
}
func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage) (json.RawMessage, error) {
if rawSource == nil && rawDestination == nil {
return nil, os.ErrInvalid
} else if rawSource == nil {
return rawDestination, nil
} else if rawDestination == nil {
return rawSource, nil
}
source, err := Decode(rawSource)
if err != nil {
return nil, E.Cause(err, "decode source")
}
destination, err := Decode(rawDestination)
if err != nil {
return nil, E.Cause(err, "decode destination")
}
if source == nil {
return json.Marshal(destination)
} else if destination == nil {
return json.Marshal(source)
}
merged, err := mergeJSON(source, destination)
if err != nil {
return nil, err
}
return json.Marshal(merged)
}
func mergeJSON(anySource any, anyDestination any) (any, error) {
switch destination := anyDestination.(type) {
case JSONArray:
switch source := anySource.(type) {
case JSONArray:
destination = append(destination, source...)
default:
destination = append(destination, source)
}
return destination, nil
case *JSONObject:
switch source := anySource.(type) {
case *JSONObject:
for _, entry := range source.Entries() {
oldValue, loaded := destination.Get(entry.Key)
if loaded {
var err error
entry.Value, err = mergeJSON(entry.Value, oldValue)
if err != nil {
return nil, E.Cause(err, "merge object item ", entry.Key)
}
}
destination.Put(entry.Key, entry.Value)
}
default:
return nil, E.New("cannot merge json object into ", reflect.TypeOf(source))
}
return destination, nil
default:
return destination, nil
}
}

View file

@ -0,0 +1,98 @@
package badjson
import (
"bytes"
"strings"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/x/collections"
"github.com/sagernet/sing/common/x/linkedhashmap"
)
type JSONObject struct {
linkedhashmap.Map[string, any]
}
func (m *JSONObject) IsEmpty() bool {
if m.Size() == 0 {
return true
}
return common.All(m.Entries(), func(it collections.MapEntry[string, any]) bool {
if valueInterface, valueMaybeEmpty := it.Value.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
return true
}
return false
})
}
func (m *JSONObject) MarshalJSON() ([]byte, error) {
buffer := new(bytes.Buffer)
buffer.WriteString("{")
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
if valueInterface, valueMaybeEmpty := it.Value.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
return false
}
return true
})
iLen := len(items)
for i, entry := range items {
keyContent, err := json.Marshal(entry.Key)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ")
valueContent, err := json.Marshal(entry.Value)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(valueContent)))
if i < iLen-1 {
buffer.WriteString(", ")
}
}
buffer.WriteString("}")
return buffer.Bytes(), nil
}
func (m *JSONObject) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content))
m.Clear()
objectStart, err := decoder.Token()
if err != nil {
return err
} else if objectStart != json.Delim('{') {
return E.New("expected json object start, but starts with ", objectStart)
}
err = m.decodeJSON(decoder)
if err != nil {
return E.Cause(err, "decode json object content")
}
objectEnd, err := decoder.Token()
if err != nil {
return err
} else if objectEnd != json.Delim('}') {
return E.New("expected json object end, but ends with ", objectEnd)
}
return nil
}
func (m *JSONObject) decodeJSON(decoder *json.Decoder) error {
for decoder.More() {
var entryKey string
keyToken, err := decoder.Token()
if err != nil {
return err
}
entryKey = keyToken.(string)
var entryValue any
entryValue, err = decodeJSON(decoder)
if err != nil {
return E.Cause(err, "decode value for ", entryKey)
}
m.Put(entryKey, entryValue)
}
return nil
}

View file

@ -0,0 +1,86 @@
package badjson
import (
"bytes"
"strings"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/x/linkedhashmap"
)
type TypedMap[K comparable, V any] struct {
linkedhashmap.Map[K, V]
}
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
buffer := new(bytes.Buffer)
buffer.WriteString("{")
items := m.Entries()
iLen := len(items)
for i, entry := range items {
keyContent, err := json.Marshal(entry.Key)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ")
valueContent, err := json.Marshal(entry.Value)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(valueContent)))
if i < iLen-1 {
buffer.WriteString(", ")
}
}
buffer.WriteString("}")
return buffer.Bytes(), nil
}
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content))
m.Clear()
objectStart, err := decoder.Token()
if err != nil {
return err
} else if objectStart != json.Delim('{') {
return E.New("expected json object start, but starts with ", objectStart)
}
err = m.decodeJSON(decoder)
if err != nil {
return E.Cause(err, "decode json object content")
}
objectEnd, err := decoder.Token()
if err != nil {
return err
} else if objectEnd != json.Delim('}') {
return E.New("expected json object end, but ends with ", objectEnd)
}
return nil
}
func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error {
for decoder.More() {
keyToken, err := decoder.Token()
if err != nil {
return err
}
keyContent, err := json.Marshal(keyToken)
if err != nil {
return err
}
var entryKey K
err = json.Unmarshal(keyContent, &entryKey)
if err != nil {
return err
}
var entryValue V
err = decoder.Decode(&entryValue)
if err != nil {
return err
}
m.Put(entryKey, entryValue)
}
return nil
}

128
common/json/comment.go Normal file
View file

@ -0,0 +1,128 @@
package json
import (
"bufio"
"io"
)
// kanged from v2ray
type commentFilterState = byte
const (
commentFilterStateContent commentFilterState = iota
commentFilterStateEscape
commentFilterStateDoubleQuote
commentFilterStateDoubleQuoteEscape
commentFilterStateSingleQuote
commentFilterStateSingleQuoteEscape
commentFilterStateComment
commentFilterStateSlash
commentFilterStateMultilineComment
commentFilterStateMultilineCommentStar
)
type CommentFilter struct {
br *bufio.Reader
state commentFilterState
}
func NewCommentFilter(reader io.Reader) io.Reader {
return &CommentFilter{br: bufio.NewReader(reader)}
}
func (v *CommentFilter) Read(b []byte) (int, error) {
p := b[:0]
for len(p) < len(b)-2 {
x, err := v.br.ReadByte()
if err != nil {
if len(p) == 0 {
return 0, err
}
return len(p), nil
}
switch v.state {
case commentFilterStateContent:
switch x {
case '"':
v.state = commentFilterStateDoubleQuote
p = append(p, x)
case '\'':
v.state = commentFilterStateSingleQuote
p = append(p, x)
case '\\':
v.state = commentFilterStateEscape
case '#':
v.state = commentFilterStateComment
case '/':
v.state = commentFilterStateSlash
default:
p = append(p, x)
}
case commentFilterStateEscape:
p = append(p, '\\', x)
v.state = commentFilterStateContent
case commentFilterStateDoubleQuote:
switch x {
case '"':
v.state = commentFilterStateContent
p = append(p, x)
case '\\':
v.state = commentFilterStateDoubleQuoteEscape
default:
p = append(p, x)
}
case commentFilterStateDoubleQuoteEscape:
p = append(p, '\\', x)
v.state = commentFilterStateDoubleQuote
case commentFilterStateSingleQuote:
switch x {
case '\'':
v.state = commentFilterStateContent
p = append(p, x)
case '\\':
v.state = commentFilterStateSingleQuoteEscape
default:
p = append(p, x)
}
case commentFilterStateSingleQuoteEscape:
p = append(p, '\\', x)
v.state = commentFilterStateSingleQuote
case commentFilterStateComment:
if x == '\n' {
v.state = commentFilterStateContent
p = append(p, '\n')
}
case commentFilterStateSlash:
switch x {
case '/':
v.state = commentFilterStateComment
case '*':
v.state = commentFilterStateMultilineComment
default:
p = append(p, '/', x)
}
case commentFilterStateMultilineComment:
switch x {
case '*':
v.state = commentFilterStateMultilineCommentStar
case '\n':
p = append(p, '\n')
}
case commentFilterStateMultilineCommentStar:
switch x {
case '/':
v.state = commentFilterStateContent
case '*':
// Stay
case '\n':
p = append(p, '\n')
default:
v.state = commentFilterStateMultilineComment
}
default:
panic("Unknown state.")
}
}
return len(p), nil
}

23
common/json/context.go Normal file
View file

@ -0,0 +1,23 @@
//go:build go1.20 && !without_contextjson
package json
import (
"github.com/sagernet/sing/common/json/internal/contextjson"
)
var (
Marshal = json.Marshal
Unmarshal = json.Unmarshal
NewEncoder = json.NewEncoder
NewDecoder = json.NewDecoder
)
type (
Encoder = json.Encoder
Decoder = json.Decoder
Token = json.Token
Delim = json.Delim
SyntaxError = json.SyntaxError
RawMessage = json.RawMessage
)

View file

@ -0,0 +1,3 @@
# contextjson
mod from go1.21.4

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,49 @@
package json
import "strconv"
type decodeContext struct {
parent *decodeContext
index int
key string
}
func (d *decodeState) formatContext() string {
var description string
context := d.context
var appendDot bool
for context != nil {
if appendDot {
description = "." + description
}
if context.key != "" {
description = context.key + description
appendDot = true
} else {
description = "[" + strconv.Itoa(context.index) + "]" + description
appendDot = false
}
context = context.parent
}
return description
}
type contextError struct {
parent error
context string
index bool
}
func (c *contextError) Unwrap() error {
return c.parent
}
func (c *contextError) Error() string {
//goland:noinspection GoTypeAssertionOnErrors
switch c.parent.(type) {
case *contextError:
return c.context + "." + c.parent.Error()
default:
return c.context + ": " + c.parent.Error()
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,48 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"unicode"
"unicode/utf8"
)
// foldName returns a folded string such that foldName(x) == foldName(y)
// is identical to bytes.EqualFold(x, y).
func foldName(in []byte) []byte {
// This is inlinable to take advantage of "function outlining".
var arr [32]byte // large enough for most JSON names
return appendFoldedName(arr[:0], in)
}
func appendFoldedName(out, in []byte) []byte {
for i := 0; i < len(in); {
// Handle single-byte ASCII.
if c := in[i]; c < utf8.RuneSelf {
if 'a' <= c && c <= 'z' {
c -= 'a' - 'A'
}
out = append(out, c)
i++
continue
}
// Handle multi-byte Unicode.
r, n := utf8.DecodeRune(in[i:])
out = utf8.AppendRune(out, foldRune(r))
i += n
}
return out
}
// foldRune is returns the smallest rune for all runes in the same fold set.
func foldRune(r rune) rune {
for {
r2 := unicode.SimpleFold(r)
if r2 <= r {
return r2
}
r = r2
}
}

View file

@ -0,0 +1,179 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import "bytes"
// TODO(https://go.dev/issue/53685): Use bytes.Buffer.AvailableBuffer instead.
func availableBuffer(b *bytes.Buffer) []byte {
return b.Bytes()[b.Len():]
}
// HTMLEscape appends to dst the JSON-encoded src with <, >, &, U+2028 and U+2029
// characters inside string literals changed to \u003c, \u003e, \u0026, \u2028, \u2029
// so that the JSON will be safe to embed inside HTML <script> tags.
// For historical reasons, web browsers don't honor standard HTML
// escaping within <script> tags, so an alternative JSON encoding must be used.
func HTMLEscape(dst *bytes.Buffer, src []byte) {
dst.Grow(len(src))
dst.Write(appendHTMLEscape(availableBuffer(dst), src))
}
func appendHTMLEscape(dst, src []byte) []byte {
// The characters can only appear in string literals,
// so just scan the string one byte at a time.
start := 0
for i, c := range src {
if c == '<' || c == '>' || c == '&' {
dst = append(dst, src[start:i]...)
dst = append(dst, '\\', 'u', '0', '0', hex[c>>4], hex[c&0xF])
start = i + 1
}
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
if c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
dst = append(dst, src[start:i]...)
dst = append(dst, '\\', 'u', '2', '0', '2', hex[src[i+2]&0xF])
start = i + len("\u2029")
}
}
return append(dst, src[start:]...)
}
// Compact appends to dst the JSON-encoded src with
// insignificant space characters elided.
func Compact(dst *bytes.Buffer, src []byte) error {
dst.Grow(len(src))
b := availableBuffer(dst)
b, err := appendCompact(b, src, false)
dst.Write(b)
return err
}
func appendCompact(dst, src []byte, escape bool) ([]byte, error) {
origLen := len(dst)
scan := newScanner()
defer freeScanner(scan)
start := 0
for i, c := range src {
if escape && (c == '<' || c == '>' || c == '&') {
dst = append(dst, src[start:i]...)
dst = append(dst, '\\', 'u', '0', '0', hex[c>>4], hex[c&0xF])
start = i + 1
}
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
if escape && c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
dst = append(dst, src[start:i]...)
dst = append(dst, '\\', 'u', '2', '0', '2', hex[src[i+2]&0xF])
start = i + len("\u2029")
}
v := scan.step(scan, c)
if v >= scanSkipSpace {
if v == scanError {
break
}
dst = append(dst, src[start:i]...)
start = i + 1
}
}
if scan.eof() == scanError {
return dst[:origLen], scan.err
}
dst = append(dst, src[start:]...)
return dst, nil
}
func appendNewline(dst []byte, prefix, indent string, depth int) []byte {
dst = append(dst, '\n')
dst = append(dst, prefix...)
for i := 0; i < depth; i++ {
dst = append(dst, indent...)
}
return dst
}
// indentGrowthFactor specifies the growth factor of indenting JSON input.
// Empirically, the growth factor was measured to be between 1.4x to 1.8x
// for some set of compacted JSON with the indent being a single tab.
// Specify a growth factor slightly larger than what is observed
// to reduce probability of allocation in appendIndent.
// A factor no higher than 2 ensures that wasted space never exceeds 50%.
const indentGrowthFactor = 2
// Indent appends to dst an indented form of the JSON-encoded src.
// Each element in a JSON object or array begins on a new,
// indented line beginning with prefix followed by one or more
// copies of indent according to the indentation nesting.
// The data appended to dst does not begin with the prefix nor
// any indentation, to make it easier to embed inside other formatted JSON data.
// Although leading space characters (space, tab, carriage return, newline)
// at the beginning of src are dropped, trailing space characters
// at the end of src are preserved and copied to dst.
// For example, if src has no trailing spaces, neither will dst;
// if src ends in a trailing newline, so will dst.
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
dst.Grow(indentGrowthFactor * len(src))
b := availableBuffer(dst)
b, err := appendIndent(b, src, prefix, indent)
dst.Write(b)
return err
}
func appendIndent(dst, src []byte, prefix, indent string) ([]byte, error) {
origLen := len(dst)
scan := newScanner()
defer freeScanner(scan)
needIndent := false
depth := 0
for _, c := range src {
scan.bytes++
v := scan.step(scan, c)
if v == scanSkipSpace {
continue
}
if v == scanError {
break
}
if needIndent && v != scanEndObject && v != scanEndArray {
needIndent = false
depth++
dst = appendNewline(dst, prefix, indent, depth)
}
// Emit semantically uninteresting bytes
// (in particular, punctuation in strings) unmodified.
if v == scanContinue {
dst = append(dst, c)
continue
}
// Add spacing around real punctuation.
switch c {
case '{', '[':
// delay indent so that empty object and array are formatted as {} and [].
needIndent = true
dst = append(dst, c)
case ',':
dst = append(dst, c)
dst = appendNewline(dst, prefix, indent, depth)
case ':':
dst = append(dst, c, ' ')
case '}', ']':
if needIndent {
// suppress indent in empty object/array
needIndent = false
} else {
depth--
dst = appendNewline(dst, prefix, indent, depth)
}
dst = append(dst, c)
default:
dst = append(dst, c)
}
}
if scan.eof() == scanError {
return dst[:origLen], scan.err
}
return dst, nil
}

View file

@ -0,0 +1,610 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
// JSON value parser state machine.
// Just about at the limit of what is reasonable to write by hand.
// Some parts are a bit tedious, but overall it nicely factors out the
// otherwise common code from the multiple scanning functions
// in this package (Compact, Indent, checkValid, etc).
//
// This file starts with two simple examples using the scanner
// before diving into the scanner itself.
import (
"strconv"
"sync"
)
// Valid reports whether data is a valid JSON encoding.
func Valid(data []byte) bool {
scan := newScanner()
defer freeScanner(scan)
return checkValid(data, scan) == nil
}
// checkValid verifies that data is valid JSON-encoded data.
// scan is passed in for use by checkValid to avoid an allocation.
// checkValid returns nil or a SyntaxError.
func checkValid(data []byte, scan *scanner) error {
scan.reset()
for _, c := range data {
scan.bytes++
if scan.step(scan, c) == scanError {
return scan.err
}
}
if scan.eof() == scanError {
return scan.err
}
return nil
}
// A SyntaxError is a description of a JSON syntax error.
// Unmarshal will return a SyntaxError if the JSON can't be parsed.
type SyntaxError struct {
msg string // description of error
Offset int64 // error occurred after reading Offset bytes
}
func (e *SyntaxError) Error() string { return e.msg }
// A scanner is a JSON scanning state machine.
// Callers call scan.reset and then pass bytes in one at a time
// by calling scan.step(&scan, c) for each byte.
// The return value, referred to as an opcode, tells the
// caller about significant parsing events like beginning
// and ending literals, objects, and arrays, so that the
// caller can follow along if it wishes.
// The return value scanEnd indicates that a single top-level
// JSON value has been completed, *before* the byte that
// just got passed in. (The indication must be delayed in order
// to recognize the end of numbers: is 123 a whole value or
// the beginning of 12345e+6?).
type scanner struct {
// The step is a func to be called to execute the next transition.
// Also tried using an integer constant and a single func
// with a switch, but using the func directly was 10% faster
// on a 64-bit Mac Mini, and it's nicer to read.
step func(*scanner, byte) int
// Reached end of top-level value.
endTop bool
// Stack of what we're in the middle of - array values, object keys, object values.
parseState []int
// Error that happened, if any.
err error
// total bytes consumed, updated by decoder.Decode (and deliberately
// not set to zero by scan.reset)
bytes int64
}
var scannerPool = sync.Pool{
New: func() any {
return &scanner{}
},
}
func newScanner() *scanner {
scan := scannerPool.Get().(*scanner)
// scan.reset by design doesn't set bytes to zero
scan.bytes = 0
scan.reset()
return scan
}
func freeScanner(scan *scanner) {
// Avoid hanging on to too much memory in extreme cases.
if len(scan.parseState) > 1024 {
scan.parseState = nil
}
scannerPool.Put(scan)
}
// These values are returned by the state transition functions
// assigned to scanner.state and the method scanner.eof.
// They give details about the current state of the scan that
// callers might be interested to know about.
// It is okay to ignore the return value of any particular
// call to scanner.state: if one call returns scanError,
// every subsequent call will return scanError too.
const (
// Continue.
scanContinue = iota // uninteresting byte
scanBeginLiteral // end implied by next result != scanContinue
scanBeginObject // begin object
scanObjectKey // just finished object key (string)
scanObjectValue // just finished non-last object value
scanEndObject // end object (implies scanObjectValue if possible)
scanBeginArray // begin array
scanArrayValue // just finished array value
scanEndArray // end array (implies scanArrayValue if possible)
scanSkipSpace // space byte; can skip; known to be last "continue" result
// Stop.
scanEnd // top-level value ended *before* this byte; known to be first "stop" result
scanError // hit an error, scanner.err.
)
// These values are stored in the parseState stack.
// They give the current state of a composite value
// being scanned. If the parser is inside a nested value
// the parseState describes the nested state, outermost at entry 0.
const (
parseObjectKey = iota // parsing object key (before colon)
parseObjectValue // parsing object value (after colon)
parseArrayValue // parsing array value
)
// This limits the max nesting depth to prevent stack overflow.
// This is permitted by https://tools.ietf.org/html/rfc7159#section-9
const maxNestingDepth = 10000
// reset prepares the scanner for use.
// It must be called before calling s.step.
func (s *scanner) reset() {
s.step = stateBeginValue
s.parseState = s.parseState[0:0]
s.err = nil
s.endTop = false
}
// eof tells the scanner that the end of input has been reached.
// It returns a scan status just as s.step does.
func (s *scanner) eof() int {
if s.err != nil {
return scanError
}
if s.endTop {
return scanEnd
}
s.step(s, ' ')
if s.endTop {
return scanEnd
}
if s.err == nil {
s.err = &SyntaxError{"unexpected end of JSON input", s.bytes}
}
return scanError
}
// pushParseState pushes a new parse state p onto the parse stack.
// an error state is returned if maxNestingDepth was exceeded, otherwise successState is returned.
func (s *scanner) pushParseState(c byte, newParseState int, successState int) int {
s.parseState = append(s.parseState, newParseState)
if len(s.parseState) <= maxNestingDepth {
return successState
}
return s.error(c, "exceeded max depth")
}
// popParseState pops a parse state (already obtained) off the stack
// and updates s.step accordingly.
func (s *scanner) popParseState() {
n := len(s.parseState) - 1
s.parseState = s.parseState[0:n]
if n == 0 {
s.step = stateEndTop
s.endTop = true
} else {
s.step = stateEndValue
}
}
func isSpace(c byte) bool {
return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n')
}
// stateBeginValueOrEmpty is the state after reading `[`.
func stateBeginValueOrEmpty(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == ']' {
return stateEndValue(s, c)
}
return stateBeginValue(s, c)
}
// stateBeginValue is the state at the beginning of the input.
func stateBeginValue(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
switch c {
case '{':
s.step = stateBeginStringOrEmpty
return s.pushParseState(c, parseObjectKey, scanBeginObject)
case '[':
s.step = stateBeginValueOrEmpty
return s.pushParseState(c, parseArrayValue, scanBeginArray)
case '"':
s.step = stateInString
return scanBeginLiteral
case '-':
s.step = stateNeg
return scanBeginLiteral
case '0': // beginning of 0.123
s.step = state0
return scanBeginLiteral
case 't': // beginning of true
s.step = stateT
return scanBeginLiteral
case 'f': // beginning of false
s.step = stateF
return scanBeginLiteral
case 'n': // beginning of null
s.step = stateN
return scanBeginLiteral
}
if '1' <= c && c <= '9' { // beginning of 1234.5
s.step = state1
return scanBeginLiteral
}
return s.error(c, "looking for beginning of value")
}
// stateBeginStringOrEmpty is the state after reading `{`.
func stateBeginStringOrEmpty(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == '}' {
n := len(s.parseState)
s.parseState[n-1] = parseObjectValue
return stateEndValue(s, c)
}
return stateBeginString(s, c)
}
// stateBeginString is the state after reading `{"key": value,`.
func stateBeginString(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == '"' {
s.step = stateInString
return scanBeginLiteral
}
return s.error(c, "looking for beginning of object key string")
}
// stateEndValue is the state after completing a value,
// such as after reading `{}` or `true` or `["x"`.
func stateEndValue(s *scanner, c byte) int {
n := len(s.parseState)
if n == 0 {
// Completed top-level before the current byte.
s.step = stateEndTop
s.endTop = true
return stateEndTop(s, c)
}
if isSpace(c) {
s.step = stateEndValue
return scanSkipSpace
}
ps := s.parseState[n-1]
switch ps {
case parseObjectKey:
if c == ':' {
s.parseState[n-1] = parseObjectValue
s.step = stateBeginValue
return scanObjectKey
}
return s.error(c, "after object key")
case parseObjectValue:
if c == ',' {
s.parseState[n-1] = parseObjectKey
s.step = stateBeginStringOrEmpty
return scanObjectValue
}
if c == '}' {
s.popParseState()
return scanEndObject
}
return s.error(c, "after object key:value pair")
case parseArrayValue:
if c == ',' {
s.step = stateBeginValueOrEmpty
return scanArrayValue
}
if c == ']' {
s.popParseState()
return scanEndArray
}
return s.error(c, "after array element")
}
return s.error(c, "")
}
// stateEndTop is the state after finishing the top-level value,
// such as after reading `{}` or `[1,2,3]`.
// Only space characters should be seen now.
func stateEndTop(s *scanner, c byte) int {
if !isSpace(c) {
// Complain about non-space byte on next call.
s.error(c, "after top-level value")
}
return scanEnd
}
// stateInString is the state after reading `"`.
func stateInString(s *scanner, c byte) int {
if c == '"' {
s.step = stateEndValue
return scanContinue
}
if c == '\\' {
s.step = stateInStringEsc
return scanContinue
}
if c < 0x20 {
return s.error(c, "in string literal")
}
return scanContinue
}
// stateInStringEsc is the state after reading `"\` during a quoted string.
func stateInStringEsc(s *scanner, c byte) int {
switch c {
case 'b', 'f', 'n', 'r', 't', '\\', '/', '"':
s.step = stateInString
return scanContinue
case 'u':
s.step = stateInStringEscU
return scanContinue
}
return s.error(c, "in string escape code")
}
// stateInStringEscU is the state after reading `"\u` during a quoted string.
func stateInStringEscU(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU1
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU1 is the state after reading `"\u1` during a quoted string.
func stateInStringEscU1(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU12
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU12 is the state after reading `"\u12` during a quoted string.
func stateInStringEscU12(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU123
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU123 is the state after reading `"\u123` during a quoted string.
func stateInStringEscU123(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInString
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateNeg is the state after reading `-` during a number.
func stateNeg(s *scanner, c byte) int {
if c == '0' {
s.step = state0
return scanContinue
}
if '1' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return s.error(c, "in numeric literal")
}
// state1 is the state after reading a non-zero integer during a number,
// such as after reading `1` or `100` but not `0`.
func state1(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return state0(s, c)
}
// state0 is the state after reading `0` during a number.
func state0(s *scanner, c byte) int {
if c == '.' {
s.step = stateDot
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateDot is the state after reading the integer and decimal point in a number,
// such as after reading `1.`.
func stateDot(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateDot0
return scanContinue
}
return s.error(c, "after decimal point in numeric literal")
}
// stateDot0 is the state after reading the integer, decimal point, and subsequent
// digits of a number, such as after reading `3.14`.
func stateDot0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateE is the state after reading the mantissa and e in a number,
// such as after reading `314e` or `0.314e`.
func stateE(s *scanner, c byte) int {
if c == '+' || c == '-' {
s.step = stateESign
return scanContinue
}
return stateESign(s, c)
}
// stateESign is the state after reading the mantissa, e, and sign in a number,
// such as after reading `314e-` or `0.314e+`.
func stateESign(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateE0
return scanContinue
}
return s.error(c, "in exponent of numeric literal")
}
// stateE0 is the state after reading the mantissa, e, optional sign,
// and at least one digit of the exponent in a number,
// such as after reading `314e-2` or `0.314e+1` or `3.14e0`.
func stateE0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
return stateEndValue(s, c)
}
// stateT is the state after reading `t`.
func stateT(s *scanner, c byte) int {
if c == 'r' {
s.step = stateTr
return scanContinue
}
return s.error(c, "in literal true (expecting 'r')")
}
// stateTr is the state after reading `tr`.
func stateTr(s *scanner, c byte) int {
if c == 'u' {
s.step = stateTru
return scanContinue
}
return s.error(c, "in literal true (expecting 'u')")
}
// stateTru is the state after reading `tru`.
func stateTru(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal true (expecting 'e')")
}
// stateF is the state after reading `f`.
func stateF(s *scanner, c byte) int {
if c == 'a' {
s.step = stateFa
return scanContinue
}
return s.error(c, "in literal false (expecting 'a')")
}
// stateFa is the state after reading `fa`.
func stateFa(s *scanner, c byte) int {
if c == 'l' {
s.step = stateFal
return scanContinue
}
return s.error(c, "in literal false (expecting 'l')")
}
// stateFal is the state after reading `fal`.
func stateFal(s *scanner, c byte) int {
if c == 's' {
s.step = stateFals
return scanContinue
}
return s.error(c, "in literal false (expecting 's')")
}
// stateFals is the state after reading `fals`.
func stateFals(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal false (expecting 'e')")
}
// stateN is the state after reading `n`.
func stateN(s *scanner, c byte) int {
if c == 'u' {
s.step = stateNu
return scanContinue
}
return s.error(c, "in literal null (expecting 'u')")
}
// stateNu is the state after reading `nu`.
func stateNu(s *scanner, c byte) int {
if c == 'l' {
s.step = stateNul
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateNul is the state after reading `nul`.
func stateNul(s *scanner, c byte) int {
if c == 'l' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateError is the state after reaching a syntax error,
// such as after reading `[1}` or `5.1.2`.
func stateError(s *scanner, c byte) int {
return scanError
}
// error records an error and switches to the error state.
func (s *scanner) error(c byte, context string) int {
s.step = stateError
s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes}
return scanError
}
// quoteChar formats c as a quoted character literal.
func quoteChar(c byte) string {
// special cases - different from quoted strings
if c == '\'' {
return `'\''`
}
if c == '"' {
return `'"'`
}
// use quoted string with different quotation marks
s := strconv.Quote(string(c))
return "'" + s[1:len(s)-1] + "'"
}

View file

@ -0,0 +1,554 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"errors"
"io"
)
// A Decoder reads and decodes JSON values from an input stream.
type Decoder struct {
r io.Reader
buf []byte
d decodeState
scanp int // start of unread data in buf
scanned int64 // amount of data already scanned
scan scanner
err error
tokenState int
tokenStack []int
}
// NewDecoder returns a new decoder that reads from r.
//
// The decoder introduces its own buffering and may
// read data from r beyond the JSON values requested.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
// Number instead of as a float64.
func (dec *Decoder) UseNumber() { dec.d.useNumber = true }
// DisallowUnknownFields causes the Decoder to return an error when the destination
// is a struct and the input contains object keys which do not match any
// non-ignored, exported fields in the destination.
func (dec *Decoder) DisallowUnknownFields() { dec.d.disallowUnknownFields = true }
// Decode reads the next JSON-encoded value from its
// input and stores it in the value pointed to by v.
//
// See the documentation for Unmarshal for details about
// the conversion of JSON into a Go value.
func (dec *Decoder) Decode(v any) error {
if dec.err != nil {
return dec.err
}
if err := dec.tokenPrepareForDecode(); err != nil {
return err
}
if !dec.tokenValueAllowed() {
return &SyntaxError{msg: "not at beginning of value", Offset: dec.InputOffset()}
}
// Read whole value into buffer.
n, err := dec.readValue()
if err != nil {
return err
}
dec.d.init(dec.buf[dec.scanp : dec.scanp+n])
dec.scanp += n
// Don't save err from unmarshal into dec.err:
// the connection is still usable since we read a complete JSON
// object from it before the error happened.
err = dec.d.unmarshal(v)
// fixup token streaming state
dec.tokenValueEnd()
return err
}
// Buffered returns a reader of the data remaining in the Decoder's
// buffer. The reader is valid until the next call to Decode.
func (dec *Decoder) Buffered() io.Reader {
return bytes.NewReader(dec.buf[dec.scanp:])
}
// readValue reads a JSON value into dec.buf.
// It returns the length of the encoding.
func (dec *Decoder) readValue() (int, error) {
dec.scan.reset()
scanp := dec.scanp
var err error
Input:
// help the compiler see that scanp is never negative, so it can remove
// some bounds checks below.
for scanp >= 0 {
// Look in the buffer for a new value.
for ; scanp < len(dec.buf); scanp++ {
c := dec.buf[scanp]
dec.scan.bytes++
switch dec.scan.step(&dec.scan, c) {
case scanEnd:
// scanEnd is delayed one byte so we decrement
// the scanner bytes count by 1 to ensure that
// this value is correct in the next call of Decode.
dec.scan.bytes--
break Input
case scanEndObject, scanEndArray:
// scanEnd is delayed one byte.
// We might block trying to get that byte from src,
// so instead invent a space byte.
if stateEndValue(&dec.scan, ' ') == scanEnd {
scanp++
break Input
}
case scanError:
dec.err = dec.scan.err
return 0, dec.scan.err
}
}
// Did the last read have an error?
// Delayed until now to allow buffer scan.
if err != nil {
if err == io.EOF {
if dec.scan.step(&dec.scan, ' ') == scanEnd {
break Input
}
if nonSpace(dec.buf) {
err = io.ErrUnexpectedEOF
}
}
dec.err = err
return 0, err
}
n := scanp - dec.scanp
err = dec.refill()
scanp = dec.scanp + n
}
return scanp - dec.scanp, nil
}
func (dec *Decoder) refill() error {
// Make room to read more into the buffer.
// First slide down data already consumed.
if dec.scanp > 0 {
dec.scanned += int64(dec.scanp)
n := copy(dec.buf, dec.buf[dec.scanp:])
dec.buf = dec.buf[:n]
dec.scanp = 0
}
return dec.refill0()
}
func (dec *Decoder) refill0() error {
// Grow buffer if not large enough.
const minRead = 512
if cap(dec.buf)-len(dec.buf) < minRead {
newBuf := make([]byte, len(dec.buf), 2*cap(dec.buf)+minRead)
copy(newBuf, dec.buf)
dec.buf = newBuf
}
// Read. Delay error for next iteration (after scan).
n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
dec.buf = dec.buf[0 : len(dec.buf)+n]
return err
}
func nonSpace(b []byte) bool {
for _, c := range b {
if !isSpace(c) {
return true
}
}
return false
}
// An Encoder writes JSON values to an output stream.
type Encoder struct {
w io.Writer
err error
escapeHTML bool
indentBuf []byte
indentPrefix string
indentValue string
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w, escapeHTML: true}
}
// Encode writes the JSON encoding of v to the stream,
// followed by a newline character.
//
// See the documentation for Marshal for details about the
// conversion of Go values to JSON.
func (enc *Encoder) Encode(v any) error {
if enc.err != nil {
return enc.err
}
e := newEncodeState()
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
if err != nil {
return err
}
// Terminate each value with a newline.
// This makes the output look a little nicer
// when debugging, and some kind of space
// is required if the encoded value was a number,
// so that the reader knows there aren't more
// digits coming.
e.WriteByte('\n')
b := e.Bytes()
if enc.indentPrefix != "" || enc.indentValue != "" {
enc.indentBuf, err = appendIndent(enc.indentBuf[:0], b, enc.indentPrefix, enc.indentValue)
if err != nil {
return err
}
b = enc.indentBuf
}
if _, err = enc.w.Write(b); err != nil {
enc.err = err
}
return err
}
// SetIndent instructs the encoder to format each subsequent encoded
// value as if indented by the package-level function Indent(dst, src, prefix, indent).
// Calling SetIndent("", "") disables indentation.
func (enc *Encoder) SetIndent(prefix, indent string) {
enc.indentPrefix = prefix
enc.indentValue = indent
}
// SetEscapeHTML specifies whether problematic HTML characters
// should be escaped inside JSON quoted strings.
// The default behavior is to escape &, <, and > to \u0026, \u003c, and \u003e
// to avoid certain safety problems that can arise when embedding JSON in HTML.
//
// In non-HTML settings where the escaping interferes with the readability
// of the output, SetEscapeHTML(false) disables this behavior.
func (enc *Encoder) SetEscapeHTML(on bool) {
enc.escapeHTML = on
}
// RawMessage is a raw encoded JSON value.
// It implements Marshaler and Unmarshaler and can
// be used to delay JSON decoding or precompute a JSON encoding.
type RawMessage []byte
// MarshalJSON returns m as the JSON encoding of m.
func (m RawMessage) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("null"), nil
}
return m, nil
}
// UnmarshalJSON sets *m to a copy of data.
func (m *RawMessage) UnmarshalJSON(data []byte) error {
if m == nil {
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
}
*m = append((*m)[0:0], data...)
return nil
}
var (
_ Marshaler = (*RawMessage)(nil)
_ Unmarshaler = (*RawMessage)(nil)
)
// A Token holds a value of one of these types:
//
// Delim, for the four JSON delimiters [ ] { }
// bool, for JSON booleans
// float64, for JSON numbers
// Number, for JSON numbers
// string, for JSON string literals
// nil, for JSON null
type Token any
const (
tokenTopValue = iota
tokenArrayStart
tokenArrayValue
tokenArrayComma
tokenObjectStart
tokenObjectKey
tokenObjectColon
tokenObjectValue
tokenObjectComma
)
// advance tokenstate from a separator state to a value state
func (dec *Decoder) tokenPrepareForDecode() error {
// Note: Not calling peek before switch, to avoid
// putting peek into the standard Decode path.
// peek is only called when using the Token API.
switch dec.tokenState {
case tokenArrayComma:
c, err := dec.peek()
if err != nil {
return err
}
if c != ',' {
return &SyntaxError{"expected comma after array element", dec.InputOffset()}
}
dec.scanp++
dec.tokenState = tokenArrayValue
case tokenObjectColon:
c, err := dec.peek()
if err != nil {
return err
}
if c != ':' {
return &SyntaxError{"expected colon after object key", dec.InputOffset()}
}
dec.scanp++
dec.tokenState = tokenObjectValue
}
return nil
}
func (dec *Decoder) tokenValueAllowed() bool {
switch dec.tokenState {
case tokenTopValue, tokenArrayStart, tokenArrayValue, tokenObjectValue:
return true
}
return false
}
func (dec *Decoder) tokenValueEnd() {
switch dec.tokenState {
case tokenArrayStart, tokenArrayValue:
dec.tokenState = tokenArrayComma
case tokenObjectValue:
dec.tokenState = tokenObjectComma
}
}
// A Delim is a JSON array or object delimiter, one of [ ] { or }.
type Delim rune
func (d Delim) String() string {
return string(d)
}
// Token returns the next JSON token in the input stream.
// At the end of the input stream, Token returns nil, io.EOF.
//
// Token guarantees that the delimiters [ ] { } it returns are
// properly nested and matched: if Token encounters an unexpected
// delimiter in the input, it will return an error.
//
// The input stream consists of basic JSON values—bool, string,
// number, and null—along with delimiters [ ] { } of type Delim
// to mark the start and end of arrays and objects.
// Commas and colons are elided.
func (dec *Decoder) Token() (Token, error) {
for {
c, err := dec.peek()
if err != nil {
return nil, err
}
switch c {
case '[':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenArrayStart
return Delim('['), nil
case ']':
if dec.tokenState != tokenArrayStart && dec.tokenState != tokenArrayComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim(']'), nil
case '{':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenObjectStart
return Delim('{'), nil
case '}':
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma && dec.tokenState != tokenObjectKey {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim('}'), nil
case ':':
if dec.tokenState != tokenObjectColon {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = tokenObjectValue
continue
case ',':
if dec.tokenState == tokenArrayComma {
dec.scanp++
dec.tokenState = tokenArrayValue
continue
}
if dec.tokenState == tokenObjectComma {
dec.scanp++
dec.tokenState = tokenObjectKey
continue
}
return dec.tokenError(c)
case '"':
if dec.tokenState == tokenObjectStart || dec.tokenState == tokenObjectKey {
var x string
old := dec.tokenState
dec.tokenState = tokenTopValue
err := dec.Decode(&x)
dec.tokenState = old
if err != nil {
return nil, err
}
dec.tokenState = tokenObjectColon
return x, nil
}
fallthrough
default:
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
var x any
if err := dec.Decode(&x); err != nil {
return nil, err
}
return x, nil
}
}
}
func (dec *Decoder) tokenError(c byte) (Token, error) {
var context string
switch dec.tokenState {
case tokenTopValue:
context = " looking for beginning of value"
case tokenArrayStart, tokenArrayValue, tokenObjectValue:
context = " looking for beginning of value"
case tokenArrayComma:
context = " after array element"
case tokenObjectKey:
context = " looking for beginning of object key string"
case tokenObjectColon:
context = " after object key"
case tokenObjectComma:
context = " after object key:value pair"
}
return nil, &SyntaxError{"invalid character " + quoteChar(c) + context, dec.InputOffset()}
}
// More reports whether there is another element in the
// current array or object being parsed.
func (dec *Decoder) More() bool {
c, err := dec.peek()
// return err == nil && c != ']' && c != '}'
if err != nil {
return false
}
if c == ']' || c == '}' {
return false
}
if c == ',' {
scanp := dec.scanp
dec.scanp++
c, err = dec.peekNoRefill()
dec.scanp = scanp
if err != nil {
return false
}
if c == ']' || c == '}' {
return false
}
}
return true
}
func (dec *Decoder) peek() (byte, error) {
var err error
for {
for i := dec.scanp; i < len(dec.buf); i++ {
c := dec.buf[i]
if isSpace(c) {
continue
}
dec.scanp = i
return c, nil
}
// buffer has been scanned, now report any error
if err != nil {
return 0, err
}
err = dec.refill()
}
}
func (dec *Decoder) peekNoRefill() (byte, error) {
var err error
for {
for i := dec.scanp; i < len(dec.buf); i++ {
c := dec.buf[i]
if isSpace(c) {
continue
}
dec.scanp = i
return c, nil
}
// buffer has been scanned, now report any error
if err != nil {
return 0, err
}
err = dec.refill0()
}
}
// InputOffset returns the input stream byte offset of the current decoder position.
// The offset gives the location of the end of the most recently returned token
// and the beginning of the next token.
func (dec *Decoder) InputOffset() int64 {
return dec.scanned + int64(dec.scanp)
}

View file

@ -0,0 +1,218 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}

View file

@ -0,0 +1,38 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"strings"
)
// tagOptions is the string following a comma in a struct field's "json"
// tag, or the empty string. It does not include the leading comma.
type tagOptions string
// parseTag splits a struct field's json tag into its name and
// comma-separated options.
func parseTag(tag string) (string, tagOptions) {
tag, opt, _ := strings.Cut(tag, ",")
return tag, tagOptions(opt)
}
// Contains reports whether a comma-separated list of options
// contains a particular substr flag. substr must be surrounded by a
// string boundary or commas.
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var name string
name, s, _ = strings.Cut(s, ",")
if name == optionName {
return true
}
}
return false
}

21
common/json/std.go Normal file
View file

@ -0,0 +1,21 @@
//go:build !go1.20 || without_contextjson
package json
import "encoding/json"
var (
Marshal = json.Marshal
Unmarshal = json.Unmarshal
NewEncoder = json.NewEncoder
NewDecoder = json.NewDecoder
)
type (
Encoder = json.Encoder
Decoder = json.Decoder
Token = json.Token
Delim = json.Delim
SyntaxError = json.SyntaxError
RawMessage = json.RawMessage
)

25
common/json/unmarshal.go Normal file
View file

@ -0,0 +1,25 @@
package json
import (
"bytes"
"strings"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
func UnmarshalExtended[T any](content []byte) (T, error) {
decoder := NewDecoder(NewCommentFilter(bytes.NewReader(content)))
var value T
err := decoder.Decode(&value)
if err == nil {
return value, err
}
if syntaxError, isSyntaxError := err.(*SyntaxError); isSyntaxError {
prefix := string(content[:syntaxError.Offset])
row := strings.Count(prefix, "\n") + 1
column := len(prefix) - strings.LastIndex(prefix, "\n") - 1
return common.DefaultValue[T](), E.Extend(syntaxError, "row ", row, ", column ", column)
}
return common.DefaultValue[T](), err
}

16
common/memory/memory.go Normal file
View file

@ -0,0 +1,16 @@
package memory
import "runtime"
func Total() uint64 {
if nativeAvailable {
return usageNative()
}
return Inuse()
}
func Inuse() uint64 {
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
return memStats.StackInuse + memStats.HeapInuse + memStats.HeapIdle - memStats.HeapReleased
}

View file

@ -0,0 +1,18 @@
package memory
// #include <mach/mach.h>
import "C"
import "unsafe"
const nativeAvailable = true
func usageNative() uint64 {
var memoryUsageInByte uint64
var vmInfo C.task_vm_info_data_t
var count C.mach_msg_type_number_t = C.TASK_VM_INFO_COUNT
var kernelReturn C.kern_return_t = C.task_info(C.vm_map_t(C.mach_task_self_), C.TASK_VM_INFO, (*C.integer_t)(unsafe.Pointer(&vmInfo)), &count)
if kernelReturn == C.KERN_SUCCESS {
memoryUsageInByte = uint64(vmInfo.phys_footprint)
}
return memoryUsageInByte
}

View file

@ -0,0 +1,9 @@
//go:build (darwin && !cgo) || !darwin
package memory
const nativeAvailable = false
func usageNative() uint64 {
return 0
}

Some files were not shown because too many files have changed in this diff Show more