diff --git a/.github/renovate.json b/.github/renovate.json index 10d5ad5..3515733 100644 --- a/.github/renovate.json +++ b/.github/renovate.json @@ -1,6 +1,7 @@ { "$schema": "https://docs.renovatebot.com/renovate-schema.json", "commitMessagePrefix": "[dependencies]", + "branchName": "main", "extends": [ "config:base", ":disableRateLimiting" diff --git a/.github/workflows/debug.yml b/.github/workflows/debug.yml index 93d7659..7ae30fa 100644 --- a/.github/workflows/debug.yml +++ b/.github/workflows/debug.yml @@ -3,6 +3,7 @@ name: Debug build on: push: branches: + - main - dev paths-ignore: - '**.md' @@ -10,34 +11,86 @@ on: - '!.github/workflows/debug.yml' pull_request: branches: + - main - dev jobs: build: - name: Debug build + name: Linux Debug build runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - - name: Get latest go version - id: version - run: | - echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g') - name: Setup Go uses: actions/setup-go@v4 with: - go-version: ${{ steps.version.outputs.go_version }} - - name: Add cache to Go proxy + go-version: ^1.22 + - name: Build run: | - version=`git rev-parse HEAD` - mkdir build - pushd build - go mod init build - go get -v github.com/sagernet/sing@$version - popd + make test + build_go120: + name: Linux Debug build (Go 1.20) + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: ~1.20 continue-on-error: true - name: Build run: | make test + build_go121: + name: Linux Debug build (Go 1.21) + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: ~1.21 + continue-on-error: true + - name: Build + run: | + make test + build__windows: + name: Windows Debug build + runs-on: windows-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: ^1.22 + continue-on-error: true + - name: Build + run: | + make test + build_darwin: + name: macOS Debug build + runs-on: macos-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: ^1.22 + continue-on-error: true + - name: Build + run: | + make test \ No newline at end of file diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 0a1c8f2..2947450 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -3,6 +3,7 @@ name: Lint on: push: branches: + - main - dev paths-ignore: - '**.md' @@ -10,6 +11,7 @@ on: - '!.github/workflows/lint.yml' pull_request: branches: + - main - dev jobs: @@ -18,17 +20,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - - name: Get latest go version - id: version - run: | - echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g') - name: Setup Go uses: actions/setup-go@v4 with: - go-version: ${{ steps.version.outputs.go_version }} + go-version: ^1.22 - name: Cache go module uses: actions/cache@v3 with: diff --git a/.gitignore b/.gitignore index f7f8ac3..f1298ea 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /.idea/ /vendor/ +.DS_Store diff --git a/Makefile b/Makefile index e47456e..91ac97b 100644 --- a/Makefile +++ b/Makefile @@ -18,4 +18,4 @@ lint_install: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest test: - go test -v ./... \ No newline at end of file + go test $(shell go list ./... | grep -v /internal/) diff --git a/common/abx/reader.go b/common/abx/reader.go index 1629572..16f839c 100644 --- a/common/abx/reader.go +++ b/common/abx/reader.go @@ -18,7 +18,6 @@ var _ xml.TokenReader = (*Reader)(nil) type Reader struct { reader *bytes.Reader stringRefs []string - attrs []xml.Attr } func NewReader(content []byte) (xml.TokenReader, bool) { @@ -47,7 +46,7 @@ func (r *Reader) Token() (token xml.Token, err error) { return } var attrs []xml.Attr - attrs, err = r.pullAttributes() + attrs, err = r.readAttributes() if err != nil { return } @@ -93,35 +92,41 @@ func (r *Reader) Token() (token xml.Token, err error) { _, err = r.readUTF() return case ATTRIBUTE: - return nil, E.New("unexpected attribute") + _, err = r.readAttribute() + return } return nil, E.New("unknown token type ", tokenType, " with type ", eventType) } -func (r *Reader) pullAttributes() ([]xml.Attr, error) { - err := r.pullAttribute() - if err != nil { - return nil, err +func (r *Reader) readAttributes() ([]xml.Attr, error) { + var attrs []xml.Attr + for { + attr, err := r.readAttribute() + if err == io.EOF { + break + } + attrs = append(attrs, attr) } - attrs := r.attrs - r.attrs = nil return attrs, nil } -func (r *Reader) pullAttribute() error { +func (r *Reader) readAttribute() (xml.Attr, error) { event, err := r.reader.ReadByte() if err != nil { - return nil + return xml.Attr{}, nil } tokenType := event & 0x0f eventType := event & 0xf0 if tokenType != ATTRIBUTE { - return r.reader.UnreadByte() + err = r.reader.UnreadByte() + if err != nil { + return xml.Attr{}, nil + } + return xml.Attr{}, io.EOF } - var name string - name, err = r.readInternedUTF() + name, err := r.readInternedUTF() if err != nil { - return err + return xml.Attr{}, err } var value string switch eventType { @@ -134,74 +139,73 @@ func (r *Reader) pullAttribute() error { case TypeString: value, err = r.readUTF() if err != nil { - return err + return xml.Attr{}, err } case TypeStringInterned: value, err = r.readInternedUTF() if err != nil { - return err + return xml.Attr{}, err } case TypeBytesHex: var data []byte data, err = r.readBytes() if err != nil { - return err + return xml.Attr{}, err } value = hex.EncodeToString(data) case TypeBytesBase64: var data []byte data, err = r.readBytes() if err != nil { - return err + return xml.Attr{}, err } value = base64.StdEncoding.EncodeToString(data) case TypeInt: var data int32 err = binary.Read(r.reader, binary.BigEndian, &data) if err != nil { - return err + return xml.Attr{}, err } value = strconv.FormatInt(int64(data), 10) case TypeIntHex: var data int32 err = binary.Read(r.reader, binary.BigEndian, &data) if err != nil { - return err + return xml.Attr{}, err } value = "0x" + strconv.FormatInt(int64(data), 16) case TypeLong: var data int64 err = binary.Read(r.reader, binary.BigEndian, &data) if err != nil { - return err + return xml.Attr{}, err } value = strconv.FormatInt(data, 10) case TypeLongHex: var data int64 err = binary.Read(r.reader, binary.BigEndian, &data) if err != nil { - return err + return xml.Attr{}, err } value = "0x" + strconv.FormatInt(data, 16) case TypeFloat: var data float32 err = binary.Read(r.reader, binary.BigEndian, &data) if err != nil { - return err + return xml.Attr{}, err } value = strconv.FormatFloat(float64(data), 'g', -1, 32) case TypeDouble: var data float64 err = binary.Read(r.reader, binary.BigEndian, &data) if err != nil { - return err + return xml.Attr{}, err } value = strconv.FormatFloat(data, 'g', -1, 64) default: - return E.New("unexpected attribute type, ", eventType) + return xml.Attr{}, E.New("unexpected attribute type, ", eventType) } - r.attrs = append(r.attrs, xml.Attr{Name: xml.Name{Local: name}, Value: value}) - return r.pullAttribute() + return xml.Attr{Name: xml.Name{Local: name}, Value: value}, nil } func (r *Reader) readUnsignedShort() (uint16, error) { diff --git a/common/atomic/typed.go b/common/atomic/typed.go index 23f77fb..19934f8 100644 --- a/common/atomic/typed.go +++ b/common/atomic/typed.go @@ -10,26 +10,37 @@ type TypedValue[T any] struct { value atomic.Value } +// typedValue is a struct with determined type to resolve atomic.Value usages with interface types +// https://github.com/golang/go/issues/22550 +// +// The intention to have an atomic value store for errors. However, running this code panics: +// panic: sync/atomic: store of inconsistently typed value into Value +// This is because atomic.Value requires that the underlying concrete type be the same (which is a reasonable expectation for its implementation). +// When going through the atomic.Value.Store method call, the fact that both these are of the error interface is lost. +type typedValue[T any] struct { + value T +} + func (t *TypedValue[T]) Load() T { value := t.value.Load() if value == nil { return common.DefaultValue[T]() } - return value.(T) + return value.(typedValue[T]).value } func (t *TypedValue[T]) Store(value T) { - t.value.Store(value) + t.value.Store(typedValue[T]{value}) } func (t *TypedValue[T]) Swap(new T) T { - old := t.value.Swap(new) + old := t.value.Swap(typedValue[T]{new}) if old == nil { return common.DefaultValue[T]() } - return old.(T) + return old.(typedValue[T]).value } func (t *TypedValue[T]) CompareAndSwap(old, new T) bool { - return t.value.CompareAndSwap(old, new) + return t.value.CompareAndSwap(typedValue[T]{old}, typedValue[T]{new}) } diff --git a/common/auth/auth.go b/common/auth/auth.go index 024b381..b1be60d 100644 --- a/common/auth/auth.go +++ b/common/auth/auth.go @@ -1,38 +1,30 @@ package auth -type Authenticator interface { - Verify(user string, pass string) bool - Users() []string -} +import "github.com/sagernet/sing/common" type User struct { - Username string `json:"username"` - Password string `json:"password"` + Username string + Password string } -type inMemoryAuthenticator struct { - storage map[string]string - usernames []string +type Authenticator struct { + userMap map[string][]string } -func (au *inMemoryAuthenticator) Verify(username string, password string) bool { - realPass, ok := au.storage[username] - return ok && realPass == password -} - -func (au *inMemoryAuthenticator) Users() []string { return au.usernames } - -func NewAuthenticator(users []User) Authenticator { +func NewAuthenticator(users []User) *Authenticator { if len(users) == 0 { return nil } - au := &inMemoryAuthenticator{ - storage: make(map[string]string), - usernames: make([]string, 0, len(users)), + au := &Authenticator{ + userMap: make(map[string][]string), } for _, user := range users { - au.storage[user.Username] = user.Password - au.usernames = append(au.usernames, user.Username) + au.userMap[user.Username] = append(au.userMap[user.Username], user.Password) } return au } + +func (au *Authenticator) Verify(username string, password string) bool { + passwordList, ok := au.userMap[username] + return ok && common.Contains(passwordList, password) +} diff --git a/common/baderror/baderror.go b/common/baderror/baderror.go index 74e37a2..c5ab530 100644 --- a/common/baderror/baderror.go +++ b/common/baderror/baderror.go @@ -55,7 +55,10 @@ func WrapQUIC(err error) error { if err == nil { return nil } - if Contains(err, "canceled with error code 0") { + if Contains(err, + "canceled by remote with error code 0", + "canceled by local with error code 0", + ) { return net.ErrClosed } return err diff --git a/common/binary/README.md b/common/binary/README.md new file mode 100644 index 0000000..4b82a6a --- /dev/null +++ b/common/binary/README.md @@ -0,0 +1,3 @@ +# binary + +mod from go 1.22.3 \ No newline at end of file diff --git a/common/binary/binary.go b/common/binary/binary.go new file mode 100644 index 0000000..41558ab --- /dev/null +++ b/common/binary/binary.go @@ -0,0 +1,817 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package binary implements simple translation between numbers and byte +// sequences and encoding and decoding of varints. +// +// Numbers are translated by reading and writing fixed-size values. +// A fixed-size value is either a fixed-size arithmetic +// type (bool, int8, uint8, int16, float32, complex64, ...) +// or an array or struct containing only fixed-size values. +// +// The varint functions encode and decode single integer values using +// a variable-length encoding; smaller values require fewer bytes. +// For a specification, see +// https://developers.google.com/protocol-buffers/docs/encoding. +// +// This package favors simplicity over efficiency. Clients that require +// high-performance serialization, especially for large data structures, +// should look at more advanced solutions such as the [encoding/gob] +// package or [google.golang.org/protobuf] for protocol buffers. +package binary + +import ( + "errors" + "io" + "math" + "reflect" + "sync" +) + +// A ByteOrder specifies how to convert byte slices into +// 16-, 32-, or 64-bit unsigned integers. +// +// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian]. +type ByteOrder interface { + Uint16([]byte) uint16 + Uint32([]byte) uint32 + Uint64([]byte) uint64 + PutUint16([]byte, uint16) + PutUint32([]byte, uint32) + PutUint64([]byte, uint64) + String() string +} + +// AppendByteOrder specifies how to append 16-, 32-, or 64-bit unsigned integers +// into a byte slice. +// +// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian]. +type AppendByteOrder interface { + AppendUint16([]byte, uint16) []byte + AppendUint32([]byte, uint32) []byte + AppendUint64([]byte, uint64) []byte + String() string +} + +// LittleEndian is the little-endian implementation of [ByteOrder] and [AppendByteOrder]. +var LittleEndian littleEndian + +// BigEndian is the big-endian implementation of [ByteOrder] and [AppendByteOrder]. +var BigEndian bigEndian + +type littleEndian struct{} + +func (littleEndian) Uint16(b []byte) uint16 { + _ = b[1] // bounds check hint to compiler; see golang.org/issue/14808 + return uint16(b[0]) | uint16(b[1])<<8 +} + +func (littleEndian) PutUint16(b []byte, v uint16) { + _ = b[1] // early bounds check to guarantee safety of writes below + b[0] = byte(v) + b[1] = byte(v >> 8) +} + +func (littleEndian) AppendUint16(b []byte, v uint16) []byte { + return append(b, + byte(v), + byte(v>>8), + ) +} + +func (littleEndian) Uint32(b []byte) uint32 { + _ = b[3] // bounds check hint to compiler; see golang.org/issue/14808 + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func (littleEndian) PutUint32(b []byte, v uint32) { + _ = b[3] // early bounds check to guarantee safety of writes below + b[0] = byte(v) + b[1] = byte(v >> 8) + b[2] = byte(v >> 16) + b[3] = byte(v >> 24) +} + +func (littleEndian) AppendUint32(b []byte, v uint32) []byte { + return append(b, + byte(v), + byte(v>>8), + byte(v>>16), + byte(v>>24), + ) +} + +func (littleEndian) Uint64(b []byte) uint64 { + _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | + uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56 +} + +func (littleEndian) PutUint64(b []byte, v uint64) { + _ = b[7] // early bounds check to guarantee safety of writes below + b[0] = byte(v) + b[1] = byte(v >> 8) + b[2] = byte(v >> 16) + b[3] = byte(v >> 24) + b[4] = byte(v >> 32) + b[5] = byte(v >> 40) + b[6] = byte(v >> 48) + b[7] = byte(v >> 56) +} + +func (littleEndian) AppendUint64(b []byte, v uint64) []byte { + return append(b, + byte(v), + byte(v>>8), + byte(v>>16), + byte(v>>24), + byte(v>>32), + byte(v>>40), + byte(v>>48), + byte(v>>56), + ) +} + +func (littleEndian) String() string { return "LittleEndian" } + +func (littleEndian) GoString() string { return "binary.LittleEndian" } + +type bigEndian struct{} + +func (bigEndian) Uint16(b []byte) uint16 { + _ = b[1] // bounds check hint to compiler; see golang.org/issue/14808 + return uint16(b[1]) | uint16(b[0])<<8 +} + +func (bigEndian) PutUint16(b []byte, v uint16) { + _ = b[1] // early bounds check to guarantee safety of writes below + b[0] = byte(v >> 8) + b[1] = byte(v) +} + +func (bigEndian) AppendUint16(b []byte, v uint16) []byte { + return append(b, + byte(v>>8), + byte(v), + ) +} + +func (bigEndian) Uint32(b []byte) uint32 { + _ = b[3] // bounds check hint to compiler; see golang.org/issue/14808 + return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 +} + +func (bigEndian) PutUint32(b []byte, v uint32) { + _ = b[3] // early bounds check to guarantee safety of writes below + b[0] = byte(v >> 24) + b[1] = byte(v >> 16) + b[2] = byte(v >> 8) + b[3] = byte(v) +} + +func (bigEndian) AppendUint32(b []byte, v uint32) []byte { + return append(b, + byte(v>>24), + byte(v>>16), + byte(v>>8), + byte(v), + ) +} + +func (bigEndian) Uint64(b []byte) uint64 { + _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 | + uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56 +} + +func (bigEndian) PutUint64(b []byte, v uint64) { + _ = b[7] // early bounds check to guarantee safety of writes below + b[0] = byte(v >> 56) + b[1] = byte(v >> 48) + b[2] = byte(v >> 40) + b[3] = byte(v >> 32) + b[4] = byte(v >> 24) + b[5] = byte(v >> 16) + b[6] = byte(v >> 8) + b[7] = byte(v) +} + +func (bigEndian) AppendUint64(b []byte, v uint64) []byte { + return append(b, + byte(v>>56), + byte(v>>48), + byte(v>>40), + byte(v>>32), + byte(v>>24), + byte(v>>16), + byte(v>>8), + byte(v), + ) +} + +func (bigEndian) String() string { return "BigEndian" } + +func (bigEndian) GoString() string { return "binary.BigEndian" } + +func (nativeEndian) String() string { return "NativeEndian" } + +func (nativeEndian) GoString() string { return "binary.NativeEndian" } + +// Read reads structured binary data from r into data. +// Data must be a pointer to a fixed-size value or a slice +// of fixed-size values. +// Bytes read from r are decoded using the specified byte order +// and written to successive fields of the data. +// When decoding boolean values, a zero byte is decoded as false, and +// any other non-zero byte is decoded as true. +// When reading into structs, the field data for fields with +// blank (_) field names is skipped; i.e., blank field names +// may be used for padding. +// When reading into a struct, all non-blank fields must be exported +// or Read may panic. +// +// The error is [io.EOF] only if no bytes were read. +// If an [io.EOF] happens after reading some but not all the bytes, +// Read returns [io.ErrUnexpectedEOF]. +func Read(r io.Reader, order ByteOrder, data any) error { + // Fast path for basic types and slices. + if n := intDataSize(data); n != 0 { + bs := make([]byte, n) + if _, err := io.ReadFull(r, bs); err != nil { + return err + } + switch data := data.(type) { + case *bool: + *data = bs[0] != 0 + case *int8: + *data = int8(bs[0]) + case *uint8: + *data = bs[0] + case *int16: + *data = int16(order.Uint16(bs)) + case *uint16: + *data = order.Uint16(bs) + case *int32: + *data = int32(order.Uint32(bs)) + case *uint32: + *data = order.Uint32(bs) + case *int64: + *data = int64(order.Uint64(bs)) + case *uint64: + *data = order.Uint64(bs) + case *float32: + *data = math.Float32frombits(order.Uint32(bs)) + case *float64: + *data = math.Float64frombits(order.Uint64(bs)) + case []bool: + for i, x := range bs { // Easier to loop over the input for 8-bit values. + data[i] = x != 0 + } + case []int8: + for i, x := range bs { + data[i] = int8(x) + } + case []uint8: + copy(data, bs) + case []int16: + for i := range data { + data[i] = int16(order.Uint16(bs[2*i:])) + } + case []uint16: + for i := range data { + data[i] = order.Uint16(bs[2*i:]) + } + case []int32: + for i := range data { + data[i] = int32(order.Uint32(bs[4*i:])) + } + case []uint32: + for i := range data { + data[i] = order.Uint32(bs[4*i:]) + } + case []int64: + for i := range data { + data[i] = int64(order.Uint64(bs[8*i:])) + } + case []uint64: + for i := range data { + data[i] = order.Uint64(bs[8*i:]) + } + case []float32: + for i := range data { + data[i] = math.Float32frombits(order.Uint32(bs[4*i:])) + } + case []float64: + for i := range data { + data[i] = math.Float64frombits(order.Uint64(bs[8*i:])) + } + default: + n = 0 // fast path doesn't apply + } + if n != 0 { + return nil + } + } + + // Fallback to reflect-based decoding. + v := reflect.ValueOf(data) + size := -1 + switch v.Kind() { + case reflect.Pointer: + v = v.Elem() + size = dataSize(v) + case reflect.Slice: + size = dataSize(v) + } + if size < 0 { + return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String()) + } + d := &decoder{order: order, buf: make([]byte, size)} + if _, err := io.ReadFull(r, d.buf); err != nil { + return err + } + d.value(v) + return nil +} + +// Write writes the binary representation of data into w. +// Data must be a fixed-size value or a slice of fixed-size +// values, or a pointer to such data. +// Boolean values encode as one byte: 1 for true, and 0 for false. +// Bytes written to w are encoded using the specified byte order +// and read from successive fields of the data. +// When writing structs, zero values are written for fields +// with blank (_) field names. +func Write(w io.Writer, order ByteOrder, data any) error { + // Fast path for basic types and slices. + if n := intDataSize(data); n != 0 { + bs := make([]byte, n) + switch v := data.(type) { + case *bool: + if *v { + bs[0] = 1 + } else { + bs[0] = 0 + } + case bool: + if v { + bs[0] = 1 + } else { + bs[0] = 0 + } + case []bool: + for i, x := range v { + if x { + bs[i] = 1 + } else { + bs[i] = 0 + } + } + case *int8: + bs[0] = byte(*v) + case int8: + bs[0] = byte(v) + case []int8: + for i, x := range v { + bs[i] = byte(x) + } + case *uint8: + bs[0] = *v + case uint8: + bs[0] = v + case []uint8: + bs = v + case *int16: + order.PutUint16(bs, uint16(*v)) + case int16: + order.PutUint16(bs, uint16(v)) + case []int16: + for i, x := range v { + order.PutUint16(bs[2*i:], uint16(x)) + } + case *uint16: + order.PutUint16(bs, *v) + case uint16: + order.PutUint16(bs, v) + case []uint16: + for i, x := range v { + order.PutUint16(bs[2*i:], x) + } + case *int32: + order.PutUint32(bs, uint32(*v)) + case int32: + order.PutUint32(bs, uint32(v)) + case []int32: + for i, x := range v { + order.PutUint32(bs[4*i:], uint32(x)) + } + case *uint32: + order.PutUint32(bs, *v) + case uint32: + order.PutUint32(bs, v) + case []uint32: + for i, x := range v { + order.PutUint32(bs[4*i:], x) + } + case *int64: + order.PutUint64(bs, uint64(*v)) + case int64: + order.PutUint64(bs, uint64(v)) + case []int64: + for i, x := range v { + order.PutUint64(bs[8*i:], uint64(x)) + } + case *uint64: + order.PutUint64(bs, *v) + case uint64: + order.PutUint64(bs, v) + case []uint64: + for i, x := range v { + order.PutUint64(bs[8*i:], x) + } + case *float32: + order.PutUint32(bs, math.Float32bits(*v)) + case float32: + order.PutUint32(bs, math.Float32bits(v)) + case []float32: + for i, x := range v { + order.PutUint32(bs[4*i:], math.Float32bits(x)) + } + case *float64: + order.PutUint64(bs, math.Float64bits(*v)) + case float64: + order.PutUint64(bs, math.Float64bits(v)) + case []float64: + for i, x := range v { + order.PutUint64(bs[8*i:], math.Float64bits(x)) + } + } + _, err := w.Write(bs) + return err + } + + // Fallback to reflect-based encoding. + v := reflect.Indirect(reflect.ValueOf(data)) + size := dataSize(v) + if size < 0 { + return errors.New("binary.Write: some values are not fixed-sized in type " + reflect.TypeOf(data).String()) + } + buf := make([]byte, size) + e := &encoder{order: order, buf: buf} + e.value(v) + _, err := w.Write(buf) + return err +} + +// Size returns how many bytes [Write] would generate to encode the value v, which +// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data. +// If v is neither of these, Size returns -1. +func Size(v any) int { + return dataSize(reflect.Indirect(reflect.ValueOf(v))) +} + +var structSize sync.Map // map[reflect.Type]int + +// dataSize returns the number of bytes the actual data represented by v occupies in memory. +// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice +// it returns the length of the slice times the element size and does not count the memory +// occupied by the header. If the type of v is not acceptable, dataSize returns -1. +func dataSize(v reflect.Value) int { + switch v.Kind() { + case reflect.Slice: + if s := sizeof(v.Type().Elem()); s >= 0 { + return s * v.Len() + } + + case reflect.Struct: + t := v.Type() + if size, ok := structSize.Load(t); ok { + return size.(int) + } + size := sizeof(t) + structSize.Store(t, size) + return size + + default: + if v.IsValid() { + return sizeof(v.Type()) + } + } + + return -1 +} + +// sizeof returns the size >= 0 of variables for the given type or -1 if the type is not acceptable. +func sizeof(t reflect.Type) int { + switch t.Kind() { + case reflect.Array: + if s := sizeof(t.Elem()); s >= 0 { + return s * t.Len() + } + + case reflect.Struct: + sum := 0 + for i, n := 0, t.NumField(); i < n; i++ { + s := sizeof(t.Field(i).Type) + if s < 0 { + return -1 + } + sum += s + } + return sum + + case reflect.Bool, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return int(t.Size()) + } + + return -1 +} + +type coder struct { + order ByteOrder + buf []byte + offset int +} + +type ( + decoder coder + encoder coder +) + +func (d *decoder) bool() bool { + x := d.buf[d.offset] + d.offset++ + return x != 0 +} + +func (e *encoder) bool(x bool) { + if x { + e.buf[e.offset] = 1 + } else { + e.buf[e.offset] = 0 + } + e.offset++ +} + +func (d *decoder) uint8() uint8 { + x := d.buf[d.offset] + d.offset++ + return x +} + +func (e *encoder) uint8(x uint8) { + e.buf[e.offset] = x + e.offset++ +} + +func (d *decoder) uint16() uint16 { + x := d.order.Uint16(d.buf[d.offset : d.offset+2]) + d.offset += 2 + return x +} + +func (e *encoder) uint16(x uint16) { + e.order.PutUint16(e.buf[e.offset:e.offset+2], x) + e.offset += 2 +} + +func (d *decoder) uint32() uint32 { + x := d.order.Uint32(d.buf[d.offset : d.offset+4]) + d.offset += 4 + return x +} + +func (e *encoder) uint32(x uint32) { + e.order.PutUint32(e.buf[e.offset:e.offset+4], x) + e.offset += 4 +} + +func (d *decoder) uint64() uint64 { + x := d.order.Uint64(d.buf[d.offset : d.offset+8]) + d.offset += 8 + return x +} + +func (e *encoder) uint64(x uint64) { + e.order.PutUint64(e.buf[e.offset:e.offset+8], x) + e.offset += 8 +} + +func (d *decoder) int8() int8 { return int8(d.uint8()) } + +func (e *encoder) int8(x int8) { e.uint8(uint8(x)) } + +func (d *decoder) int16() int16 { return int16(d.uint16()) } + +func (e *encoder) int16(x int16) { e.uint16(uint16(x)) } + +func (d *decoder) int32() int32 { return int32(d.uint32()) } + +func (e *encoder) int32(x int32) { e.uint32(uint32(x)) } + +func (d *decoder) int64() int64 { return int64(d.uint64()) } + +func (e *encoder) int64(x int64) { e.uint64(uint64(x)) } + +func (d *decoder) value(v reflect.Value) { + switch v.Kind() { + case reflect.Array: + l := v.Len() + for i := 0; i < l; i++ { + d.value(v.Index(i)) + } + + case reflect.Struct: + t := v.Type() + l := v.NumField() + for i := 0; i < l; i++ { + // Note: Calling v.CanSet() below is an optimization. + // It would be sufficient to check the field name, + // but creating the StructField info for each field is + // costly (run "go test -bench=ReadStruct" and compare + // results when making changes to this code). + if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" { + d.value(v) + } else { + d.skip(v) + } + } + + case reflect.Slice: + l := v.Len() + for i := 0; i < l; i++ { + d.value(v.Index(i)) + } + + case reflect.Bool: + v.SetBool(d.bool()) + + case reflect.Int8: + v.SetInt(int64(d.int8())) + case reflect.Int16: + v.SetInt(int64(d.int16())) + case reflect.Int32: + v.SetInt(int64(d.int32())) + case reflect.Int64: + v.SetInt(d.int64()) + + case reflect.Uint8: + v.SetUint(uint64(d.uint8())) + case reflect.Uint16: + v.SetUint(uint64(d.uint16())) + case reflect.Uint32: + v.SetUint(uint64(d.uint32())) + case reflect.Uint64: + v.SetUint(d.uint64()) + + case reflect.Float32: + v.SetFloat(float64(math.Float32frombits(d.uint32()))) + case reflect.Float64: + v.SetFloat(math.Float64frombits(d.uint64())) + + case reflect.Complex64: + v.SetComplex(complex( + float64(math.Float32frombits(d.uint32())), + float64(math.Float32frombits(d.uint32())), + )) + case reflect.Complex128: + v.SetComplex(complex( + math.Float64frombits(d.uint64()), + math.Float64frombits(d.uint64()), + )) + } +} + +func (e *encoder) value(v reflect.Value) { + switch v.Kind() { + case reflect.Array: + l := v.Len() + for i := 0; i < l; i++ { + e.value(v.Index(i)) + } + + case reflect.Struct: + t := v.Type() + l := v.NumField() + for i := 0; i < l; i++ { + // see comment for corresponding code in decoder.value() + if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" { + e.value(v) + } else { + e.skip(v) + } + } + + case reflect.Slice: + l := v.Len() + for i := 0; i < l; i++ { + e.value(v.Index(i)) + } + + case reflect.Bool: + e.bool(v.Bool()) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch v.Type().Kind() { + case reflect.Int8: + e.int8(int8(v.Int())) + case reflect.Int16: + e.int16(int16(v.Int())) + case reflect.Int32: + e.int32(int32(v.Int())) + case reflect.Int64: + e.int64(v.Int()) + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + switch v.Type().Kind() { + case reflect.Uint8: + e.uint8(uint8(v.Uint())) + case reflect.Uint16: + e.uint16(uint16(v.Uint())) + case reflect.Uint32: + e.uint32(uint32(v.Uint())) + case reflect.Uint64: + e.uint64(v.Uint()) + } + + case reflect.Float32, reflect.Float64: + switch v.Type().Kind() { + case reflect.Float32: + e.uint32(math.Float32bits(float32(v.Float()))) + case reflect.Float64: + e.uint64(math.Float64bits(v.Float())) + } + + case reflect.Complex64, reflect.Complex128: + switch v.Type().Kind() { + case reflect.Complex64: + x := v.Complex() + e.uint32(math.Float32bits(float32(real(x)))) + e.uint32(math.Float32bits(float32(imag(x)))) + case reflect.Complex128: + x := v.Complex() + e.uint64(math.Float64bits(real(x))) + e.uint64(math.Float64bits(imag(x))) + } + } +} + +func (d *decoder) skip(v reflect.Value) { + d.offset += dataSize(v) +} + +func (e *encoder) skip(v reflect.Value) { + n := dataSize(v) + zero := e.buf[e.offset : e.offset+n] + for i := range zero { + zero[i] = 0 + } + e.offset += n +} + +// intDataSize returns the size of the data required to represent the data when encoded. +// It returns zero if the type cannot be implemented by the fast path in Read or Write. +func intDataSize(data any) int { + switch data := data.(type) { + case bool, int8, uint8, *bool, *int8, *uint8: + return 1 + case []bool: + return len(data) + case []int8: + return len(data) + case []uint8: + return len(data) + case int16, uint16, *int16, *uint16: + return 2 + case []int16: + return 2 * len(data) + case []uint16: + return 2 * len(data) + case int32, uint32, *int32, *uint32: + return 4 + case []int32: + return 4 * len(data) + case []uint32: + return 4 * len(data) + case int64, uint64, *int64, *uint64: + return 8 + case []int64: + return 8 * len(data) + case []uint64: + return 8 * len(data) + case float32, *float32: + return 4 + case float64, *float64: + return 8 + case []float32: + return 4 * len(data) + case []float64: + return 8 * len(data) + } + return 0 +} diff --git a/common/binary/native_endian_big.go b/common/binary/native_endian_big.go new file mode 100644 index 0000000..bcc8e30 --- /dev/null +++ b/common/binary/native_endian_big.go @@ -0,0 +1,14 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build armbe || arm64be || m68k || mips || mips64 || mips64p32 || ppc || ppc64 || s390 || s390x || shbe || sparc || sparc64 + +package binary + +type nativeEndian struct { + bigEndian +} + +// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder]. +var NativeEndian nativeEndian diff --git a/common/binary/native_endian_little.go b/common/binary/native_endian_little.go new file mode 100644 index 0000000..38d3e9b --- /dev/null +++ b/common/binary/native_endian_little.go @@ -0,0 +1,14 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build 386 || amd64 || amd64p32 || alpha || arm || arm64 || loong64 || mipsle || mips64le || mips64p32le || nios2 || ppc64le || riscv || riscv64 || sh || wasm + +package binary + +type nativeEndian struct { + littleEndian +} + +// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder]. +var NativeEndian nativeEndian diff --git a/common/binary/variant_data.go b/common/binary/variant_data.go new file mode 100644 index 0000000..eb2b128 --- /dev/null +++ b/common/binary/variant_data.go @@ -0,0 +1,305 @@ +package binary + +import ( + "bufio" + "errors" + "io" + "reflect" + + E "github.com/sagernet/sing/common/exceptions" +) + +func ReadDataSlice(r *bufio.Reader, order ByteOrder, data ...any) error { + for index, item := range data { + err := ReadData(r, order, item) + if err != nil { + return E.Cause(err, "[", index, "]") + } + } + return nil +} + +func ReadData(r *bufio.Reader, order ByteOrder, data any) error { + switch dataPtr := data.(type) { + case *[]uint8: + bytesLen, err := ReadUvarint(r) + if err != nil { + return E.Cause(err, "bytes length") + } + newBytes := make([]uint8, bytesLen) + _, err = io.ReadFull(r, newBytes) + if err != nil { + return E.Cause(err, "bytes value") + } + *dataPtr = newBytes + default: + if intBaseDataSize(data) != 0 { + return Read(r, order, data) + } + } + dataValue := reflect.ValueOf(data) + if dataValue.Kind() == reflect.Pointer { + dataValue = dataValue.Elem() + } + return readData(r, order, dataValue) +} + +func readData(r *bufio.Reader, order ByteOrder, data reflect.Value) error { + switch data.Kind() { + case reflect.Pointer: + pointerValue, err := r.ReadByte() + if err != nil { + return err + } + if pointerValue == 0 { + data.SetZero() + return nil + } + if data.IsNil() { + data.Set(reflect.New(data.Type().Elem())) + } + return readData(r, order, data.Elem()) + case reflect.String: + stringLength, err := ReadUvarint(r) + if err != nil { + return E.Cause(err, "string length") + } + if stringLength == 0 { + data.SetZero() + } else { + stringData := make([]byte, stringLength) + _, err = io.ReadFull(r, stringData) + if err != nil { + return E.Cause(err, "string value") + } + data.SetString(string(stringData)) + } + case reflect.Array: + arrayLen := data.Len() + for i := 0; i < arrayLen; i++ { + err := readData(r, order, data.Index(i)) + if err != nil { + return E.Cause(err, "[", i, "]") + } + } + case reflect.Slice: + sliceLength, err := ReadUvarint(r) + if err != nil { + return E.Cause(err, "slice length") + } + if !data.IsNil() && data.Cap() >= int(sliceLength) { + data.SetLen(int(sliceLength)) + } else if sliceLength > 0 { + data.Set(reflect.MakeSlice(data.Type(), int(sliceLength), int(sliceLength))) + } + if sliceLength > 0 { + if data.Type().Elem().Kind() == reflect.Uint8 { + _, err = io.ReadFull(r, data.Bytes()) + if err != nil { + return E.Cause(err, "bytes value") + } + } else { + for index := 0; index < int(sliceLength); index++ { + err = readData(r, order, data.Index(index)) + if err != nil { + return E.Cause(err, "[", index, "]") + } + } + } + } + case reflect.Map: + mapLength, err := ReadUvarint(r) + if err != nil { + return E.Cause(err, "map length") + } + data.Set(reflect.MakeMap(data.Type())) + for index := 0; index < int(mapLength); index++ { + key := reflect.New(data.Type().Key()).Elem() + err = readData(r, order, key) + if err != nil { + return E.Cause(err, "[", index, "].key") + } + value := reflect.New(data.Type().Elem()).Elem() + err = readData(r, order, value) + if err != nil { + return E.Cause(err, "[", index, "].value") + } + data.SetMapIndex(key, value) + } + case reflect.Struct: + fieldType := data.Type() + fieldLen := data.NumField() + for i := 0; i < fieldLen; i++ { + field := data.Field(i) + fieldName := fieldType.Field(i).Name + if field.CanSet() || fieldName != "_" { + err := readData(r, order, field) + if err != nil { + return E.Cause(err, fieldName) + } + } + } + default: + size := dataSize(data) + if size < 0 { + return errors.New("invalid type " + reflect.TypeOf(data).String()) + } + d := &decoder{order: order, buf: make([]byte, size)} + _, err := io.ReadFull(r, d.buf) + if err != nil { + return err + } + d.value(data) + } + return nil +} + +func WriteDataSlice(writer *bufio.Writer, order ByteOrder, data ...any) error { + for index, item := range data { + err := WriteData(writer, order, item) + if err != nil { + return E.Cause(err, "[", index, "]") + } + } + return nil +} + +func WriteData(writer *bufio.Writer, order ByteOrder, data any) error { + switch dataPtr := data.(type) { + case []uint8: + _, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(len(dataPtr)))) + if err != nil { + return E.Cause(err, "bytes length") + } + _, err = writer.Write(dataPtr) + if err != nil { + return E.Cause(err, "bytes value") + } + default: + if intBaseDataSize(data) != 0 { + return Write(writer, order, data) + } + } + return writeData(writer, order, reflect.Indirect(reflect.ValueOf(data))) +} + +func writeData(writer *bufio.Writer, order ByteOrder, data reflect.Value) error { + switch data.Kind() { + case reflect.Pointer: + if data.IsNil() { + err := writer.WriteByte(0) + if err != nil { + return err + } + } else { + err := writer.WriteByte(1) + if err != nil { + return err + } + return writeData(writer, order, data.Elem()) + } + case reflect.String: + stringValue := data.String() + _, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(len(stringValue)))) + if err != nil { + return E.Cause(err, "string length") + } + if stringValue != "" { + _, err = writer.WriteString(stringValue) + if err != nil { + return E.Cause(err, "string value") + } + } + case reflect.Array: + dataLen := data.Len() + for i := 0; i < dataLen; i++ { + err := writeData(writer, order, data.Index(i)) + if err != nil { + return E.Cause(err, "[", i, "]") + } + } + case reflect.Slice: + dataLen := data.Len() + _, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(dataLen))) + if err != nil { + return E.Cause(err, "slice length") + } + if dataLen > 0 { + if data.Type().Elem().Kind() == reflect.Uint8 { + _, err = writer.Write(data.Bytes()) + if err != nil { + return E.Cause(err, "bytes value") + } + } else { + for i := 0; i < dataLen; i++ { + err = writeData(writer, order, data.Index(i)) + if err != nil { + return E.Cause(err, "[", i, "]") + } + } + } + } + case reflect.Map: + dataLen := data.Len() + _, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(dataLen))) + if err != nil { + return E.Cause(err, "map length") + } + if dataLen > 0 { + for index, key := range data.MapKeys() { + err = writeData(writer, order, key) + if err != nil { + return E.Cause(err, "[", index, "].key") + } + err = writeData(writer, order, data.MapIndex(key)) + if err != nil { + return E.Cause(err, "[", index, "].value") + } + } + } + case reflect.Struct: + fieldType := data.Type() + fieldLen := data.NumField() + for i := 0; i < fieldLen; i++ { + field := data.Field(i) + fieldName := fieldType.Field(i).Name + if field.CanSet() || fieldName != "_" { + err := writeData(writer, order, field) + if err != nil { + return E.Cause(err, fieldName) + } + } + } + default: + size := dataSize(data) + if size < 0 { + return errors.New("binary.Write: some values are not fixed-sized in type " + data.Type().String()) + } + buf := make([]byte, size) + e := &encoder{order: order, buf: buf} + e.value(data) + _, err := writer.Write(buf) + if err != nil { + return E.Cause(err, reflect.TypeOf(data).String()) + } + } + return nil +} + +func intBaseDataSize(data any) int { + switch data.(type) { + case bool, int8, uint8: + return 1 + case int16, uint16: + return 2 + case int32, uint32: + return 4 + case int64, uint64: + return 8 + case float32: + return 4 + case float64: + return 8 + } + return 0 +} diff --git a/common/binary/varint.go b/common/binary/varint.go new file mode 100644 index 0000000..64dd9d6 --- /dev/null +++ b/common/binary/varint.go @@ -0,0 +1,166 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package binary + +// This file implements "varint" encoding of 64-bit integers. +// The encoding is: +// - unsigned integers are serialized 7 bits at a time, starting with the +// least significant bits +// - the most significant bit (msb) in each output byte indicates if there +// is a continuation byte (msb = 1) +// - signed integers are mapped to unsigned integers using "zig-zag" +// encoding: Positive values x are written as 2*x + 0, negative values +// are written as 2*(^x) + 1; that is, negative numbers are complemented +// and whether to complement is encoded in bit 0. +// +// Design note: +// At most 10 bytes are needed for 64-bit values. The encoding could +// be more dense: a full 64-bit value needs an extra byte just to hold bit 63. +// Instead, the msb of the previous byte could be used to hold bit 63 since we +// know there can't be more than 64 bits. This is a trivial improvement and +// would reduce the maximum encoding length to 9 bytes. However, it breaks the +// invariant that the msb is always the "continuation bit" and thus makes the +// format incompatible with a varint encoding for larger numbers (say 128-bit). + +import ( + "errors" + "io" +) + +// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer. +const ( + MaxVarintLen16 = 3 + MaxVarintLen32 = 5 + MaxVarintLen64 = 10 +) + +// AppendUvarint appends the varint-encoded form of x, +// as generated by [PutUvarint], to buf and returns the extended buffer. +func AppendUvarint(buf []byte, x uint64) []byte { + for x >= 0x80 { + buf = append(buf, byte(x)|0x80) + x >>= 7 + } + return append(buf, byte(x)) +} + +// PutUvarint encodes a uint64 into buf and returns the number of bytes written. +// If the buffer is too small, PutUvarint will panic. +func PutUvarint(buf []byte, x uint64) int { + i := 0 + for x >= 0x80 { + buf[i] = byte(x) | 0x80 + x >>= 7 + i++ + } + buf[i] = byte(x) + return i + 1 +} + +// Uvarint decodes a uint64 from buf and returns that value and the +// number of bytes read (> 0). If an error occurred, the value is 0 +// and the number of bytes n is <= 0 meaning: +// +// n == 0: buf too small +// n < 0: value larger than 64 bits (overflow) +// and -n is the number of bytes read +func Uvarint(buf []byte) (uint64, int) { + var x uint64 + var s uint + for i, b := range buf { + if i == MaxVarintLen64 { + // Catch byte reads past MaxVarintLen64. + // See issue https://golang.org/issues/41185 + return 0, -(i + 1) // overflow + } + if b < 0x80 { + if i == MaxVarintLen64-1 && b > 1 { + return 0, -(i + 1) // overflow + } + return x | uint64(b)< 0). If an error occurred, the value is 0 +// and the number of bytes n is <= 0 with the following meaning: +// +// n == 0: buf too small +// n < 0: value larger than 64 bits (overflow) +// and -n is the number of bytes read +func Varint(buf []byte) (int64, int) { + ux, n := Uvarint(buf) // ok to continue in presence of error + x := int64(ux >> 1) + if ux&1 != 0 { + x = ^x + } + return x, n +} + +var errOverflow = errors.New("binary: varint overflows a 64-bit integer") + +// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64. +// The error is [io.EOF] only if no bytes were read. +// If an [io.EOF] happens after reading some but not all the bytes, +// ReadUvarint returns [io.ErrUnexpectedEOF]. +func ReadUvarint(r io.ByteReader) (uint64, error) { + var x uint64 + var s uint + for i := 0; i < MaxVarintLen64; i++ { + b, err := r.ReadByte() + if err != nil { + if i > 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } + return x, err + } + if b < 0x80 { + if i == MaxVarintLen64-1 && b > 1 { + return x, errOverflow + } + return x | uint64(b)<> 1) + if ux&1 != 0 { + x = ^x + } + return x, err +} diff --git a/common/buf/alloc.go b/common/buf/alloc.go index 5d0b248..b556d93 100644 --- a/common/buf/alloc.go +++ b/common/buf/alloc.go @@ -8,7 +8,7 @@ import ( "sync" ) -var DefaultAllocator = newDefaultAllocer() +var DefaultAllocator = newDefaultAllocator() type Allocator interface { Get(size int) []byte @@ -17,22 +17,28 @@ type Allocator interface { // defaultAllocator for incoming frames, optimized to prevent overwriting after zeroing type defaultAllocator struct { - buffers []sync.Pool + buffers [11]sync.Pool } // NewAllocator initiates a []byte allocator for frames less than 65536 bytes, // the waste(memory fragmentation) of space allocation is guaranteed to be // no more than 50%. -func newDefaultAllocer() Allocator { - alloc := new(defaultAllocator) - alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K - for k := range alloc.buffers { - i := k - alloc.buffers[k].New = func() any { - return make([]byte, 1< 64K + {New: func() any { return new([1 << 6]byte) }}, + {New: func() any { return new([1 << 7]byte) }}, + {New: func() any { return new([1 << 8]byte) }}, + {New: func() any { return new([1 << 9]byte) }}, + {New: func() any { return new([1 << 10]byte) }}, + {New: func() any { return new([1 << 11]byte) }}, + {New: func() any { return new([1 << 12]byte) }}, + {New: func() any { return new([1 << 13]byte) }}, + {New: func() any { return new([1 << 14]byte) }}, + {New: func() any { return new([1 << 15]byte) }}, + {New: func() any { return new([1 << 16]byte) }}, + }, } - return alloc } // Get a []byte from pool with most appropriate cap @@ -41,12 +47,42 @@ func (alloc *defaultAllocator) Get(size int) []byte { return nil } - bits := msb(size) - if size == 1< 64 { + index = msb(size) + if size != 1< 65536 || cap(buf) != 1< 65535 { return &Buffer{ - data: make([]byte, size), + data: make([]byte, size), + capacity: size, } } return &Buffer{ - data: Get(size), - managed: true, - } -} - -func StackNew() *Buffer { - if common.UnsafeBuffer { - return &Buffer{ - data: make([]byte, BufferSize), - start: ReversedHeader, - end: ReversedHeader, - } - } else { - return New() - } -} - -func StackNewPacket() *Buffer { - if common.UnsafeBuffer { - return &Buffer{ - data: make([]byte, UDPBufferSize), - start: ReversedHeader, - end: ReversedHeader, - } - } else { - return NewPacket() - } -} - -func StackNewSize(size int) *Buffer { - if size == 0 { - return &Buffer{} - } - if common.UnsafeBuffer { - return &Buffer{ - data: Make(size), - } - } else { - return NewSize(size) + data: Get(size), + capacity: size, + managed: true, } } func As(data []byte) *Buffer { return &Buffer{ - data: data, - end: len(data), + data: data, + end: len(data), + capacity: len(data), } } func With(data []byte) *Buffer { return &Buffer{ - data: data, + data: data, + capacity: len(data), } } @@ -114,8 +78,8 @@ func (b *Buffer) SetByte(index int, value byte) { func (b *Buffer) Extend(n int) []byte { end := b.end + n - if end > cap(b.data) { - panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",end " + strconv.Itoa(b.end) + ", need " + strconv.Itoa(n)) + if end > b.capacity { + panic(F.ToString("buffer overflow: capacity ", b.capacity, ",end ", b.end, ", need ", n)) } ext := b.data[b.end:end] b.end = end @@ -137,14 +101,14 @@ func (b *Buffer) Write(data []byte) (n int, err error) { if b.IsFull() { return 0, io.ErrShortBuffer } - n = copy(b.data[b.end:], data) + n = copy(b.data[b.end:b.capacity], data) b.end += n return } func (b *Buffer) ExtendHeader(n int) []byte { if b.start < n { - panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",start " + strconv.Itoa(b.start) + ", need " + strconv.Itoa(n)) + panic(F.ToString("buffer overflow: capacity ", b.capacity, ",start ", b.start, ", need ", n)) } b.start -= n return b.data[b.start : b.start+n] @@ -197,7 +161,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) { } func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) { - if b.end+size > b.Cap() { + if b.end+size > b.capacity { return 0, io.ErrShortBuffer } n, err = io.ReadFull(r, b.data[b.end:b.end+size]) @@ -234,7 +198,7 @@ func (b *Buffer) WriteString(s string) (n int, err error) { if b.IsFull() { return 0, io.ErrShortBuffer } - n = copy(b.data[b.end:], s) + n = copy(b.data[b.end:b.capacity], s) b.end += n return } @@ -249,13 +213,10 @@ func (b *Buffer) WriteZero() error { } func (b *Buffer) WriteZeroN(n int) error { - if b.end+n > b.Cap() { + if b.end+n > b.capacity { return io.ErrShortBuffer } - for i := b.end; i <= b.end+n; i++ { - b.data[i] = 0 - } - b.end += n + common.ClearArray(b.Extend(n)) return nil } @@ -298,40 +259,63 @@ func (b *Buffer) Resize(start, end int) { b.end = b.start + end } -func (b *Buffer) Reset() { - b.start = ReversedHeader - b.end = ReversedHeader +func (b *Buffer) Reserve(n int) { + if n > b.capacity { + panic(F.ToString("buffer overflow: capacity ", b.capacity, ", need ", n)) + } + b.capacity -= n } -func (b *Buffer) FullReset() { +func (b *Buffer) OverCap(n int) { + if b.capacity+n > len(b.data) { + panic(F.ToString("buffer overflow: capacity ", len(b.data), ", need ", b.capacity+n)) + } + b.capacity += n +} + +func (b *Buffer) Reset() { b.start = 0 b.end = 0 + b.capacity = len(b.data) +} + +// Deprecated: use Reset instead. +func (b *Buffer) FullReset() { + b.Reset() } func (b *Buffer) IncRef() { - atomic.AddInt32(&b.refs, 1) + b.refs.Add(1) } func (b *Buffer) DecRef() { - atomic.AddInt32(&b.refs, -1) + b.refs.Add(-1) } func (b *Buffer) Release() { - if b == nil || b.closed || !b.managed { + if b == nil || !b.managed { return } - if atomic.LoadInt32(&b.refs) > 0 { + if b.refs.Load() > 0 { return } common.Must(Put(b.data)) - *b = Buffer{closed: true} + *b = Buffer{} } -func (b *Buffer) Cut(start int, end int) *Buffer { - b.start += start - b.end = len(b.data) - end - return &Buffer{ - data: b.data[b.start:b.end], +func (b *Buffer) Leak() { + if debug.Enabled { + if b == nil || !b.managed { + return + } + refs := b.refs.Load() + if refs == 0 { + panic("leaking buffer") + } else { + panic(F.ToString("leaking buffer with ", refs, " references")) + } + } else { + b.Release() } } @@ -344,6 +328,10 @@ func (b *Buffer) Len() int { } func (b *Buffer) Cap() int { + return b.capacity +} + +func (b *Buffer) RawCap() int { return len(b.data) } @@ -351,10 +339,6 @@ func (b *Buffer) Bytes() []byte { return b.data[b.start:b.end] } -func (b *Buffer) Slice() []byte { - return b.data -} - func (b *Buffer) From(n int) []byte { return b.data[b.start+n : b.end] } @@ -372,11 +356,11 @@ func (b *Buffer) Index(start int) []byte { } func (b *Buffer) FreeLen() int { - return b.Cap() - b.end + return b.capacity - b.end } func (b *Buffer) FreeBytes() []byte { - return b.data[b.end:b.Cap()] + return b.data[b.end:b.capacity] } func (b *Buffer) IsEmpty() bool { @@ -384,7 +368,7 @@ func (b *Buffer) IsEmpty() bool { } func (b *Buffer) IsFull() bool { - return b.end == b.Cap() + return b.end == b.capacity } func (b *Buffer) ToOwned() *Buffer { @@ -392,5 +376,6 @@ func (b *Buffer) ToOwned() *Buffer { copy(n.data[b.start:b.end], b.data[b.start:b.end]) n.start = b.start n.end = b.end + n.capacity = b.capacity return n } diff --git a/common/buf/hex.go b/common/buf/hex.go deleted file mode 100644 index ca54f67..0000000 --- a/common/buf/hex.go +++ /dev/null @@ -1,9 +0,0 @@ -package buf - -import "encoding/hex" - -func EncodeHexString(src []byte) string { - dst := Make(hex.EncodedLen(len(src))) - hex.Encode(dst, src) - return string(dst) -} diff --git a/common/buf/pool.go b/common/buf/pool.go index a729989..37f1232 100644 --- a/common/buf/pool.go +++ b/common/buf/pool.go @@ -11,46 +11,7 @@ func Put(buf []byte) error { return DefaultAllocator.Put(buf) } +// Deprecated: use array instead. func Make(size int) []byte { - if size == 0 { - return nil - } - var buffer []byte - switch { - case size <= 2: - buffer = make([]byte, 2) - case size <= 4: - buffer = make([]byte, 4) - case size <= 8: - buffer = make([]byte, 8) - case size <= 16: - buffer = make([]byte, 16) - case size <= 32: - buffer = make([]byte, 32) - case size <= 64: - buffer = make([]byte, 64) - case size <= 128: - buffer = make([]byte, 128) - case size <= 256: - buffer = make([]byte, 256) - case size <= 512: - buffer = make([]byte, 512) - case size <= 1024: - buffer = make([]byte, 1024) - case size <= 2048: - buffer = make([]byte, 2048) - case size <= 4096: - buffer = make([]byte, 4096) - case size <= 8192: - buffer = make([]byte, 8192) - case size <= 16384: - buffer = make([]byte, 16384) - case size <= 32768: - buffer = make([]byte, 32768) - case size <= 65535: - buffer = make([]byte, 65535) - default: - return make([]byte, size) - } - return buffer[:size] + return make([]byte, size) } diff --git a/common/buf/ptr.go b/common/buf/ptr.go deleted file mode 100644 index 901c9e3..0000000 --- a/common/buf/ptr.go +++ /dev/null @@ -1,34 +0,0 @@ -//go:build !disable_unsafe - -package buf - -import ( - "unsafe" - - "github.com/sagernet/sing/common" -) - -type dbgVar struct { - name string - value *int32 -} - -//go:linkname dbgvars runtime.dbgvars -var dbgvars any - -// go.info.runtime.dbgvars: relocation target go.info.[]github.com/sagernet/sing/common/buf.dbgVar not defined -// var dbgvars []dbgVar - -func init() { - if !common.UnsafeBuffer { - return - } - debugVars := *(*[]dbgVar)(unsafe.Pointer(&dbgvars)) - for _, v := range debugVars { - if v.name == "invalidptr" { - *v.value = 0 - return - } - } - panic("can't disable invalidptr") -} diff --git a/common/bufio/addr_bsd.go b/common/bufio/addr_bsd.go new file mode 100644 index 0000000..7c51fd0 --- /dev/null +++ b/common/bufio/addr_bsd.go @@ -0,0 +1,34 @@ +//go:build darwin || dragonfly || freebsd || netbsd || openbsd + +package bufio + +import ( + "encoding/binary" + "net/netip" + "unsafe" + + "golang.org/x/sys/unix" +) + +func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) { + if destination.Addr().Is4() { + sa := unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: destination.Addr().As4(), + } + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port()) + name = unsafe.Pointer(&sa) + nameLen = unix.SizeofSockaddrInet4 + } else { + sa := unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: destination.Addr().As16(), + } + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port()) + name = unsafe.Pointer(&sa) + nameLen = unix.SizeofSockaddrInet6 + } + return +} diff --git a/common/bufio/addr.go b/common/bufio/addr_conn.go similarity index 100% rename from common/bufio/addr.go rename to common/bufio/addr_conn.go diff --git a/common/bufio/addr_linux.go b/common/bufio/addr_linux.go new file mode 100644 index 0000000..f0baef7 --- /dev/null +++ b/common/bufio/addr_linux.go @@ -0,0 +1,30 @@ +package bufio + +import ( + "encoding/binary" + "net/netip" + "unsafe" + + "golang.org/x/sys/unix" +) + +func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) { + if destination.Addr().Is4() { + sa := unix.RawSockaddrInet4{ + Family: unix.AF_INET, + Addr: destination.Addr().As4(), + } + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port()) + name = unsafe.Pointer(&sa) + nameLen = unix.SizeofSockaddrInet4 + } else { + sa := unix.RawSockaddrInet6{ + Family: unix.AF_INET6, + Addr: destination.Addr().As16(), + } + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port()) + name = unsafe.Pointer(&sa) + nameLen = unix.SizeofSockaddrInet6 + } + return +} diff --git a/common/bufio/addr_windows.go b/common/bufio/addr_windows.go new file mode 100644 index 0000000..b3a5b9e --- /dev/null +++ b/common/bufio/addr_windows.go @@ -0,0 +1,30 @@ +package bufio + +import ( + "encoding/binary" + "net/netip" + "unsafe" + + "golang.org/x/sys/windows" +) + +func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen int32) { + if destination.Addr().Is4() { + sa := windows.RawSockaddrInet4{ + Family: windows.AF_INET, + Addr: destination.Addr().As4(), + } + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port()) + name = unsafe.Pointer(&sa) + nameLen = int32(unsafe.Sizeof(sa)) + } else { + sa := windows.RawSockaddrInet6{ + Family: windows.AF_INET6, + Addr: destination.Addr().As16(), + } + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port()) + name = unsafe.Pointer(&sa) + nameLen = int32(unsafe.Sizeof(sa)) + } + return +} diff --git a/common/bufio/bind.go b/common/bufio/bind.go index 4c84320..9788b4d 100644 --- a/common/bufio/bind.go +++ b/common/bufio/bind.go @@ -8,51 +8,76 @@ import ( N "github.com/sagernet/sing/common/network" ) -type BindPacketConn struct { +type BindPacketConn interface { N.NetPacketConn - Addr net.Addr + net.Conn } -func NewBindPacketConn(conn net.PacketConn, addr net.Addr) *BindPacketConn { - return &BindPacketConn{ +type bindPacketConn struct { + N.NetPacketConn + addr net.Addr +} + +func NewBindPacketConn(conn net.PacketConn, addr net.Addr) BindPacketConn { + return &bindPacketConn{ NewPacketConn(conn), addr, } } -func (c *BindPacketConn) Read(b []byte) (n int, err error) { +func (c *bindPacketConn) Read(b []byte) (n int, err error) { n, _, err = c.ReadFrom(b) return } -func (c *BindPacketConn) Write(b []byte) (n int, err error) { - return c.WriteTo(b, c.Addr) +func (c *bindPacketConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.addr) } -func (c *BindPacketConn) RemoteAddr() net.Addr { - return c.Addr +func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) { + readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn) + if !isReadWaiter { + return nil, false + } + return &bindPacketReadWaiter{readWaiter}, true } -func (c *BindPacketConn) Upstream() any { +func (c *bindPacketConn) RemoteAddr() net.Addr { + return c.addr +} + +func (c *bindPacketConn) Upstream() any { return c.NetPacketConn } +var ( + _ N.NetPacketConn = (*UnbindPacketConn)(nil) + _ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil) +) + type UnbindPacketConn struct { N.ExtendedConn - Addr M.Socksaddr + addr M.Socksaddr } -func NewUnbindPacketConn(conn net.Conn) *UnbindPacketConn { +func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn { return &UnbindPacketConn{ NewExtendedConn(conn), M.SocksaddrFromNet(conn.RemoteAddr()), } } +func NewUnbindPacketConnWithAddr(conn net.Conn, addr M.Socksaddr) N.NetPacketConn { + return &UnbindPacketConn{ + NewExtendedConn(conn), + addr, + } +} + func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { n, err = c.ExtendedConn.Read(p) if err == nil { - addr = c.Addr.UDPAddr() + addr = c.addr.UDPAddr() } return } @@ -66,7 +91,7 @@ func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad if err != nil { return } - destination = c.Addr + destination = c.addr return } @@ -74,6 +99,67 @@ func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error return c.ExtendedConn.WriteBuffer(buffer) } +func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) { + readWaiter, isReadWaiter := CreateReadWaiter(c.ExtendedConn) + if !isReadWaiter { + return nil, false + } + return &unbindPacketReadWaiter{readWaiter, c.addr}, true +} + func (c *UnbindPacketConn) Upstream() any { return c.ExtendedConn } + +func NewServerPacketConn(conn net.PacketConn) N.ExtendedConn { + return &serverPacketConn{ + NetPacketConn: NewPacketConn(conn), + } +} + +type serverPacketConn struct { + N.NetPacketConn + remoteAddr M.Socksaddr +} + +func (c *serverPacketConn) Read(p []byte) (n int, err error) { + n, addr, err := c.NetPacketConn.ReadFrom(p) + if err != nil { + return + } + c.remoteAddr = M.SocksaddrFromNet(addr) + return +} + +func (c *serverPacketConn) ReadBuffer(buffer *buf.Buffer) error { + destination, err := c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return err + } + c.remoteAddr = destination + return nil +} + +func (c *serverPacketConn) Write(p []byte) (n int, err error) { + return c.NetPacketConn.WriteTo(p, c.remoteAddr.UDPAddr()) +} + +func (c *serverPacketConn) WriteBuffer(buffer *buf.Buffer) error { + return c.NetPacketConn.WritePacket(buffer, c.remoteAddr) +} + +func (c *serverPacketConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *serverPacketConn) Upstream() any { + return c.NetPacketConn +} + +func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) { + readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn) + if !isReadWaiter { + return nil, false + } + return &serverPacketReadWaiter{c, readWaiter}, true +} diff --git a/common/bufio/bind_wait.go b/common/bufio/bind_wait.go new file mode 100644 index 0000000..779474c --- /dev/null +++ b/common/bufio/bind_wait.go @@ -0,0 +1,62 @@ +package bufio + +import ( + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var _ N.ReadWaiter = (*bindPacketReadWaiter)(nil) + +type bindPacketReadWaiter struct { + readWaiter N.PacketReadWaiter +} + +func (w *bindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return w.readWaiter.InitializeReadWaiter(options) +} + +func (w *bindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { + buffer, _, err = w.readWaiter.WaitReadPacket() + return +} + +var _ N.PacketReadWaiter = (*unbindPacketReadWaiter)(nil) + +type unbindPacketReadWaiter struct { + readWaiter N.ReadWaiter + addr M.Socksaddr +} + +func (w *unbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return w.readWaiter.InitializeReadWaiter(options) +} + +func (w *unbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + buffer, err = w.readWaiter.WaitReadBuffer() + if err != nil { + return + } + destination = w.addr + return +} + +var _ N.ReadWaiter = (*serverPacketReadWaiter)(nil) + +type serverPacketReadWaiter struct { + *serverPacketConn + readWaiter N.PacketReadWaiter +} + +func (w *serverPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return w.readWaiter.InitializeReadWaiter(options) +} + +func (w *serverPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { + buffer, destination, err := w.readWaiter.WaitReadPacket() + if err != nil { + return + } + w.remoteAddr = destination + return +} diff --git a/common/bufio/buffer.go b/common/bufio/buffer.go index 47c35a9..cdd2896 100644 --- a/common/bufio/buffer.go +++ b/common/bufio/buffer.go @@ -37,7 +37,7 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) { if err != nil { return } - w.buffer.FullReset() + w.buffer.Reset() } } diff --git a/common/bufio/chunk.go b/common/bufio/chunk.go index 56a733a..cd11f63 100644 --- a/common/bufio/chunk.go +++ b/common/bufio/chunk.go @@ -30,7 +30,7 @@ func (c *ChunkReader) ReadBuffer(buffer *buf.Buffer) error { } else if !c.cache.IsEmpty() { return common.Error(buffer.ReadFrom(c.cache)) } - c.cache.FullReset() + c.cache.Reset() err := c.upstream.ReadBuffer(c.cache) if err != nil { c.cache.Release() @@ -46,7 +46,7 @@ func (c *ChunkReader) Read(p []byte) (n int, err error) { } else if !c.cache.IsEmpty() { return c.cache.Read(p) } - c.cache.FullReset() + c.cache.Reset() err = c.upstream.ReadBuffer(c.cache) if err != nil { c.cache.Release() @@ -70,7 +70,7 @@ func (c *ChunkReader) ReadChunk() (*buf.Buffer, error) { } else if !c.cache.IsEmpty() { return c.cache, nil } - c.cache.FullReset() + c.cache.Reset() err := c.upstream.ReadBuffer(c.cache) if err != nil { c.cache.Release() diff --git a/common/bufio/copy.go b/common/bufio/copy.go index c0ff6dd..f8e63cd 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -22,7 +22,7 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { } else if destination == nil { return 0, E.New("nil writer") } - originDestination := destination + originSource := source var readCounters, writeCounters []N.CountFunc for { source, readCounters = N.UnwrapCountReader(source, readCounters) @@ -45,105 +45,61 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { dstSyscallConn, dstIsSyscall := destination.(syscall.Conn) if srcIsSyscall && dstIsSyscall { var handled bool - handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) + handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) if handled { return } } break } - return CopyExtended(originDestination, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) + return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) } -func CopyExtended(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { - safeSrc := N.IsSafeReader(source) - headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination) - if safeSrc != nil { - if headroom == 0 { - return CopyExtendedWithSrcBuffer(originDestination, destination, safeSrc, readCounters, writeCounters) - } - } +func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { + frontHeadroom := N.CalculateFrontHeadroom(destination) + rearHeadroom := N.CalculateRearHeadroom(destination) readWaiter, isReadWaiter := CreateReadWaiter(source) if isReadWaiter { - var handled bool - handled, n, err = copyWaitWithPool(originDestination, destination, readWaiter, readCounters, writeCounters) - if handled { - return + needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{ + FrontHeadroom: frontHeadroom, + RearHeadroom: rearHeadroom, + MTU: N.CalculateMTU(source, destination), + }) + if !needCopy || common.LowMemory { + var handled bool + handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters) + if handled { + return + } } } - if !common.UnsafeBuffer || N.IsUnsafeWriter(destination) { - return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters) - } - bufferSize := N.CalculateMTU(source, destination) - if bufferSize > 0 { - bufferSize += headroom - } else { - bufferSize = buf.BufferSize - } - _buffer := buf.StackNewSize(bufferSize) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - return CopyExtendedBuffer(originDestination, destination, source, buffer, readCounters, writeCounters) + return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters) } -func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { buffer.IncRef() defer buffer.DecRef() frontHeadroom := N.CalculateFrontHeadroom(destination) rearHeadroom := N.CalculateRearHeadroom(destination) - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) + buffer.Resize(frontHeadroom, 0) + buffer.Reserve(rearHeadroom) var notFirstTime bool for { - readBuffer.Resize(frontHeadroom, 0) - err = source.ReadBuffer(readBuffer) + err = source.ReadBuffer(buffer) if err != nil { if errors.Is(err, io.EOF) { err = nil return } - if !notFirstTime { - err = N.HandshakeFailure(originDestination, err) - } - return - } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) - err = destination.WriteBuffer(buffer) - if err != nil { - return - } - n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - notFirstTime = true - } -} - -func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { - var notFirstTime bool - for { - var buffer *buf.Buffer - buffer, err = source.ReadBufferThreadSafe() - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - return - } - if !notFirstTime { - err = N.HandshakeFailure(originDestination, err) - } return } dataLen := buffer.Len() + buffer.OverCap(rearHeadroom) err = destination.WriteBuffer(buffer) if err != nil { - buffer.Release() + if !notFirstTime { + err = N.ReportHandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -157,7 +113,7 @@ func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.Extend } } -func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { frontHeadroom := N.CalculateFrontHeadroom(destination) rearHeadroom := N.CalculateRearHeadroom(destination) bufferSize := N.CalculateMTU(source, destination) @@ -169,26 +125,25 @@ func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWri var notFirstTime bool for { buffer := buf.NewSize(bufferSize) - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) - readBuffer.Resize(frontHeadroom, 0) - err = source.ReadBuffer(readBuffer) + buffer.Resize(frontHeadroom, 0) + buffer.Reserve(rearHeadroom) + err = source.ReadBuffer(buffer) if err != nil { buffer.Release() if errors.Is(err, io.EOF) { err = nil return } - if !notFirstTime { - err = N.HandshakeFailure(originDestination, err) - } return } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) + dataLen := buffer.Len() + buffer.OverCap(rearHeadroom) err = destination.WriteBuffer(buffer) if err != nil { - buffer.Release() + buffer.Leak() + if !notFirstTime { + err = N.ReportHandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -249,6 +204,7 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { var readCounters, writeCounters []N.CountFunc var cachedPackets []*N.PacketBuffer + originSource := source for { source, readCounters = N.UnwrapCountPacketReader(source, readCounters) destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters) @@ -262,113 +218,38 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, break } if cachedPackets != nil { - n, err = WritePacketWithPool(destinationConn, cachedPackets) + n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets) if err != nil { return } } - safeSrc := N.IsSafePacketReader(source) frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) - headroom := frontHeadroom + rearHeadroom - if safeSrc != nil { - if headroom == 0 { - var copyN int64 - copyN, err = CopyPacketWithSrcBuffer(destinationConn, safeSrc, readCounters, writeCounters) - n += copyN - return - } - } + var ( + handled bool + copeN int64 + ) readWaiter, isReadWaiter := CreatePacketReadWaiter(source) if isReadWaiter { - var ( - handled bool - copeN int64 - ) - handled, copeN, err = copyPacketWaitWithPool(destinationConn, readWaiter, readCounters, writeCounters) - if handled { - n += copeN - return - } - } - if N.IsUnsafeWriter(destinationConn) { - return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters) - } - bufferSize := N.CalculateMTU(source, destinationConn) - if bufferSize > 0 { - bufferSize += headroom - } else { - bufferSize = buf.UDPBufferSize - } - _buffer := buf.StackNewSize(bufferSize) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - buffer.IncRef() - defer buffer.DecRef() - var destination M.Socksaddr - var notFirstTime bool - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) - for { - readBuffer.Resize(frontHeadroom, 0) - destination, err = source.ReadPacket(readBuffer) - if err != nil { - if !notFirstTime { - err = N.HandshakeFailure(destinationConn, err) + needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{ + FrontHeadroom: frontHeadroom, + RearHeadroom: rearHeadroom, + MTU: N.CalculateMTU(source, destinationConn), + }) + if !needCopy || common.LowMemory { + handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) + if handled { + n += copeN + return } - return } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) - err = destinationConn.WritePacket(buffer, destination) - if err != nil { - return - } - n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - notFirstTime = true } + copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0) + n += copeN + return } -func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { - var buffer *buf.Buffer - var destination M.Socksaddr - var notFirstTime bool - for { - buffer, destination, err = source.ReadPacketThreadSafe() - if err != nil { - if !notFirstTime { - err = N.HandshakeFailure(destinationConn, err) - } - return - } - dataLen := buffer.Len() - if dataLen == 0 { - continue - } - err = destinationConn.WritePacket(buffer, destination) - if err != nil { - buffer.Release() - return - } - n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - notFirstTime = true - } -} - -func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) bufferSize := N.CalculateMTU(source, destinationConn) @@ -378,25 +259,23 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r bufferSize = buf.UDPBufferSize } var destination M.Socksaddr - var notFirstTime bool for { buffer := buf.NewSize(bufferSize) - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) - readBuffer.Resize(frontHeadroom, 0) - destination, err = source.ReadPacket(readBuffer) + buffer.Resize(frontHeadroom, 0) + buffer.Reserve(rearHeadroom) + destination, err = source.ReadPacket(buffer) if err != nil { buffer.Release() - if !notFirstTime { - err = N.HandshakeFailure(destinationConn, err) - } return } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) + dataLen := buffer.Len() + buffer.OverCap(rearHeadroom) err = destinationConn.WritePacket(buffer, destination) if err != nil { - buffer.Release() + buffer.Leak() + if !notFirstTime { + err = N.ReportHandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -410,24 +289,28 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r } } -func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { +func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) + var notFirstTime bool for _, packetBuffer := range packetBuffers { buffer := buf.NewPacket() - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) - readBuffer.Resize(frontHeadroom, 0) - _, err = readBuffer.Write(packetBuffer.Buffer.Bytes()) + buffer.Resize(frontHeadroom, 0) + buffer.Reserve(rearHeadroom) + _, err = buffer.Write(packetBuffer.Buffer.Bytes()) packetBuffer.Buffer.Release() if err != nil { + buffer.Release() continue } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) + dataLen := buffer.Len() + buffer.OverCap(rearHeadroom) err = destinationConn.WritePacket(buffer, packetBuffer.Destination) if err != nil { - buffer.Release() + buffer.Leak() + if !notFirstTime { + err = N.ReportHandshakeFailure(originSource, err) + } return } n += int64(dataLen) diff --git a/common/bufio/copy_direct.go b/common/bufio/copy_direct.go index 1648c03..f34d384 100644 --- a/common/bufio/copy_direct.go +++ b/common/bufio/copy_direct.go @@ -1,12 +1,16 @@ package bufio import ( + "errors" + "io" "syscall" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) -func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { +func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { rawSource, err := source.SyscallConn() if err != nil { return @@ -18,3 +22,69 @@ func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N. handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters) return } + +func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { + handled = true + var ( + buffer *buf.Buffer + notFirstTime bool + ) + for { + buffer, err = source.WaitReadBuffer() + if err != nil { + if errors.Is(err, io.EOF) { + err = nil + return + } + return + } + dataLen := buffer.Len() + err = destination.WriteBuffer(buffer) + if err != nil { + buffer.Leak() + if !notFirstTime { + err = N.ReportHandshakeFailure(originSource, err) + } + return + } + n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } + notFirstTime = true + } +} + +func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) { + handled = true + var ( + buffer *buf.Buffer + destination M.Socksaddr + ) + for { + buffer, destination, err = source.WaitReadPacket() + if err != nil { + return + } + dataLen := buffer.Len() + err = destinationConn.WritePacket(buffer, destination) + if err != nil { + buffer.Leak() + if !notFirstTime { + err = N.ReportHandshakeFailure(originSource, err) + } + return + } + n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } + notFirstTime = true + } +} diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go index 4479988..956b818 100644 --- a/common/bufio/copy_direct_posix.go +++ b/common/bufio/copy_direct_posix.go @@ -3,7 +3,6 @@ package bufio import ( - "errors" "io" "net/netip" "os" @@ -15,115 +14,14 @@ import ( N "github.com/sagernet/sing/common/network" ) -func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { - handled = true - frontHeadroom := N.CalculateFrontHeadroom(destination) - rearHeadroom := N.CalculateRearHeadroom(destination) - bufferSize := N.CalculateMTU(source, destination) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.BufferSize - } - var ( - buffer *buf.Buffer - readBuffer *buf.Buffer - notFirstTime bool - ) - source.InitializeReadWaiter(func() *buf.Buffer { - buffer = buf.NewSize(bufferSize) - readBufferRaw := buffer.Slice() - readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) - readBuffer.Resize(frontHeadroom, 0) - return readBuffer - }) - defer source.InitializeReadWaiter(nil) - for { - err = source.WaitReadBuffer() - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - return - } - if !notFirstTime { - err = N.HandshakeFailure(originDestination, err) - } - return - } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) - err = destination.WriteBuffer(buffer) - if err != nil { - buffer.Release() - return - } - n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - notFirstTime = true - } -} - -func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { - handled = true - frontHeadroom := N.CalculateFrontHeadroom(destinationConn) - rearHeadroom := N.CalculateRearHeadroom(destinationConn) - bufferSize := N.CalculateMTU(source, destinationConn) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.UDPBufferSize - } - var ( - buffer *buf.Buffer - readBuffer *buf.Buffer - destination M.Socksaddr - notFirstTime bool - ) - source.InitializeReadWaiter(func() *buf.Buffer { - buffer = buf.NewSize(bufferSize) - readBufferRaw := buffer.Slice() - readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) - readBuffer.Resize(frontHeadroom, 0) - return readBuffer - }) - defer source.InitializeReadWaiter(nil) - for { - destination, err = source.WaitReadPacket() - if err != nil { - if !notFirstTime { - err = N.HandshakeFailure(destinationConn, err) - } - return - } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) - err = destinationConn.WritePacket(buffer, destination) - if err != nil { - buffer.Release() - return - } - n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - notFirstTime = true - } -} - var _ N.ReadWaiter = (*syscallReadWaiter)(nil) type syscallReadWaiter struct { rawConn syscall.RawConn readErr error readFunc func(fd uintptr) (done bool) + buffer *buf.Buffer + options N.ReadWaitOptions } func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { @@ -136,47 +34,48 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { return nil, false } -func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - w.readErr = nil - if newBuffer == nil { - w.readFunc = nil - } else { - w.readFunc = func(fd uintptr) (done bool) { - buffer := newBuffer() - var readN int - readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes()) - if readN > 0 { - buffer.Truncate(readN) - } else { - buffer.Release() - buffer = nil - } - if w.readErr == syscall.EAGAIN { - return false - } - if readN == 0 { - w.readErr = io.EOF - } - return true +func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + w.options = options + w.readFunc = func(fd uintptr) (done bool) { + buffer := w.options.NewBuffer() + var readN int + readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes()) + if readN > 0 { + buffer.Truncate(readN) + w.options.PostReturn(buffer) + w.buffer = buffer + } else { + buffer.Release() } + //goland:noinspection GoDirectComparisonOfErrors + if w.readErr == syscall.EAGAIN { + return false + } + if readN == 0 && w.readErr == nil { + w.readErr = io.EOF + } + return true } + return false } -func (w *syscallReadWaiter) WaitReadBuffer() error { +func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { if w.readFunc == nil { - return os.ErrInvalid + return nil, os.ErrInvalid } - err := w.rawConn.Read(w.readFunc) + err = w.rawConn.Read(w.readFunc) if err != nil { - return err + return } if w.readErr != nil { if w.readErr == io.EOF { - return io.EOF + return nil, io.EOF } - return E.Cause(w.readErr, "raw read") + return nil, E.Cause(w.readErr, "raw read") } - return nil + buffer = w.buffer + w.buffer = nil + return } var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil) @@ -186,6 +85,8 @@ type syscallPacketReadWaiter struct { readErr error readFrom M.Socksaddr readFunc func(fd uintptr) (done bool) + buffer *buf.Buffer + options N.ReadWaitOptions } func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) { @@ -198,42 +99,37 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) return nil, false } -func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - w.readErr = nil - w.readFrom = M.Socksaddr{} - if newBuffer == nil { - w.readFunc = nil - } else { - w.readFunc = func(fd uintptr) (done bool) { - buffer := newBuffer() - var readN int - var from syscall.Sockaddr - readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0) - if readN > 0 { - buffer.Truncate(readN) - } else { - buffer.Release() - buffer = nil - } - if w.readErr == syscall.EAGAIN { - return false - } - if from != nil { - switch fromAddr := from.(type) { - case *syscall.SockaddrInet4: - w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port)) - case *syscall.SockaddrInet6: - w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)) - } - } - return true +func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + w.options = options + w.readFunc = func(fd uintptr) (done bool) { + buffer := w.options.NewPacketBuffer() + var readN int + var from syscall.Sockaddr + readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0) + //goland:noinspection GoDirectComparisonOfErrors + if w.readErr != nil { + buffer.Release() + return w.readErr != syscall.EAGAIN } + if readN > 0 { + buffer.Truncate(readN) + } + w.options.PostReturn(buffer) + w.buffer = buffer + switch fromAddr := from.(type) { + case *syscall.SockaddrInet4: + w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port)) + case *syscall.SockaddrInet6: + w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap() + } + return true } + return false } -func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) { +func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { if w.readFunc == nil { - return M.Socksaddr{}, os.ErrInvalid + return nil, M.Socksaddr{}, os.ErrInvalid } err = w.rawConn.Read(w.readFunc) if err != nil { @@ -243,6 +139,8 @@ func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err err = E.Cause(w.readErr, "raw read") return } + buffer = w.buffer + w.buffer = nil destination = w.readFrom return } diff --git a/common/bufio/copy_direct_test.go b/common/bufio/copy_direct_test.go new file mode 100644 index 0000000..41fed63 --- /dev/null +++ b/common/bufio/copy_direct_test.go @@ -0,0 +1,77 @@ +package bufio + +import ( + "net" + "testing" + + "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" + + "github.com/stretchr/testify/require" +) + +func TestCopyWaitTCP(t *testing.T) { + t.Parallel() + inputConn, outputConn := TCPPipe(t) + readWaiter, created := CreateReadWaiter(outputConn) + require.True(t, created) + require.NotNil(t, readWaiter) + readWaiter.InitializeReadWaiter(N.ReadWaitOptions{}) + require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{ + Conn: outputConn, + readWaiter: readWaiter, + })) +} + +type readWaitWrapper struct { + net.Conn + readWaiter N.ReadWaiter + buffer *buf.Buffer +} + +func (r *readWaitWrapper) Read(p []byte) (n int, err error) { + if r.buffer != nil { + if r.buffer.Len() > 0 { + return r.buffer.Read(p) + } + if r.buffer.IsEmpty() { + r.buffer.Release() + r.buffer = nil + } + } + buffer, err := r.readWaiter.WaitReadBuffer() + if err != nil { + return + } + r.buffer = buffer + return r.buffer.Read(p) +} + +func TestCopyWaitUDP(t *testing.T) { + t.Parallel() + inputConn, outputConn, outputAddr := UDPPipe(t) + readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn)) + require.True(t, created) + require.NotNil(t, readWaiter) + readWaiter.InitializeReadWaiter(N.ReadWaitOptions{}) + require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{ + PacketConn: outputConn, + readWaiter: readWaiter, + }, outputAddr)) +} + +type packetReadWaitWrapper struct { + net.PacketConn + readWaiter N.PacketReadWaiter +} + +func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + buffer, destination, err := r.readWaiter.WaitReadPacket() + if err != nil { + return + } + n = copy(p, buffer.Bytes()) + buffer.Release() + addr = destination.UDPAddr() + return +} diff --git a/common/bufio/copy_direct_windows.go b/common/bufio/copy_direct_windows.go index 9c0743f..ee20caf 100644 --- a/common/bufio/copy_direct_windows.go +++ b/common/bufio/copy_direct_windows.go @@ -2,22 +2,206 @@ package bufio import ( "io" + "net/netip" + "os" + "syscall" + "unsafe" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + + "golang.org/x/sys/windows" ) -func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { +var modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") + +var procrecv = modws2_32.NewProc("recv") + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) { + var _p0 *byte + if len(buf) > 0 { + _p0 = &buf[0] + } + r0, _, e1 := syscall.SyscallN(procrecv.Addr(), uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags)) + n = int32(r0) + if n == -1 { + err = errnoErr(e1) + } return } -func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { +var _ N.ReadWaiter = (*syscallReadWaiter)(nil) + +type syscallReadWaiter struct { + rawConn syscall.RawConn + readErr error + readFunc func(fd uintptr) (done bool) + hasData bool + buffer *buf.Buffer + options N.ReadWaitOptions +} + +func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { + if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn { + rawConn, err := syscallConn.SyscallConn() + if err == nil { + return &syscallReadWaiter{rawConn: rawConn}, true + } + } + return nil, false +} + +func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + w.options = options + w.readFunc = func(fd uintptr) (done bool) { + if !w.hasData { + w.hasData = true + // golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this + // socket is readable if we return false. So the `recv` syscall will not block the system thread. + return false + } + buffer := w.options.NewBuffer() + var readN int32 + readN, w.readErr = recv(windows.Handle(fd), buffer.FreeBytes(), 0) + if readN > 0 { + buffer.Truncate(int(readN)) + w.options.PostReturn(buffer) + w.buffer = buffer + } else { + buffer.Release() + } + if w.readErr == windows.WSAEWOULDBLOCK { + return false + } + if readN == 0 && w.readErr == nil { + w.readErr = io.EOF + } + w.hasData = false + return true + } + return false +} + +func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { + if w.readFunc == nil { + return nil, os.ErrInvalid + } + err = w.rawConn.Read(w.readFunc) + if err != nil { + return + } + if w.readErr != nil { + if w.readErr == io.EOF { + return nil, io.EOF + } + return nil, E.Cause(w.readErr, "raw read") + } + buffer = w.buffer + w.buffer = nil return } -func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) { +var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil) + +type syscallPacketReadWaiter struct { + rawConn syscall.RawConn + readErr error + readFrom M.Socksaddr + readFunc func(fd uintptr) (done bool) + hasData bool + buffer *buf.Buffer + options N.ReadWaitOptions +} + +func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) { + if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn { + rawConn, err := syscallConn.SyscallConn() + if err == nil { + return &syscallPacketReadWaiter{rawConn: rawConn}, true + } + } return nil, false } -func createSyscallPacketReadWaiter(reader any) (N.PacketReadWaiter, bool) { - return nil, false +func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + w.options = options + w.readFunc = func(fd uintptr) (done bool) { + if !w.hasData { + w.hasData = true + // golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this + // socket is readable if we return false. So the `recvfrom` syscall will not block the system thread. + return false + } + buffer := w.options.NewPacketBuffer() + var readN int + var from windows.Sockaddr + readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0) + if readN > 0 { + buffer.Truncate(readN) + w.options.PostReturn(buffer) + w.buffer = buffer + } else { + buffer.Release() + } + if w.readErr == windows.WSAEWOULDBLOCK { + return false + } + if from != nil { + switch fromAddr := from.(type) { + case *windows.SockaddrInet4: + w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port)) + case *windows.SockaddrInet6: + w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap() + } + } + w.hasData = false + return true + } + return false +} + +func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + if w.readFunc == nil { + return nil, M.Socksaddr{}, os.ErrInvalid + } + err = w.rawConn.Read(w.readFunc) + if err != nil { + return + } + if w.readErr != nil { + err = E.Cause(w.readErr, "raw read") + return + } + buffer = w.buffer + w.buffer = nil + destination = w.readFrom + return } diff --git a/common/bufio/deadline/conn.go b/common/bufio/deadline/conn.go index 7ad1a9e..484d297 100644 --- a/common/bufio/deadline/conn.go +++ b/common/bufio/deadline/conn.go @@ -14,18 +14,18 @@ type Conn struct { reader Reader } -func NewConn(conn net.Conn) *Conn { +func NewConn(conn net.Conn) N.ExtendedConn { if deadlineConn, isDeadline := conn.(*Conn); isDeadline { return deadlineConn } - return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)} + return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)}) } -func NewFallbackConn(conn net.Conn) *Conn { +func NewFallbackConn(conn net.Conn) N.ExtendedConn { if deadlineConn, isDeadline := conn.(*Conn); isDeadline { return deadlineConn } - return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)} + return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)}) } func (c *Conn) Read(p []byte) (n int, err error) { diff --git a/common/bufio/deadline/packet_conn.go b/common/bufio/deadline/packet_conn.go index 7c92845..a0e9808 100644 --- a/common/bufio/deadline/packet_conn.go +++ b/common/bufio/deadline/packet_conn.go @@ -14,18 +14,18 @@ type PacketConn struct { reader PacketReader } -func NewPacketConn(conn N.NetPacketConn) *PacketConn { +func NewPacketConn(conn N.NetPacketConn) N.NetPacketConn { if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline { return deadlineConn } - return &PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)} + return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)}) } -func NewFallbackPacketConn(conn N.NetPacketConn) *PacketConn { +func NewFallbackPacketConn(conn N.NetPacketConn) N.NetPacketConn { if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline { return deadlineConn } - return &PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)} + return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)}) } func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { diff --git a/common/bufio/deadline/packet_reader.go b/common/bufio/deadline/packet_reader.go index 36b4e87..088a811 100644 --- a/common/bufio/deadline/packet_reader.go +++ b/common/bufio/deadline/packet_reader.go @@ -52,14 +52,13 @@ func (r *packetReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) { default: } select { + case result := <-r.result: + return r.pipeReturnFrom(result, p) + case <-r.pipeDeadline.wait(): + return 0, nil, os.ErrDeadlineExceeded case <-r.done: go r.pipeReadFrom(len(p)) - default: } - return r.readFrom(p) -} - -func (r *packetReader) readFrom(p []byte) (n int, addr net.Addr, err error) { select { case result := <-r.result: return r.pipeReturnFrom(result, p) @@ -106,14 +105,13 @@ func (r *packetReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, default: } select { + case result := <-r.result: + return r.pipeReturnFromBuffer(result, buffer) + case <-r.pipeDeadline.wait(): + return M.Socksaddr{}, os.ErrDeadlineExceeded case <-r.done: - go r.pipeReadFromBuffer(buffer.FreeLen()) - default: + go r.pipeReadFrom(buffer.FreeLen()) } - return r.readPacket(buffer) -} - -func (r *packetReader) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { select { case result := <-r.result: return r.pipeReturnFromBuffer(result, buffer) @@ -134,17 +132,6 @@ func (r *packetReader) pipeReturnFromBuffer(result *packetReadResult, buffer *bu } } -func (r *packetReader) pipeReadFromBuffer(pLen int) { - buffer := buf.NewSize(pLen) - destination, err := r.TimeoutPacketReader.ReadPacket(buffer) - r.result <- &packetReadResult{ - buffer: buffer, - destination: destination, - err: err, - } - r.done <- struct{}{} -} - func (r *packetReader) SetReadDeadline(t time.Time) error { r.deadline.Store(t) r.pipeDeadline.set(t) diff --git a/common/bufio/deadline/packet_reader_fallback.go b/common/bufio/deadline/packet_reader_fallback.go index 276b784..c20f568 100644 --- a/common/bufio/deadline/packet_reader_fallback.go +++ b/common/bufio/deadline/packet_reader_fallback.go @@ -2,6 +2,7 @@ package deadline import ( "net" + "os" "time" "github.com/sagernet/sing/common/atomic" @@ -25,12 +26,15 @@ func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err err return r.pipeReturnFrom(result, p) default: } - if r.disablePipe.Load() { - return r.TimeoutPacketReader.ReadFrom(p) - } select { + case result := <-r.result: + return r.pipeReturnFrom(result, p) + case <-r.pipeDeadline.wait(): + return 0, nil, os.ErrDeadlineExceeded case <-r.done: - if r.deadline.Load().IsZero() { + if r.disablePipe.Load() { + return r.TimeoutPacketReader.ReadFrom(p) + } else if r.deadline.Load().IsZero() { r.done <- struct{}{} r.inRead.Store(true) defer r.inRead.Store(false) @@ -38,9 +42,13 @@ func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err err return } go r.pipeReadFrom(len(p)) - default: } - return r.readFrom(p) + select { + case result := <-r.result: + return r.pipeReturnFrom(result, p) + case <-r.pipeDeadline.wait(): + return 0, nil, os.ErrDeadlineExceeded + } } func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { @@ -49,22 +57,29 @@ func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Soc return r.pipeReturnFromBuffer(result, buffer) default: } - if r.disablePipe.Load() { - return r.TimeoutPacketReader.ReadPacket(buffer) - } select { + case result := <-r.result: + return r.pipeReturnFromBuffer(result, buffer) + case <-r.pipeDeadline.wait(): + return M.Socksaddr{}, os.ErrDeadlineExceeded case <-r.done: - if r.deadline.Load().IsZero() { + if r.disablePipe.Load() { + return r.TimeoutPacketReader.ReadPacket(buffer) + } else if r.deadline.Load().IsZero() { r.done <- struct{}{} r.inRead.Store(true) defer r.inRead.Store(false) destination, err = r.TimeoutPacketReader.ReadPacket(buffer) return } - go r.pipeReadFromBuffer(buffer.FreeLen()) - default: + go r.pipeReadFrom(buffer.FreeLen()) + } + select { + case result := <-r.result: + return r.pipeReturnFromBuffer(result, buffer) + case <-r.pipeDeadline.wait(): + return M.Socksaddr{}, os.ErrDeadlineExceeded } - return r.readPacket(buffer) } func (r *fallbackPacketReader) SetReadDeadline(t time.Time) error { diff --git a/common/bufio/deadline/reader.go b/common/bufio/deadline/reader.go index b6d3c7d..a7a6252 100644 --- a/common/bufio/deadline/reader.go +++ b/common/bufio/deadline/reader.go @@ -54,14 +54,13 @@ func (r *reader) Read(p []byte) (n int, err error) { default: } select { + case result := <-r.result: + return r.pipeReturn(result, p) + case <-r.pipeDeadline.wait(): + return 0, os.ErrDeadlineExceeded case <-r.done: go r.pipeRead(len(p)) - default: } - return r.read(p) -} - -func (r *reader) read(p []byte) (n int, err error) { select { case result := <-r.result: return r.pipeReturn(result, p) @@ -99,14 +98,13 @@ func (r *reader) ReadBuffer(buffer *buf.Buffer) error { default: } select { + case result := <-r.result: + return r.pipeReturnBuffer(result, buffer) + case <-r.pipeDeadline.wait(): + return os.ErrDeadlineExceeded case <-r.done: - go r.pipeReadBuffer(buffer.FreeLen()) - default: + go r.pipeRead(buffer.FreeLen()) } - return r.readBuffer(buffer) -} - -func (r *reader) readBuffer(buffer *buf.Buffer) error { select { case result := <-r.result: return r.pipeReturnBuffer(result, buffer) @@ -127,16 +125,6 @@ func (r *reader) pipeReturnBuffer(result *readResult, buffer *buf.Buffer) error } } -func (r *reader) pipeReadBuffer(pLen int) { - cacheBuffer := buf.NewSize(pLen) - err := r.ExtendedReader.ReadBuffer(cacheBuffer) - r.result <- &readResult{ - buffer: cacheBuffer, - err: err, - } - r.done <- struct{}{} -} - func (r *reader) SetReadDeadline(t time.Time) error { r.deadline.Store(t) r.pipeDeadline.set(t) diff --git a/common/bufio/deadline/reader_fallback.go b/common/bufio/deadline/reader_fallback.go index 182ab40..a28b315 100644 --- a/common/bufio/deadline/reader_fallback.go +++ b/common/bufio/deadline/reader_fallback.go @@ -1,6 +1,7 @@ package deadline import ( + "os" "time" "github.com/sagernet/sing/common/atomic" @@ -23,12 +24,15 @@ func (r *fallbackReader) Read(p []byte) (n int, err error) { return r.pipeReturn(result, p) default: } - if r.disablePipe.Load() { - return r.ExtendedReader.Read(p) - } select { + case result := <-r.result: + return r.pipeReturn(result, p) + case <-r.pipeDeadline.wait(): + return 0, os.ErrDeadlineExceeded case <-r.done: - if r.deadline.Load().IsZero() { + if r.disablePipe.Load() { + return r.ExtendedReader.Read(p) + } else if r.deadline.Load().IsZero() { r.done <- struct{}{} r.inRead.Store(true) defer r.inRead.Store(false) @@ -36,9 +40,13 @@ func (r *fallbackReader) Read(p []byte) (n int, err error) { return } go r.pipeRead(len(p)) - default: } - return r.reader.read(p) + select { + case result := <-r.result: + return r.pipeReturn(result, p) + case <-r.pipeDeadline.wait(): + return 0, os.ErrDeadlineExceeded + } } func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error { @@ -47,21 +55,28 @@ func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error { return r.pipeReturnBuffer(result, buffer) default: } - if r.disablePipe.Load() { - return r.ExtendedReader.ReadBuffer(buffer) - } select { + case result := <-r.result: + return r.pipeReturnBuffer(result, buffer) + case <-r.pipeDeadline.wait(): + return os.ErrDeadlineExceeded case <-r.done: - if r.deadline.Load().IsZero() { + if r.disablePipe.Load() { + return r.ExtendedReader.ReadBuffer(buffer) + } else if r.deadline.Load().IsZero() { r.done <- struct{}{} r.inRead.Store(true) defer r.inRead.Store(false) return r.ExtendedReader.ReadBuffer(buffer) } - go r.pipeReadBuffer(buffer.FreeLen()) - default: + go r.pipeRead(buffer.FreeLen()) + } + select { + case result := <-r.result: + return r.pipeReturnBuffer(result, buffer) + case <-r.pipeDeadline.wait(): + return os.ErrDeadlineExceeded } - return r.readBuffer(buffer) } func (r *fallbackReader) SetReadDeadline(t time.Time) error { diff --git a/common/bufio/deadline/serial.go b/common/bufio/deadline/serial.go new file mode 100644 index 0000000..951fd7e --- /dev/null +++ b/common/bufio/deadline/serial.go @@ -0,0 +1,75 @@ +package deadline + +import ( + "net" + "sync" + + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/debug" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type SerialConn struct { + N.ExtendedConn + access sync.Mutex +} + +func NewSerialConn(conn N.ExtendedConn) N.ExtendedConn { + if !debug.Enabled { + return conn + } + return &SerialConn{ExtendedConn: conn} +} + +func (c *SerialConn) Read(p []byte) (n int, err error) { + if !c.access.TryLock() { + panic("concurrent read on deadline conn") + } + defer c.access.Unlock() + return c.ExtendedConn.Read(p) +} + +func (c *SerialConn) ReadBuffer(buffer *buf.Buffer) error { + if !c.access.TryLock() { + panic("concurrent read on deadline conn") + } + defer c.access.Unlock() + return c.ExtendedConn.ReadBuffer(buffer) +} + +func (c *SerialConn) Upstream() any { + return c.ExtendedConn +} + +type SerialPacketConn struct { + N.NetPacketConn + access sync.Mutex +} + +func NewSerialPacketConn(conn N.NetPacketConn) N.NetPacketConn { + if !debug.Enabled { + return conn + } + return &SerialPacketConn{NetPacketConn: conn} +} + +func (c *SerialPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if !c.access.TryLock() { + panic("concurrent read on deadline conn") + } + defer c.access.Unlock() + return c.NetPacketConn.ReadFrom(p) +} + +func (c *SerialPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + if !c.access.TryLock() { + panic("concurrent read on deadline conn") + } + defer c.access.Unlock() + return c.NetPacketConn.ReadPacket(buffer) +} + +func (c *SerialPacketConn) Upstream() any { + return c.NetPacketConn +} diff --git a/common/bufio/fallback.go b/common/bufio/fallback.go index 4ea87cf..bd4ab46 100644 --- a/common/bufio/fallback.go +++ b/common/bufio/fallback.go @@ -3,6 +3,7 @@ package bufio import ( "net" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -12,13 +13,17 @@ var _ N.NetPacketConn = (*FallbackPacketConn)(nil) type FallbackPacketConn struct { N.PacketConn + writer N.NetPacketWriter } func NewNetPacketConn(conn N.PacketConn) N.NetPacketConn { if packetConn, loaded := conn.(N.NetPacketConn); loaded { return packetConn } - return &FallbackPacketConn{PacketConn: conn} + return &FallbackPacketConn{ + PacketConn: conn, + writer: NewNetPacketWriter(conn), + } } func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { @@ -36,11 +41,7 @@ func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error } func (c *FallbackPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - err = c.WritePacket(buf.As(p), M.SocksaddrFromNet(addr)) - if err == nil { - n = len(p) - } - return + return c.writer.WriteTo(p, addr) } func (c *FallbackPacketConn) ReaderReplaceable() bool { @@ -54,3 +55,50 @@ func (c *FallbackPacketConn) WriterReplaceable() bool { func (c *FallbackPacketConn) Upstream() any { return c.PacketConn } + +func (c *FallbackPacketConn) UpstreamWriter() any { + return c.writer +} + +var _ N.NetPacketWriter = (*FallbackPacketWriter)(nil) + +type FallbackPacketWriter struct { + N.PacketWriter + frontHeadroom int + rearHeadroom int +} + +func NewNetPacketWriter(writer N.PacketWriter) N.NetPacketWriter { + if packetWriter, loaded := writer.(N.NetPacketWriter); loaded { + return packetWriter + } + return &FallbackPacketWriter{ + PacketWriter: writer, + frontHeadroom: N.CalculateFrontHeadroom(writer), + rearHeadroom: N.CalculateRearHeadroom(writer), + } +} + +func (c *FallbackPacketWriter) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.frontHeadroom > 0 || c.rearHeadroom > 0 { + buffer := buf.NewSize(len(p) + c.frontHeadroom + c.rearHeadroom) + buffer.Resize(c.frontHeadroom, 0) + common.Must1(buffer.Write(p)) + err = c.PacketWriter.WritePacket(buffer, M.SocksaddrFromNet(addr)) + } else { + err = c.PacketWriter.WritePacket(buf.As(p), M.SocksaddrFromNet(addr)) + } + if err != nil { + return + } + n = len(p) + return +} + +func (c *FallbackPacketWriter) WriterReplaceable() bool { + return true +} + +func (c *FallbackPacketWriter) Upstream() any { + return c.PacketWriter +} diff --git a/common/bufio/io.go b/common/bufio/io.go index 1e5d89b..a25a7cc 100644 --- a/common/bufio/io.go +++ b/common/bufio/io.go @@ -37,13 +37,7 @@ func WriteBuffer(writer N.ExtendedWriter, buffer *buf.Buffer) (n int, err error) frontHeadroom := N.CalculateFrontHeadroom(writer) rearHeadroom := N.CalculateRearHeadroom(writer) if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() { - bufferSize := N.CalculateMTU(nil, writer) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.BufferSize - } - newBuffer := buf.NewSize(bufferSize) + newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom) newBuffer.Resize(frontHeadroom, 0) common.Must1(newBuffer.Write(buffer.Bytes())) buffer.Release() @@ -69,13 +63,7 @@ func WritePacketBuffer(writer N.PacketWriter, buffer *buf.Buffer, destination M. frontHeadroom := N.CalculateFrontHeadroom(writer) rearHeadroom := N.CalculateRearHeadroom(writer) if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() { - bufferSize := N.CalculateMTU(nil, writer) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.BufferSize - } - newBuffer := buf.NewSize(bufferSize) + newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom) newBuffer.Resize(frontHeadroom, 0) common.Must1(newBuffer.Write(buffer.Bytes())) buffer.Release() diff --git a/common/bufio/nat.go b/common/bufio/nat.go index d652094..cafeb06 100644 --- a/common/bufio/nat.go +++ b/common/bufio/nat.go @@ -9,54 +9,142 @@ import ( N "github.com/sagernet/sing/common/network" ) -type NATPacketConn struct { +type NATPacketConn interface { + N.NetPacketConn + UpdateDestination(destinationAddress netip.Addr) +} + +func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn { + return &unidirectionalNATPacketConn{ + NetPacketConn: conn, + origin: socksaddrWithoutPort(origin), + destination: socksaddrWithoutPort(destination), + } +} + +func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn { + return &bidirectionalNATPacketConn{ + NetPacketConn: conn, + origin: socksaddrWithoutPort(origin), + destination: socksaddrWithoutPort(destination), + } +} + +type unidirectionalNATPacketConn struct { N.NetPacketConn origin M.Socksaddr destination M.Socksaddr } -func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) *NATPacketConn { - return &NATPacketConn{ - NetPacketConn: conn, - origin: origin, - destination: destination, +func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + destination := M.SocksaddrFromNet(addr) + if socksaddrWithoutPort(destination) == c.destination { + destination = M.Socksaddr{ + Addr: c.origin.Addr, + Fqdn: c.origin.Fqdn, + Port: destination.Port, + } } + return c.NetPacketConn.WriteTo(p, destination.UDPAddr()) } -func (c *NATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - n, addr, err = c.NetPacketConn.ReadFrom(p) - if err == nil && M.SocksaddrFromNet(addr) == c.origin { - addr = c.destination.UDPAddr() - } - return -} - -func (c *NATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if M.SocksaddrFromNet(addr) == c.destination { - addr = c.origin.UDPAddr() - } - return c.NetPacketConn.WriteTo(p, addr) -} - -func (c *NATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - destination, err = c.NetPacketConn.ReadPacket(buffer) - if destination == c.origin { - destination = c.destination - } - return -} - -func (c *NATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - if destination == c.destination { - destination = c.origin +func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if socksaddrWithoutPort(destination) == c.destination { + destination = M.Socksaddr{ + Addr: c.origin.Addr, + Fqdn: c.origin.Fqdn, + Port: destination.Port, + } } return c.NetPacketConn.WritePacket(buffer, destination) } -func (c *NATPacketConn) UpdateDestination(destinationAddress netip.Addr) { +func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) } -func (c *NATPacketConn) Upstream() any { +func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr { + return c.destination.UDPAddr() +} + +func (c *unidirectionalNATPacketConn) Upstream() any { return c.NetPacketConn } + +type bidirectionalNATPacketConn struct { + N.NetPacketConn + origin M.Socksaddr + destination M.Socksaddr +} + +func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.NetPacketConn.ReadFrom(p) + if err != nil { + return + } + destination := M.SocksaddrFromNet(addr) + if socksaddrWithoutPort(destination) == c.origin { + destination = M.Socksaddr{ + Addr: c.destination.Addr, + Fqdn: c.destination.Fqdn, + Port: destination.Port, + } + } + addr = destination.UDPAddr() + return +} + +func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + destination := M.SocksaddrFromNet(addr) + if socksaddrWithoutPort(destination) == c.destination { + destination = M.Socksaddr{ + Addr: c.origin.Addr, + Fqdn: c.origin.Fqdn, + Port: destination.Port, + } + } + return c.NetPacketConn.WriteTo(p, destination.UDPAddr()) +} + +func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + destination, err = c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return + } + if socksaddrWithoutPort(destination) == c.origin { + destination = M.Socksaddr{ + Addr: c.destination.Addr, + Fqdn: c.destination.Fqdn, + Port: destination.Port, + } + } + return +} + +func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if socksaddrWithoutPort(destination) == c.destination { + destination = M.Socksaddr{ + Addr: c.origin.Addr, + Fqdn: c.origin.Fqdn, + Port: destination.Port, + } + } + return c.NetPacketConn.WritePacket(buffer, destination) +} + +func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { + c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) +} + +func (c *bidirectionalNATPacketConn) Upstream() any { + return c.NetPacketConn +} + +func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr { + return c.destination.UDPAddr() +} + +func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr { + destination.Port = 0 + return destination +} diff --git a/common/bufio/nat_wait.go b/common/bufio/nat_wait.go new file mode 100644 index 0000000..dbb370a --- /dev/null +++ b/common/bufio/nat_wait.go @@ -0,0 +1,39 @@ +package bufio + +import ( + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func (c *bidirectionalNATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaiter, bool) { + waiter, created := CreatePacketReadWaiter(c.NetPacketConn) + if !created { + return nil, false + } + return &waitBidirectionalNATPacketConn{c, waiter}, true +} + +type waitBidirectionalNATPacketConn struct { + *bidirectionalNATPacketConn + readWaiter N.PacketReadWaiter +} + +func (c *waitBidirectionalNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return c.readWaiter.InitializeReadWaiter(options) +} + +func (c *waitBidirectionalNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + buffer, destination, err = c.readWaiter.WaitReadPacket() + if err != nil { + return + } + if socksaddrWithoutPort(destination) == c.origin { + destination = M.Socksaddr{ + Addr: c.destination.Addr, + Fqdn: c.destination.Fqdn, + Port: destination.Port, + } + } + return +} diff --git a/common/bufio/net_test.go b/common/bufio/net_test.go new file mode 100644 index 0000000..8642572 --- /dev/null +++ b/common/bufio/net_test.go @@ -0,0 +1,277 @@ +package bufio + +import ( + "context" + "crypto/md5" + "crypto/rand" + "errors" + "io" + "net" + "sync" + "testing" + "time" + + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/task" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TCPPipe(t *testing.T) (net.Conn, net.Conn) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + var ( + group task.Group + serverConn net.Conn + clientConn net.Conn + ) + group.Append0(func(ctx context.Context) error { + var serverErr error + serverConn, serverErr = listener.Accept() + return serverErr + }) + group.Append0(func(ctx context.Context) error { + var clientErr error + clientConn, clientErr = net.Dial("tcp", listener.Addr().String()) + return clientErr + }) + err = group.Run() + require.NoError(t, err) + listener.Close() + t.Cleanup(func() { + serverConn.Close() + clientConn.Close() + }) + return serverConn, clientConn +} + +func UDPPipe(t *testing.T) (net.PacketConn, net.PacketConn, M.Socksaddr) { + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + return serverConn, clientConn, M.SocksaddrFromNet(clientConn.LocalAddr()) +} + +func Timeout(t *testing.T) context.CancelFunc { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + select { + case <-ctx.Done(): + return + case <-time.After(5 * time.Second): + t.Error("timeout") + } + }() + return cancel +} + +type hashPair struct { + sendHash map[int][]byte + recvHash map[int][]byte +} + +func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) { + pingCh := make(chan hashPair) + pongCh := make(chan hashPair) + test := func(t *testing.T) error { + defer close(pingCh) + defer close(pongCh) + pingOpen := false + pongOpen := false + var serverPair hashPair + var clientPair hashPair + + for { + if pingOpen && pongOpen { + break + } + + select { + case serverPair, pingOpen = <-pingCh: + assert.True(t, pingOpen) + case clientPair, pongOpen = <-pongCh: + assert.True(t, pongOpen) + case <-time.After(10 * time.Second): + return errors.New("timeout") + } + } + + assert.Equal(t, serverPair.recvHash, clientPair.sendHash) + assert.Equal(t, serverPair.sendHash, clientPair.recvHash) + + return nil + } + + return pingCh, pongCh, test +} + +func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error { + times := 100 + chunkSize := int64(64 * 1024) + + pingCh, pongCh, test := newLargeDataPair() + writeRandData := func(conn net.Conn) (map[int][]byte, error) { + buf := make([]byte, chunkSize) + hashMap := map[int][]byte{} + for i := 0; i < times; i++ { + if _, err := rand.Read(buf[1:]); err != nil { + return nil, err + } + buf[0] = byte(i) + + hash := md5.Sum(buf) + hashMap[i] = hash[:] + + if _, err := conn.Write(buf); err != nil { + return nil, err + } + } + + return hashMap, nil + } + go func() { + hashMap := map[int][]byte{} + buf := make([]byte, chunkSize) + + for i := 0; i < times; i++ { + _, err := io.ReadFull(outputConn, buf) + if err != nil { + t.Log(err.Error()) + return + } + + hash := md5.Sum(buf) + hashMap[int(buf[0])] = hash[:] + } + + sendHash, err := writeRandData(outputConn) + if err != nil { + t.Log(err.Error()) + return + } + + pingCh <- hashPair{ + sendHash: sendHash, + recvHash: hashMap, + } + }() + + go func() { + sendHash, err := writeRandData(inputConn) + if err != nil { + t.Log(err.Error()) + return + } + + hashMap := map[int][]byte{} + buf := make([]byte, chunkSize) + + for i := 0; i < times; i++ { + _, err = io.ReadFull(inputConn, buf) + if err != nil { + t.Log(err.Error()) + return + } + + hash := md5.Sum(buf) + hashMap[int(buf[0])] = hash[:] + } + + pongCh <- hashPair{ + sendHash: sendHash, + recvHash: hashMap, + } + }() + return test(t) +} + +func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error { + rAddr := outputAddr.UDPAddr() + times := 50 + chunkSize := 9000 + pingCh, pongCh, test := newLargeDataPair() + writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) { + hashMap := map[int][]byte{} + mux := sync.Mutex{} + for i := 0; i < times; i++ { + buf := make([]byte, chunkSize) + if _, err := rand.Read(buf[1:]); err != nil { + t.Log(err.Error()) + continue + } + buf[0] = byte(i) + + hash := md5.Sum(buf) + mux.Lock() + hashMap[i] = hash[:] + mux.Unlock() + + if _, err := pc.WriteTo(buf, addr); err != nil { + t.Log(err.Error()) + } + + time.Sleep(10 * time.Millisecond) + } + + return hashMap, nil + } + go func() { + var ( + lAddr net.Addr + err error + ) + hashMap := map[int][]byte{} + buf := make([]byte, 64*1024) + + for i := 0; i < times; i++ { + _, lAddr, err = outputConn.ReadFrom(buf) + if err != nil { + t.Log(err.Error()) + return + } + hash := md5.Sum(buf[:chunkSize]) + hashMap[int(buf[0])] = hash[:] + } + sendHash, err := writeRandData(outputConn, lAddr) + if err != nil { + t.Log(err.Error()) + return + } + + pingCh <- hashPair{ + sendHash: sendHash, + recvHash: hashMap, + } + }() + + go func() { + sendHash, err := writeRandData(inputConn, rAddr) + if err != nil { + t.Log(err.Error()) + return + } + + hashMap := map[int][]byte{} + buf := make([]byte, 64*1024) + + for i := 0; i < times; i++ { + _, _, err := inputConn.ReadFrom(buf) + if err != nil { + t.Log(err.Error()) + return + } + + hash := md5.Sum(buf[:chunkSize]) + hashMap[int(buf[0])] = hash[:] + } + + pongCh <- hashPair{ + sendHash: sendHash, + recvHash: hashMap, + } + }() + + return test(t) +} diff --git a/common/bufio/once.go b/common/bufio/once.go deleted file mode 100644 index 5bfd0aa..0000000 --- a/common/bufio/once.go +++ /dev/null @@ -1,127 +0,0 @@ -package bufio - -import ( - "io" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - N "github.com/sagernet/sing/common/network" -) - -func CopyTimes(dst io.Writer, src io.Reader, times int) (n int64, err error) { - return CopyExtendedTimes(NewExtendedWriter(N.UnwrapWriter(dst)), NewExtendedReader(N.UnwrapReader(src)), times) -} - -func CopyExtendedTimes(dst N.ExtendedWriter, src N.ExtendedReader, times int) (n int64, err error) { - frontHeadroom := N.CalculateFrontHeadroom(dst) - rearHeadroom := N.CalculateRearHeadroom(dst) - bufferSize := N.CalculateMTU(src, dst) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.BufferSize - } - dstUnsafe := N.IsUnsafeWriter(dst) - var buffer *buf.Buffer - if !dstUnsafe { - _buffer := buf.StackNewSize(bufferSize) - defer common.KeepAlive(_buffer) - buffer = common.Dup(_buffer) - defer buffer.Release() - buffer.IncRef() - defer buffer.DecRef() - } - notFirstTime := true - for i := 0; i < times; i++ { - if dstUnsafe { - buffer = buf.NewSize(bufferSize) - } - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom]) - readBuffer.Resize(frontHeadroom, 0) - err = src.ReadBuffer(readBuffer) - if err != nil { - buffer.Release() - if !notFirstTime { - err = N.HandshakeFailure(dst, err) - } - return - } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) - err = dst.WriteBuffer(buffer) - if err != nil { - buffer.Release() - return - } - n += int64(dataLen) - notFirstTime = true - } - return -} - -type ReadFromWriter interface { - io.ReaderFrom - io.Writer -} - -func ReadFrom0(readerFrom ReadFromWriter, reader io.Reader) (n int64, err error) { - n, err = CopyTimes(readerFrom, reader, 1) - if err != nil { - return - } - var rn int64 - rn, err = readerFrom.ReadFrom(reader) - if err != nil { - return - } - n += rn - return -} - -func ReadFromN(readerFrom ReadFromWriter, reader io.Reader, times int) (n int64, err error) { - n, err = CopyTimes(readerFrom, reader, times) - if err != nil { - return - } - var rn int64 - rn, err = readerFrom.ReadFrom(reader) - if err != nil { - return - } - n += rn - return -} - -type WriteToReader interface { - io.WriterTo - io.Reader -} - -func WriteTo0(writerTo WriteToReader, writer io.Writer) (n int64, err error) { - n, err = CopyTimes(writer, writerTo, 1) - if err != nil { - return - } - var wn int64 - wn, err = writerTo.WriteTo(writer) - if err != nil { - return - } - n += wn - return -} - -func WriteToN(writerTo WriteToReader, writer io.Writer, times int) (n int64, err error) { - n, err = CopyTimes(writer, writerTo, times) - if err != nil { - return - } - var wn int64 - wn, err = writerTo.WriteTo(writer) - if err != nil { - return - } - n += wn - return -} diff --git a/common/bufio/vectorised.go b/common/bufio/vectorised.go index ef875fd..0fea211 100644 --- a/common/bufio/vectorised.go +++ b/common/bufio/vectorised.go @@ -33,10 +33,10 @@ func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) { case syscall.Conn: rawConn, err := w.SyscallConn() if err == nil { - return &SyscallVectorisedWriter{writer, rawConn}, true + return &SyscallVectorisedWriter{upstream: writer, rawConn: rawConn}, true } case syscall.RawConn: - return &SyscallVectorisedWriter{writer, w}, true + return &SyscallVectorisedWriter{upstream: writer, rawConn: w}, true } return nil, false } @@ -48,10 +48,10 @@ func CreateVectorisedPacketWriter(writer any) (N.VectorisedPacketWriter, bool) { case syscall.Conn: rawConn, err := w.SyscallConn() if err == nil { - return &SyscallVectorisedPacketWriter{writer, rawConn}, true + return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: rawConn}, true } case syscall.RawConn: - return &SyscallVectorisedPacketWriter{writer, w}, true + return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: w}, true } return nil, false } @@ -74,9 +74,7 @@ func (w *BufferedVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error if bufferLen > 65535 { bufferBytes = make([]byte, bufferLen) } else { - _buffer := buf.StackNewSize(bufferLen) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(bufferLen) defer buffer.Release() bufferBytes = buffer.FreeBytes() } @@ -113,6 +111,7 @@ var _ N.VectorisedWriter = (*SyscallVectorisedWriter)(nil) type SyscallVectorisedWriter struct { upstream any rawConn syscall.RawConn + syscallVectorisedWriterFields } func (w *SyscallVectorisedWriter) Upstream() any { @@ -128,6 +127,7 @@ var _ N.VectorisedPacketWriter = (*SyscallVectorisedPacketWriter)(nil) type SyscallVectorisedPacketWriter struct { upstream any rawConn syscall.RawConn + syscallVectorisedWriterFields } func (w *SyscallVectorisedPacketWriter) Upstream() any { diff --git a/common/bufio/vectorised_test.go b/common/bufio/vectorised_test.go new file mode 100644 index 0000000..7d2e42d --- /dev/null +++ b/common/bufio/vectorised_test.go @@ -0,0 +1,60 @@ +package bufio + +import ( + "crypto/rand" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWriteVectorised(t *testing.T) { + t.Parallel() + inputConn, outputConn := TCPPipe(t) + vectorisedWriter, created := CreateVectorisedWriter(inputConn) + require.True(t, created) + require.NotNil(t, vectorisedWriter) + var bufA [1024]byte + var bufB [1024]byte + var bufC [2048]byte + _, err := io.ReadFull(rand.Reader, bufA[:]) + require.NoError(t, err) + _, err = io.ReadFull(rand.Reader, bufB[:]) + require.NoError(t, err) + copy(bufC[:], bufA[:]) + copy(bufC[1024:], bufB[:]) + finish := Timeout(t) + _, err = WriteVectorised(vectorisedWriter, [][]byte{bufA[:], bufB[:]}) + require.NoError(t, err) + output := make([]byte, 2048) + _, err = io.ReadFull(outputConn, output) + finish() + require.NoError(t, err) + require.Equal(t, bufC[:], output) +} + +func TestWriteVectorisedPacket(t *testing.T) { + t.Parallel() + inputConn, outputConn, outputAddr := UDPPipe(t) + vectorisedWriter, created := CreateVectorisedPacketWriter(inputConn) + require.True(t, created) + require.NotNil(t, vectorisedWriter) + var bufA [1024]byte + var bufB [1024]byte + var bufC [2048]byte + _, err := io.ReadFull(rand.Reader, bufA[:]) + require.NoError(t, err) + _, err = io.ReadFull(rand.Reader, bufB[:]) + require.NoError(t, err) + copy(bufC[:], bufA[:]) + copy(bufC[1024:], bufB[:]) + finish := Timeout(t) + _, err = WriteVectorisedPacket(vectorisedWriter, [][]byte{bufA[:], bufB[:]}, outputAddr) + require.NoError(t, err) + output := make([]byte, 2048) + n, _, err := outputConn.ReadFrom(output) + finish() + require.NoError(t, err) + require.Equal(t, 2048, n) + require.Equal(t, bufC[:], output) +} diff --git a/common/bufio/vectorised_unix.go b/common/bufio/vectorised_unix.go index b64ae3c..6bb5d7d 100644 --- a/common/bufio/vectorised_unix.go +++ b/common/bufio/vectorised_unix.go @@ -3,6 +3,8 @@ package bufio import ( + "os" + "sync" "unsafe" "github.com/sagernet/sing/common/buf" @@ -11,15 +13,28 @@ import ( "golang.org/x/sys/unix" ) +type syscallVectorisedWriterFields struct { + access sync.Mutex + iovecList *[]unix.Iovec +} + func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error { + w.access.Lock() + defer w.access.Unlock() defer buf.ReleaseMulti(buffers) - iovecList := make([]unix.Iovec, 0, len(buffers)) - for _, buffer := range buffers { - var iovec unix.Iovec - iovec.Base = &buffer.Bytes()[0] - iovec.SetLen(buffer.Len()) - iovecList = append(iovecList, iovec) + var iovecList []unix.Iovec + if w.iovecList != nil { + iovecList = *w.iovecList } + iovecList = iovecList[:0] + for index, buffer := range buffers { + iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]}) + iovecList[index].SetLen(buffer.Len()) + } + if w.iovecList == nil { + w.iovecList = new([]unix.Iovec) + } + *w.iovecList = iovecList // cache var innerErr unix.Errno err := w.rawConn.Write(func(fd uintptr) (done bool) { //nolint:staticcheck @@ -28,32 +43,52 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error { return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK }) if innerErr != 0 { - err = innerErr + err = os.NewSyscallError("SYS_WRITEV", innerErr) + } + for index := range iovecList { + iovecList[index] = unix.Iovec{} } return err } func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error { + w.access.Lock() + defer w.access.Unlock() defer buf.ReleaseMulti(buffers) - var sockaddr unix.Sockaddr - if destination.IsIPv4() { - sockaddr = &unix.SockaddrInet4{ - Port: int(destination.Port), - Addr: destination.Addr.As4(), - } - } else { - sockaddr = &unix.SockaddrInet6{ - Port: int(destination.Port), - Addr: destination.Addr.As16(), - } + var iovecList []unix.Iovec + if w.iovecList != nil { + iovecList = *w.iovecList } + iovecList = iovecList[:0] + for index, buffer := range buffers { + iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]}) + iovecList[index].SetLen(buffer.Len()) + } + if w.iovecList == nil { + w.iovecList = new([]unix.Iovec) + } + *w.iovecList = iovecList // cache var innerErr error err := w.rawConn.Write(func(fd uintptr) (done bool) { - _, innerErr = unix.SendmsgBuffers(int(fd), buf.ToSliceMulti(buffers), nil, sockaddr, 0) + var msg unix.Msghdr + name, nameLen := ToSockaddr(destination.AddrPort()) + msg.Name = (*byte)(name) + msg.Namelen = nameLen + if len(iovecList) > 0 { + msg.Iov = &iovecList[0] + msg.SetIovlen(len(iovecList)) + } + _, innerErr = sendmsg(int(fd), &msg, 0) return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK }) if innerErr != nil { err = innerErr } + for index := range iovecList { + iovecList[index] = unix.Iovec{} + } return err } + +//go:linkname sendmsg golang.org/x/sys/unix.sendmsg +func sendmsg(s int, msg *unix.Msghdr, flags int) (n int, err error) diff --git a/common/bufio/vectorised_windows.go b/common/bufio/vectorised_windows.go index 3223052..c94617f 100644 --- a/common/bufio/vectorised_windows.go +++ b/common/bufio/vectorised_windows.go @@ -1,62 +1,93 @@ package bufio import ( + "sync" + "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" "golang.org/x/sys/windows" ) +type syscallVectorisedWriterFields struct { + access sync.Mutex + iovecList *[]windows.WSABuf +} + func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error { + w.access.Lock() + defer w.access.Unlock() defer buf.ReleaseMulti(buffers) - iovecList := make([]*windows.WSABuf, len(buffers)) - for i, buffer := range buffers { - iovecList[i] = &windows.WSABuf{ - Len: uint32(buffer.Len()), - Buf: &buffer.Bytes()[0], - } + var iovecList []windows.WSABuf + if w.iovecList != nil { + iovecList = *w.iovecList } + iovecList = iovecList[:0] + for _, buffer := range buffers { + iovecList = append(iovecList, windows.WSABuf{ + Buf: &buffer.Bytes()[0], + Len: uint32(buffer.Len()), + }) + } + if w.iovecList == nil { + w.iovecList = new([]windows.WSABuf) + } + *w.iovecList = iovecList // cache var n uint32 var innerErr error err := w.rawConn.Write(func(fd uintptr) (done bool) { - innerErr = windows.WSASend(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil) + innerErr = windows.WSASend(windows.Handle(fd), &iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil) return innerErr != windows.WSAEWOULDBLOCK }) if innerErr != nil { err = innerErr } + for index := range iovecList { + iovecList[index] = windows.WSABuf{} + } return err } func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error { + w.access.Lock() + defer w.access.Unlock() defer buf.ReleaseMulti(buffers) - iovecList := make([]*windows.WSABuf, len(buffers)) - for i, buffer := range buffers { - iovecList[i] = &windows.WSABuf{ - Len: uint32(buffer.Len()), + var iovecList []windows.WSABuf + if w.iovecList != nil { + iovecList = *w.iovecList + } + iovecList = iovecList[:0] + for _, buffer := range buffers { + iovecList = append(iovecList, windows.WSABuf{ Buf: &buffer.Bytes()[0], - } + Len: uint32(buffer.Len()), + }) } - var sockaddr windows.Sockaddr - if destination.IsIPv4() { - sockaddr = &windows.SockaddrInet4{ - Port: int(destination.Port), - Addr: destination.Addr.As4(), - } - } else { - sockaddr = &windows.SockaddrInet6{ - Port: int(destination.Port), - Addr: destination.Addr.As16(), - } + if w.iovecList == nil { + w.iovecList = new([]windows.WSABuf) } + *w.iovecList = iovecList // cache var n uint32 var innerErr error err := w.rawConn.Write(func(fd uintptr) (done bool) { - innerErr = windows.WSASendto(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, sockaddr, nil, nil) + name, nameLen := ToSockaddr(destination.AddrPort()) + innerErr = windows.WSASendTo( + windows.Handle(fd), + &iovecList[0], + uint32(len(iovecList)), + &n, + 0, + (*windows.RawSockaddrAny)(name), + nameLen, + nil, + nil) return innerErr != windows.WSAEWOULDBLOCK }) if innerErr != nil { err = innerErr } + for index := range iovecList { + iovecList[index] = windows.WSABuf{} + } return err } diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go index 37fa26b..a8912e8 100644 --- a/common/cache/lrucache.go +++ b/common/cache/lrucache.go @@ -258,6 +258,14 @@ func (c *LruCache[K, V]) Delete(key K) { c.mu.Unlock() } +func (c *LruCache[K, V]) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + for element := c.lru.Front(); element != nil; element = element.Next() { + c.deleteElement(element) + } +} + func (c *LruCache[K, V]) maybeDeleteOldest() { if !c.staleReturn && c.maxAge > 0 { now := time.Now().Unix() diff --git a/common/canceler/packet.go b/common/canceler/packet.go index ecc2006..fb4ad84 100644 --- a/common/canceler/packet.go +++ b/common/canceler/packet.go @@ -21,13 +21,13 @@ type TimerPacketConn struct { instance *Instance } -func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) { +func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) { if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn { oldTimeout := timeoutConn.Timeout() if timeout < oldTimeout { timeoutConn.SetTimeout(timeout) } - return ctx, timeoutConn + return ctx, conn } err := conn.SetReadDeadline(time.Time{}) if err == nil { diff --git a/common/canceler/packet_timeout.go b/common/canceler/packet_timeout.go index 561f212..ab5c760 100644 --- a/common/canceler/packet_timeout.go +++ b/common/canceler/packet_timeout.go @@ -2,6 +2,7 @@ package canceler import ( "context" + "net" "time" "github.com/sagernet/sing/common" @@ -31,7 +32,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa for { err = c.PacketConn.SetReadDeadline(time.Now().Add(c.timeout)) if err != nil { - return M.Socksaddr{}, err + return } destination, err = c.PacketConn.ReadPacket(buffer) if err == nil { @@ -43,7 +44,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa return } } else { - return M.Socksaddr{}, err + return } } } @@ -66,6 +67,7 @@ func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) { } func (c *TimeoutPacketConn) Close() error { + c.cancel(net.ErrClosed) return c.PacketConn.Close() } diff --git a/common/clear.go b/common/clear.go new file mode 100644 index 0000000..9768e9c --- /dev/null +++ b/common/clear.go @@ -0,0 +1,11 @@ +//go:build go1.21 + +package common + +func ClearArray[T ~[]E, E any](t T) { + clear(t) +} + +func ClearMap[T ~map[K]V, K comparable, V any](t T) { + clear(t) +} diff --git a/common/clear_compat.go b/common/clear_compat.go new file mode 100644 index 0000000..4e7e9cd --- /dev/null +++ b/common/clear_compat.go @@ -0,0 +1,16 @@ +//go:build !go1.21 + +package common + +func ClearArray[T ~[]E, E any](t T) { + var defaultValue E + for i := range t { + t[i] = defaultValue + } +} + +func ClearMap[T ~map[K]V, K comparable, V any](t T) { + for k := range t { + delete(t, k) + } +} diff --git a/common/cond.go b/common/cond.go index 24458a5..a4c66a7 100644 --- a/common/cond.go +++ b/common/cond.go @@ -159,20 +159,14 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int { //go:norace func Dup[T any](obj T) T { - if UnsafeBuffer { - pointer := uintptr(unsafe.Pointer(&obj)) - //nolint:staticcheck - //goland:noinspection GoVetUnsafePointer - return *(*T)(unsafe.Pointer(pointer)) - } else { - return obj - } + pointer := uintptr(unsafe.Pointer(&obj)) + //nolint:staticcheck + //goland:noinspection GoVetUnsafePointer + return *(*T)(unsafe.Pointer(pointer)) } func KeepAlive(obj any) { - if UnsafeBuffer { - runtime.KeepAlive(obj) - } + runtime.KeepAlive(obj) } func Uniq[T comparable](arr []T) []T { @@ -342,6 +336,10 @@ func DefaultValue[T any]() T { return defaultValue } +func Ptr[T any](obj T) *T { + return &obj +} + func Close(closers ...any) error { var retErr error for _, closer := range closers { diff --git a/common/control/bind.go b/common/control/bind.go index 94c621c..0dd2107 100644 --- a/common/control/bind.go +++ b/common/control/bind.go @@ -1,59 +1,35 @@ package control import ( - "os" - "runtime" "syscall" + E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) func BindToInterface(finder InterfaceFinder, interfaceName string, interfaceIndex int) Func { return func(network, address string, conn syscall.RawConn) error { - return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex) + return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false) } } -func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int)) Func { +func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int, err error)) Func { return func(network, address string, conn syscall.RawConn) error { - interfaceName, interfaceIndex := block(network, address) - return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex) + interfaceName, interfaceIndex, err := block(network, address) + if err != nil { + return err + } + return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false) } } -const useInterfaceName = runtime.GOOS == "linux" || runtime.GOOS == "android" - -func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { +func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int, preferInterfaceName bool) error { + if interfaceName == "" && interfaceIndex == -1 { + return E.New("interface not found: ", interfaceName) + } if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) { return nil } - if interfaceName == "" && interfaceIndex == -1 { - return nil - } - if interfaceName != "" && useInterfaceName || interfaceIndex != -1 && !useInterfaceName { - return bindToInterface(conn, network, address, interfaceName, interfaceIndex) - } - if finder == nil { - return os.ErrInvalid - } - var err error - if useInterfaceName { - interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex) - } else { - interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) - } - if err != nil { - return err - } - if useInterfaceName { - if interfaceName == "" { - return nil - } - } else { - if interfaceIndex == -1 { - return nil - } - } - return bindToInterface(conn, network, address, interfaceName, interfaceIndex) + return bindToInterface(conn, network, address, finder, interfaceName, interfaceIndex, preferInterfaceName) } diff --git a/common/control/bind_darwin.go b/common/control/bind_darwin.go index 8262ac7..bff6c29 100644 --- a/common/control/bind_darwin.go +++ b/common/control/bind_darwin.go @@ -1,16 +1,24 @@ package control import ( + "os" "syscall" "golang.org/x/sys/unix" ) -func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { - if interfaceIndex == -1 { - return nil - } +func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error { return Raw(conn, func(fd uintptr) error { + var err error + if interfaceIndex == -1 { + if finder == nil { + return os.ErrInvalid + } + interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + if err != nil { + return err + } + } switch network { case "tcp6", "udp6": return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, interfaceIndex) diff --git a/common/control/bind_finder.go b/common/control/bind_finder.go index 21820fb..2383217 100644 --- a/common/control/bind_finder.go +++ b/common/control/bind_finder.go @@ -1,30 +1,21 @@ package control -import "net" +import ( + "net" + "net/netip" +) type InterfaceFinder interface { + Interfaces() []Interface InterfaceIndexByName(name string) (int, error) InterfaceNameByIndex(index int) (string, error) + InterfaceByAddr(addr netip.Addr) (*Interface, error) } -func DefaultInterfaceFinder() InterfaceFinder { - return (*netInterfaceFinder)(nil) -} - -type netInterfaceFinder struct{} - -func (w *netInterfaceFinder) InterfaceIndexByName(name string) (int, error) { - netInterface, err := net.InterfaceByName(name) - if err != nil { - return 0, err - } - return netInterface.Index, nil -} - -func (w *netInterfaceFinder) InterfaceNameByIndex(index int) (string, error) { - netInterface, err := net.InterfaceByIndex(index) - if err != nil { - return "", err - } - return netInterface.Name, nil +type Interface struct { + Index int + MTU int + Name string + Addresses []netip.Prefix + HardwareAddr net.HardwareAddr } diff --git a/common/control/bind_finder_default.go b/common/control/bind_finder_default.go new file mode 100644 index 0000000..9d9230e --- /dev/null +++ b/common/control/bind_finder_default.go @@ -0,0 +1,104 @@ +package control + +import ( + "net" + "net/netip" + _ "unsafe" + + "github.com/sagernet/sing/common" + M "github.com/sagernet/sing/common/metadata" +) + +var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil) + +type DefaultInterfaceFinder struct { + interfaces []Interface +} + +func NewDefaultInterfaceFinder() *DefaultInterfaceFinder { + return &DefaultInterfaceFinder{} +} + +func (f *DefaultInterfaceFinder) Update() error { + netIfs, err := net.Interfaces() + if err != nil { + return err + } + interfaces := make([]Interface, 0, len(netIfs)) + for _, netIf := range netIfs { + ifAddrs, err := netIf.Addrs() + if err != nil { + return err + } + interfaces = append(interfaces, Interface{ + Index: netIf.Index, + MTU: netIf.MTU, + Name: netIf.Name, + Addresses: common.Map(ifAddrs, M.PrefixFromNet), + HardwareAddr: netIf.HardwareAddr, + }) + } + f.interfaces = interfaces + return nil +} + +func (f *DefaultInterfaceFinder) UpdateInterfaces(interfaces []Interface) { + f.interfaces = interfaces +} + +func (f *DefaultInterfaceFinder) Interfaces() []Interface { + return f.interfaces +} + +func (f *DefaultInterfaceFinder) InterfaceIndexByName(name string) (int, error) { + for _, netInterface := range f.interfaces { + if netInterface.Name == name { + return netInterface.Index, nil + } + } + netInterface, err := net.InterfaceByName(name) + if err != nil { + return 0, err + } + f.Update() + return netInterface.Index, nil +} + +func (f *DefaultInterfaceFinder) InterfaceNameByIndex(index int) (string, error) { + for _, netInterface := range f.interfaces { + if netInterface.Index == index { + return netInterface.Name, nil + } + } + netInterface, err := net.InterfaceByIndex(index) + if err != nil { + return "", err + } + f.Update() + return netInterface.Name, nil +} + +//go:linkname errNoSuchInterface net.errNoSuchInterface +var errNoSuchInterface error + +func (f *DefaultInterfaceFinder) InterfaceByAddr(addr netip.Addr) (*Interface, error) { + for _, netInterface := range f.interfaces { + for _, prefix := range netInterface.Addresses { + if prefix.Contains(addr) { + return &netInterface, nil + } + } + } + err := f.Update() + if err != nil { + return nil, err + } + for _, netInterface := range f.interfaces { + for _, prefix := range netInterface.Addresses { + if prefix.Contains(addr) { + return &netInterface, nil + } + } + } + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: addr.AsSlice()}, Err: errNoSuchInterface} +} diff --git a/common/control/bind_linux.go b/common/control/bind_linux.go index 6ebca49..c92bf6b 100644 --- a/common/control/bind_linux.go +++ b/common/control/bind_linux.go @@ -1,13 +1,42 @@ package control import ( + "os" "syscall" + "github.com/sagernet/sing/common/atomic" + E "github.com/sagernet/sing/common/exceptions" + "golang.org/x/sys/unix" ) -func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { +var ifIndexDisabled atomic.Bool + +func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error { return Raw(conn, func(fd uintptr) error { + if !preferInterfaceName && !ifIndexDisabled.Load() { + if interfaceIndex == -1 { + if interfaceName == "" { + return os.ErrInvalid + } + var err error + interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + if err != nil { + return err + } + } + err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex) + if err == nil { + return nil + } else if E.IsMulti(err, unix.ENOPROTOOPT, unix.EINVAL) { + ifIndexDisabled.Store(true) + } else { + return err + } + } + if interfaceName == "" { + return os.ErrInvalid + } return unix.BindToDevice(int(fd), interfaceName) }) } diff --git a/common/control/bind_other.go b/common/control/bind_other.go index 27d0497..23a884f 100644 --- a/common/control/bind_other.go +++ b/common/control/bind_other.go @@ -4,6 +4,6 @@ package control import "syscall" -func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { +func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error { return nil } diff --git a/common/control/bind_windows.go b/common/control/bind_windows.go index 5e23bf1..a499556 100644 --- a/common/control/bind_windows.go +++ b/common/control/bind_windows.go @@ -2,17 +2,28 @@ package control import ( "encoding/binary" + "os" "syscall" "unsafe" M "github.com/sagernet/sing/common/metadata" ) -func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { +func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error { return Raw(conn, func(fd uintptr) error { + var err error + if interfaceIndex == -1 { + if finder == nil { + return os.ErrInvalid + } + interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + if err != nil { + return err + } + } handle := syscall.Handle(fd) if M.ParseSocksaddr(address).AddrString() == "" { - err := bind4(handle, interfaceIndex) + err = bind4(handle, interfaceIndex) if err != nil { return err } diff --git a/common/control/interface.go b/common/control/interface.go index f778a4b..01f07b4 100644 --- a/common/control/interface.go +++ b/common/control/interface.go @@ -3,6 +3,7 @@ package control import ( "syscall" + "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" ) @@ -30,6 +31,14 @@ func Conn(conn syscall.Conn, block func(fd uintptr) error) error { return Raw(rawConn, block) } +func Conn0[T any](conn syscall.Conn, block func(fd uintptr) (T, error)) (T, error) { + rawConn, err := conn.SyscallConn() + if err != nil { + return common.DefaultValue[T](), err + } + return Raw0[T](rawConn, block) +} + func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error { var innerErr error err := rawConn.Control(func(fd uintptr) { @@ -37,3 +46,14 @@ func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error { }) return E.Errors(innerErr, err) } + +func Raw0[T any](rawConn syscall.RawConn, block func(fd uintptr) (T, error)) (T, error) { + var ( + value T + innerErr error + ) + err := rawConn.Control(func(fd uintptr) { + value, innerErr = block(fd) + }) + return value, E.Errors(innerErr, err) +} diff --git a/common/control/mark_linux.go b/common/control/mark_linux.go index b89c1ef..da52f02 100644 --- a/common/control/mark_linux.go +++ b/common/control/mark_linux.go @@ -4,10 +4,10 @@ import ( "syscall" ) -func RoutingMark(mark int) Func { +func RoutingMark(mark uint32) Func { return func(network, address string, conn syscall.RawConn) error { return Raw(conn, func(fd uintptr) error { - return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark) + return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(mark)) }) } } diff --git a/common/control/mark_other.go b/common/control/mark_other.go index eb76c0b..a4298f3 100644 --- a/common/control/mark_other.go +++ b/common/control/mark_other.go @@ -2,6 +2,6 @@ package control -func RoutingMark(mark int) Func { +func RoutingMark(mark uint32) Func { return nil } diff --git a/common/control/redirect_darwin.go b/common/control/redirect_darwin.go new file mode 100644 index 0000000..50db3d8 --- /dev/null +++ b/common/control/redirect_darwin.go @@ -0,0 +1,58 @@ +package control + +import ( + "encoding/binary" + "net" + "net/netip" + "syscall" + "unsafe" + + M "github.com/sagernet/sing/common/metadata" + + "golang.org/x/sys/unix" +) + +const ( + PF_OUT = 0x2 + DIOCNATLOOK = 0xc0544417 +) + +func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) { + pfFd, err := syscall.Open("/dev/pf", 0, syscall.O_RDONLY) + if err != nil { + return netip.AddrPort{}, err + } + defer syscall.Close(pfFd) + nl := struct { + saddr, daddr, rsaddr, rdaddr [16]byte + sxport, dxport, rsxport, rdxport [4]byte + af, proto, protoVariant, direction uint8 + }{ + af: syscall.AF_INET, + proto: syscall.IPPROTO_TCP, + direction: PF_OUT, + } + localAddr := M.SocksaddrFromNet(conn.LocalAddr()).Unwrap() + removeAddr := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap() + if localAddr.IsIPv4() { + copy(nl.saddr[:net.IPv4len], removeAddr.Addr.AsSlice()) + copy(nl.daddr[:net.IPv4len], localAddr.Addr.AsSlice()) + nl.af = syscall.AF_INET + } else { + copy(nl.saddr[:], removeAddr.Addr.AsSlice()) + copy(nl.daddr[:], localAddr.Addr.AsSlice()) + nl.af = syscall.AF_INET6 + } + binary.BigEndian.PutUint16(nl.sxport[:], removeAddr.Port) + binary.BigEndian.PutUint16(nl.dxport[:], localAddr.Port) + if _, _, errno := unix.Syscall(syscall.SYS_IOCTL, uintptr(pfFd), DIOCNATLOOK, uintptr(unsafe.Pointer(&nl))); errno != 0 { + return netip.AddrPort{}, errno + } + var address netip.Addr + if nl.af == unix.AF_INET { + address = M.AddrFromIP(nl.rdaddr[:net.IPv4len]) + } else { + address = netip.AddrFrom16(nl.rdaddr) + } + return netip.AddrPortFrom(address, binary.BigEndian.Uint16(nl.rdxport[:])), nil +} diff --git a/common/control/redirect_linux.go b/common/control/redirect_linux.go new file mode 100644 index 0000000..82ab233 --- /dev/null +++ b/common/control/redirect_linux.go @@ -0,0 +1,38 @@ +package control + +import ( + "encoding/binary" + "net" + "net/netip" + "os" + "syscall" + + "github.com/sagernet/sing/common" + M "github.com/sagernet/sing/common/metadata" + + "golang.org/x/sys/unix" +) + +func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) { + syscallConn, loaded := common.Cast[syscall.Conn](conn) + if !loaded { + return netip.AddrPort{}, os.ErrInvalid + } + return Conn0[netip.AddrPort](syscallConn, func(fd uintptr) (netip.AddrPort, error) { + if M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap().IsIPv4() { + raw, err := unix.GetsockoptIPv6Mreq(int(fd), unix.IPPROTO_IP, unix.SO_ORIGINAL_DST) + if err != nil { + return netip.AddrPort{}, err + } + return netip.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil + } else { + raw, err := unix.GetsockoptIPv6MTUInfo(int(fd), unix.IPPROTO_IPV6, unix.SO_ORIGINAL_DST) + if err != nil { + return netip.AddrPort{}, err + } + var port [2]byte + binary.BigEndian.PutUint16(port[:], raw.Addr.Port) + return netip.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), binary.LittleEndian.Uint16(port[:])), nil + } + }) +} diff --git a/common/control/redirect_other.go b/common/control/redirect_other.go new file mode 100644 index 0000000..b0f3297 --- /dev/null +++ b/common/control/redirect_other.go @@ -0,0 +1,13 @@ +//go:build !linux && !darwin + +package control + +import ( + "net" + "net/netip" + "os" +) + +func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) { + return netip.AddrPort{}, os.ErrInvalid +} diff --git a/common/control/tcp_keep_alive_linux.go b/common/control/tcp_keep_alive_linux.go new file mode 100644 index 0000000..bede11a --- /dev/null +++ b/common/control/tcp_keep_alive_linux.go @@ -0,0 +1,30 @@ +package control + +import ( + "syscall" + "time" + _ "unsafe" + + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" + + "golang.org/x/sys/unix" +) + +func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func { + return func(network, address string, conn syscall.RawConn) error { + if N.NetworkName(network) != N.NetworkTCP { + return nil + } + return Raw(conn, func(fd uintptr) error { + return E.Errors( + unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPIDLE, int(roundDurationUp(idle, time.Second))), + unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPINTVL, int(roundDurationUp(interval, time.Second))), + ) + }) + } +} + +func roundDurationUp(d time.Duration, to time.Duration) time.Duration { + return (d + to - 1) / to +} diff --git a/common/control/tcp_keep_alive_stub.go b/common/control/tcp_keep_alive_stub.go new file mode 100644 index 0000000..180d8d3 --- /dev/null +++ b/common/control/tcp_keep_alive_stub.go @@ -0,0 +1,11 @@ +//go:build !linux + +package control + +import ( + "time" +) + +func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func { + return nil +} diff --git a/common/control/tproxy_linux.go b/common/control/tproxy_linux.go new file mode 100644 index 0000000..b296b98 --- /dev/null +++ b/common/control/tproxy_linux.go @@ -0,0 +1,56 @@ +package control + +import ( + "encoding/binary" + "net/netip" + "syscall" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + + "golang.org/x/sys/unix" +) + +func TProxy(fd uintptr, family int) error { + err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + if err == nil { + err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1) + } + if err == nil && family == unix.AF_INET6 { + err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1) + } + if err == nil { + err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1) + } + if err == nil && family == unix.AF_INET6 { + err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1) + } + return err +} + +func TProxyWriteBack() Func { + return func(network, address string, conn syscall.RawConn) error { + return Raw(conn, func(fd uintptr) error { + if M.ParseSocksaddr(address).Addr.Is6() { + return syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1) + } else { + return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1) + } + }) + } +} + +func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) { + controlMessages, err := unix.ParseSocketControlMessage(oob) + if err != nil { + return netip.AddrPort{}, err + } + for _, message := range controlMessages { + if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR { + return netip.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil + } else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR { + return netip.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil + } + } + return netip.AddrPort{}, E.New("not found") +} diff --git a/common/control/tproxy_other.go b/common/control/tproxy_other.go new file mode 100644 index 0000000..cad1808 --- /dev/null +++ b/common/control/tproxy_other.go @@ -0,0 +1,20 @@ +//go:build !linux + +package control + +import ( + "net/netip" + "os" +) + +func TProxy(fd uintptr, isIPv6 bool) error { + return os.ErrInvalid +} + +func TProxyWriteBack() Func { + return nil +} + +func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) { + return netip.AddrPort{}, os.ErrInvalid +} diff --git a/common/domain/matcher.go b/common/domain/matcher.go index 95dc0dc..83cbd48 100644 --- a/common/domain/matcher.go +++ b/common/domain/matcher.go @@ -1,8 +1,12 @@ package domain import ( + "encoding/binary" + "io" "sort" "unicode/utf8" + + "github.com/sagernet/sing/common/rw" ) type Matcher struct { @@ -10,14 +14,19 @@ type Matcher struct { } func NewMatcher(domains []string, domainSuffix []string) *Matcher { - domainList := make([]string, 0, len(domains)+len(domainSuffix)) + domainList := make([]string, 0, len(domains)+2*len(domainSuffix)) seen := make(map[string]bool, len(domainList)) for _, domain := range domainSuffix { if seen[domain] { continue } seen[domain] = true - domainList = append(domainList, reverseDomainSuffix(domain)) + if domain[0] == '.' { + domainList = append(domainList, reverseDomainSuffix(domain)) + } else { + domainList = append(domainList, reverseDomain(domain)) + domainList = append(domainList, reverseRootDomainSuffix(domain)) + } } for _, domain := range domains { if seen[domain] { @@ -27,15 +36,87 @@ func NewMatcher(domains []string, domainSuffix []string) *Matcher { domainList = append(domainList, reverseDomain(domain)) } sort.Strings(domainList) - return &Matcher{ - newSuccinctSet(domainList), + return &Matcher{newSuccinctSet(domainList)} +} + +func ReadMatcher(reader io.Reader) (*Matcher, error) { + var version uint8 + err := binary.Read(reader, binary.BigEndian, &version) + if err != nil { + return nil, err } + leavesLength, err := rw.ReadUVariant(reader) + if err != nil { + return nil, err + } + leaves := make([]uint64, leavesLength) + err = binary.Read(reader, binary.BigEndian, leaves) + if err != nil { + return nil, err + } + labelBitmapLength, err := rw.ReadUVariant(reader) + if err != nil { + return nil, err + } + labelBitmap := make([]uint64, labelBitmapLength) + err = binary.Read(reader, binary.BigEndian, labelBitmap) + if err != nil { + return nil, err + } + labelsLength, err := rw.ReadUVariant(reader) + if err != nil { + return nil, err + } + labels := make([]byte, labelsLength) + _, err = io.ReadFull(reader, labels) + if err != nil { + return nil, err + } + set := &succinctSet{ + leaves: leaves, + labelBitmap: labelBitmap, + labels: labels, + } + set.init() + return &Matcher{set}, nil } func (m *Matcher) Match(domain string) bool { return m.set.Has(reverseDomain(domain)) } +func (m *Matcher) Write(writer io.Writer) error { + err := binary.Write(writer, binary.BigEndian, byte(1)) + if err != nil { + return err + } + err = rw.WriteUVariant(writer, uint64(len(m.set.leaves))) + if err != nil { + return err + } + err = binary.Write(writer, binary.BigEndian, m.set.leaves) + if err != nil { + return err + } + err = rw.WriteUVariant(writer, uint64(len(m.set.labelBitmap))) + if err != nil { + return err + } + err = binary.Write(writer, binary.BigEndian, m.set.labelBitmap) + if err != nil { + return err + } + err = rw.WriteUVariant(writer, uint64(len(m.set.labels))) + if err != nil { + return err + } + _, err = writer.Write(m.set.labels) + if err != nil { + return err + } + return nil +} + func reverseDomain(domain string) string { l := len(domain) b := make([]byte, l) @@ -58,3 +139,16 @@ func reverseDomainSuffix(domain string) string { b[l] = prefixLabel return string(b) } + +func reverseRootDomainSuffix(domain string) string { + l := len(domain) + b := make([]byte, l+2) + for i := 0; i < l; { + r, n := utf8.DecodeRuneInString(domain[i:]) + i += n + utf8.EncodeRune(b[l-i:], r) + } + b[l] = '.' + b[l+1] = prefixLabel + return string(b) +} diff --git a/common/exceptions/cause.go b/common/exceptions/cause.go index 27211f2..fe7adf3 100644 --- a/common/exceptions/cause.go +++ b/common/exceptions/cause.go @@ -6,9 +6,6 @@ type causeError struct { } func (e *causeError) Error() string { - if e.cause == nil { - return e.message - } return e.message + ": " + e.cause.Error() } diff --git a/common/exceptions/error.go b/common/exceptions/error.go index cf5a3da..5d056e6 100644 --- a/common/exceptions/error.go +++ b/common/exceptions/error.go @@ -26,14 +26,14 @@ func New(message ...any) error { func Cause(cause error, message ...any) error { if cause == nil { - return nil + panic("cause on an nil error") } return &causeError{F.ToString(message...), cause} } func Extend(cause error, message ...any) error { if cause == nil { - return nil + panic("extend on an nil error") } return &extendedError{F.ToString(message...), cause} } diff --git a/common/exceptions/multi.go b/common/exceptions/multi.go index a42f00c..2cdec05 100644 --- a/common/exceptions/multi.go +++ b/common/exceptions/multi.go @@ -23,6 +23,7 @@ func (e *multiError) Unwrap() []error { func Errors(errors ...error) error { errors = common.FilterNotNil(errors) errors = ExpandAll(errors) + errors = common.FilterNotNil(errors) errors = common.UniqBy(errors, error.Error) switch len(errors) { case 0: @@ -36,10 +37,13 @@ func Errors(errors ...error) error { } func Expand(err error) []error { - if multiErr, isMultiErr := err.(MultiError); isMultiErr { - return ExpandAll(multiErr.Unwrap()) + if err == nil { + return nil + } else if multiErr, isMultiErr := err.(MultiError); isMultiErr { + return ExpandAll(common.FilterNotNil(multiErr.Unwrap())) + } else { + return []error{err} } - return []error{err} } func ExpandAll(errs []error) []error { diff --git a/common/json/badjson/array.go b/common/json/badjson/array.go new file mode 100644 index 0000000..a7d5f70 --- /dev/null +++ b/common/json/badjson/array.go @@ -0,0 +1,59 @@ +package badjson + +import ( + "bytes" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" +) + +type JSONArray []any + +func (a JSONArray) IsEmpty() bool { + if len(a) == 0 { + return true + } + return common.All(a, func(it any) bool { + if valueInterface, valueMaybeEmpty := it.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() { + return true + } + return false + }) +} + +func (a JSONArray) MarshalJSON() ([]byte, error) { + return json.Marshal([]any(a)) +} + +func (a *JSONArray) UnmarshalJSON(content []byte) error { + decoder := json.NewDecoder(bytes.NewReader(content)) + arrayStart, err := decoder.Token() + if err != nil { + return err + } else if arrayStart != json.Delim('[') { + return E.New("excepted array start, but got ", arrayStart) + } + err = a.decodeJSON(decoder) + if err != nil { + return err + } + arrayEnd, err := decoder.Token() + if err != nil { + return err + } else if arrayEnd != json.Delim(']') { + return E.New("excepted array end, but got ", arrayEnd) + } + return nil +} + +func (a *JSONArray) decodeJSON(decoder *json.Decoder) error { + for decoder.More() { + item, err := decodeJSON(decoder) + if err != nil { + return err + } + *a = append(*a, item) + } + return nil +} diff --git a/common/json/badjson/empty.go b/common/json/badjson/empty.go new file mode 100644 index 0000000..da8bedf --- /dev/null +++ b/common/json/badjson/empty.go @@ -0,0 +1,5 @@ +package badjson + +type isEmpty interface { + IsEmpty() bool +} diff --git a/common/json/badjson/json.go b/common/json/badjson/json.go new file mode 100644 index 0000000..04dba1e --- /dev/null +++ b/common/json/badjson/json.go @@ -0,0 +1,54 @@ +package badjson + +import ( + "bytes" + + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" +) + +func Decode(content []byte) (any, error) { + decoder := json.NewDecoder(bytes.NewReader(content)) + return decodeJSON(decoder) +} + +func decodeJSON(decoder *json.Decoder) (any, error) { + rawToken, err := decoder.Token() + if err != nil { + return nil, err + } + switch token := rawToken.(type) { + case json.Delim: + switch token { + case '{': + var object JSONObject + err = object.decodeJSON(decoder) + if err != nil { + return nil, err + } + rawToken, err = decoder.Token() + if err != nil { + return nil, err + } else if rawToken != json.Delim('}') { + return nil, E.New("excepted object end, but got ", rawToken) + } + return &object, nil + case '[': + var array JSONArray + err = array.decodeJSON(decoder) + if err != nil { + return nil, err + } + rawToken, err = decoder.Token() + if err != nil { + return nil, err + } else if rawToken != json.Delim(']') { + return nil, E.New("excepted array end, but got ", rawToken) + } + return array, nil + default: + return nil, E.New("excepted object or array end: ", token) + } + } + return rawToken, nil +} diff --git a/common/json/badjson/merge.go b/common/json/badjson/merge.go new file mode 100644 index 0000000..35e9494 --- /dev/null +++ b/common/json/badjson/merge.go @@ -0,0 +1,139 @@ +package badjson + +import ( + "os" + "reflect" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" +) + +func Omitempty[T any](value T) (T, error) { + objectContent, err := json.Marshal(value) + if err != nil { + return common.DefaultValue[T](), E.Cause(err, "marshal object") + } + rawNewObject, err := Decode(objectContent) + if err != nil { + return common.DefaultValue[T](), err + } + newObjectContent, err := json.Marshal(rawNewObject) + if err != nil { + return common.DefaultValue[T](), E.Cause(err, "marshal new object") + } + var newObject T + err = json.Unmarshal(newObjectContent, &newObject) + if err != nil { + return common.DefaultValue[T](), E.Cause(err, "unmarshal new object") + } + return newObject, nil +} + +func Merge[T any](source T, destination T) (T, error) { + rawSource, err := json.Marshal(source) + if err != nil { + return common.DefaultValue[T](), E.Cause(err, "marshal source") + } + rawDestination, err := json.Marshal(destination) + if err != nil { + return common.DefaultValue[T](), E.Cause(err, "marshal destination") + } + return MergeFrom[T](rawSource, rawDestination) +} + +func MergeFromSource[T any](rawSource json.RawMessage, destination T) (T, error) { + if rawSource == nil { + return destination, nil + } + rawDestination, err := json.Marshal(destination) + if err != nil { + return common.DefaultValue[T](), E.Cause(err, "marshal destination") + } + return MergeFrom[T](rawSource, rawDestination) +} + +func MergeFromDestination[T any](source T, rawDestination json.RawMessage) (T, error) { + if rawDestination == nil { + return source, nil + } + rawSource, err := json.Marshal(source) + if err != nil { + return common.DefaultValue[T](), E.Cause(err, "marshal source") + } + return MergeFrom[T](rawSource, rawDestination) +} + +func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage) (T, error) { + rawMerged, err := MergeJSON(rawSource, rawDestination) + if err != nil { + return common.DefaultValue[T](), E.Cause(err, "merge options") + } + var merged T + err = json.Unmarshal(rawMerged, &merged) + if err != nil { + return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options") + } + return merged, nil +} + +func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage) (json.RawMessage, error) { + if rawSource == nil && rawDestination == nil { + return nil, os.ErrInvalid + } else if rawSource == nil { + return rawDestination, nil + } else if rawDestination == nil { + return rawSource, nil + } + source, err := Decode(rawSource) + if err != nil { + return nil, E.Cause(err, "decode source") + } + destination, err := Decode(rawDestination) + if err != nil { + return nil, E.Cause(err, "decode destination") + } + if source == nil { + return json.Marshal(destination) + } else if destination == nil { + return json.Marshal(source) + } + merged, err := mergeJSON(source, destination) + if err != nil { + return nil, err + } + return json.Marshal(merged) +} + +func mergeJSON(anySource any, anyDestination any) (any, error) { + switch destination := anyDestination.(type) { + case JSONArray: + switch source := anySource.(type) { + case JSONArray: + destination = append(destination, source...) + default: + destination = append(destination, source) + } + return destination, nil + case *JSONObject: + switch source := anySource.(type) { + case *JSONObject: + for _, entry := range source.Entries() { + oldValue, loaded := destination.Get(entry.Key) + if loaded { + var err error + entry.Value, err = mergeJSON(entry.Value, oldValue) + if err != nil { + return nil, E.Cause(err, "merge object item ", entry.Key) + } + } + destination.Put(entry.Key, entry.Value) + } + default: + return nil, E.New("cannot merge json object into ", reflect.TypeOf(source)) + } + return destination, nil + default: + return destination, nil + } +} diff --git a/common/json/badjson/object.go b/common/json/badjson/object.go new file mode 100644 index 0000000..61d5862 --- /dev/null +++ b/common/json/badjson/object.go @@ -0,0 +1,98 @@ +package badjson + +import ( + "bytes" + "strings" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/x/collections" + "github.com/sagernet/sing/common/x/linkedhashmap" +) + +type JSONObject struct { + linkedhashmap.Map[string, any] +} + +func (m *JSONObject) IsEmpty() bool { + if m.Size() == 0 { + return true + } + return common.All(m.Entries(), func(it collections.MapEntry[string, any]) bool { + if valueInterface, valueMaybeEmpty := it.Value.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() { + return true + } + return false + }) +} + +func (m *JSONObject) MarshalJSON() ([]byte, error) { + buffer := new(bytes.Buffer) + buffer.WriteString("{") + items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool { + if valueInterface, valueMaybeEmpty := it.Value.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() { + return false + } + return true + }) + iLen := len(items) + for i, entry := range items { + keyContent, err := json.Marshal(entry.Key) + if err != nil { + return nil, err + } + buffer.WriteString(strings.TrimSpace(string(keyContent))) + buffer.WriteString(": ") + valueContent, err := json.Marshal(entry.Value) + if err != nil { + return nil, err + } + buffer.WriteString(strings.TrimSpace(string(valueContent))) + if i < iLen-1 { + buffer.WriteString(", ") + } + } + buffer.WriteString("}") + return buffer.Bytes(), nil +} + +func (m *JSONObject) UnmarshalJSON(content []byte) error { + decoder := json.NewDecoder(bytes.NewReader(content)) + m.Clear() + objectStart, err := decoder.Token() + if err != nil { + return err + } else if objectStart != json.Delim('{') { + return E.New("expected json object start, but starts with ", objectStart) + } + err = m.decodeJSON(decoder) + if err != nil { + return E.Cause(err, "decode json object content") + } + objectEnd, err := decoder.Token() + if err != nil { + return err + } else if objectEnd != json.Delim('}') { + return E.New("expected json object end, but ends with ", objectEnd) + } + return nil +} + +func (m *JSONObject) decodeJSON(decoder *json.Decoder) error { + for decoder.More() { + var entryKey string + keyToken, err := decoder.Token() + if err != nil { + return err + } + entryKey = keyToken.(string) + var entryValue any + entryValue, err = decodeJSON(decoder) + if err != nil { + return E.Cause(err, "decode value for ", entryKey) + } + m.Put(entryKey, entryValue) + } + return nil +} diff --git a/common/json/badjson/typed.go b/common/json/badjson/typed.go new file mode 100644 index 0000000..66f41a6 --- /dev/null +++ b/common/json/badjson/typed.go @@ -0,0 +1,86 @@ +package badjson + +import ( + "bytes" + "strings" + + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/x/linkedhashmap" +) + +type TypedMap[K comparable, V any] struct { + linkedhashmap.Map[K, V] +} + +func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) { + buffer := new(bytes.Buffer) + buffer.WriteString("{") + items := m.Entries() + iLen := len(items) + for i, entry := range items { + keyContent, err := json.Marshal(entry.Key) + if err != nil { + return nil, err + } + buffer.WriteString(strings.TrimSpace(string(keyContent))) + buffer.WriteString(": ") + valueContent, err := json.Marshal(entry.Value) + if err != nil { + return nil, err + } + buffer.WriteString(strings.TrimSpace(string(valueContent))) + if i < iLen-1 { + buffer.WriteString(", ") + } + } + buffer.WriteString("}") + return buffer.Bytes(), nil +} + +func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error { + decoder := json.NewDecoder(bytes.NewReader(content)) + m.Clear() + objectStart, err := decoder.Token() + if err != nil { + return err + } else if objectStart != json.Delim('{') { + return E.New("expected json object start, but starts with ", objectStart) + } + err = m.decodeJSON(decoder) + if err != nil { + return E.Cause(err, "decode json object content") + } + objectEnd, err := decoder.Token() + if err != nil { + return err + } else if objectEnd != json.Delim('}') { + return E.New("expected json object end, but ends with ", objectEnd) + } + return nil +} + +func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error { + for decoder.More() { + keyToken, err := decoder.Token() + if err != nil { + return err + } + keyContent, err := json.Marshal(keyToken) + if err != nil { + return err + } + var entryKey K + err = json.Unmarshal(keyContent, &entryKey) + if err != nil { + return err + } + var entryValue V + err = decoder.Decode(&entryValue) + if err != nil { + return err + } + m.Put(entryKey, entryValue) + } + return nil +} diff --git a/common/json/comment.go b/common/json/comment.go new file mode 100644 index 0000000..6f3be26 --- /dev/null +++ b/common/json/comment.go @@ -0,0 +1,128 @@ +package json + +import ( + "bufio" + "io" +) + +// kanged from v2ray + +type commentFilterState = byte + +const ( + commentFilterStateContent commentFilterState = iota + commentFilterStateEscape + commentFilterStateDoubleQuote + commentFilterStateDoubleQuoteEscape + commentFilterStateSingleQuote + commentFilterStateSingleQuoteEscape + commentFilterStateComment + commentFilterStateSlash + commentFilterStateMultilineComment + commentFilterStateMultilineCommentStar +) + +type CommentFilter struct { + br *bufio.Reader + state commentFilterState +} + +func NewCommentFilter(reader io.Reader) io.Reader { + return &CommentFilter{br: bufio.NewReader(reader)} +} + +func (v *CommentFilter) Read(b []byte) (int, error) { + p := b[:0] + for len(p) < len(b)-2 { + x, err := v.br.ReadByte() + if err != nil { + if len(p) == 0 { + return 0, err + } + return len(p), nil + } + switch v.state { + case commentFilterStateContent: + switch x { + case '"': + v.state = commentFilterStateDoubleQuote + p = append(p, x) + case '\'': + v.state = commentFilterStateSingleQuote + p = append(p, x) + case '\\': + v.state = commentFilterStateEscape + case '#': + v.state = commentFilterStateComment + case '/': + v.state = commentFilterStateSlash + default: + p = append(p, x) + } + case commentFilterStateEscape: + p = append(p, '\\', x) + v.state = commentFilterStateContent + case commentFilterStateDoubleQuote: + switch x { + case '"': + v.state = commentFilterStateContent + p = append(p, x) + case '\\': + v.state = commentFilterStateDoubleQuoteEscape + default: + p = append(p, x) + } + case commentFilterStateDoubleQuoteEscape: + p = append(p, '\\', x) + v.state = commentFilterStateDoubleQuote + case commentFilterStateSingleQuote: + switch x { + case '\'': + v.state = commentFilterStateContent + p = append(p, x) + case '\\': + v.state = commentFilterStateSingleQuoteEscape + default: + p = append(p, x) + } + case commentFilterStateSingleQuoteEscape: + p = append(p, '\\', x) + v.state = commentFilterStateSingleQuote + case commentFilterStateComment: + if x == '\n' { + v.state = commentFilterStateContent + p = append(p, '\n') + } + case commentFilterStateSlash: + switch x { + case '/': + v.state = commentFilterStateComment + case '*': + v.state = commentFilterStateMultilineComment + default: + p = append(p, '/', x) + } + case commentFilterStateMultilineComment: + switch x { + case '*': + v.state = commentFilterStateMultilineCommentStar + case '\n': + p = append(p, '\n') + } + case commentFilterStateMultilineCommentStar: + switch x { + case '/': + v.state = commentFilterStateContent + case '*': + // Stay + case '\n': + p = append(p, '\n') + default: + v.state = commentFilterStateMultilineComment + } + default: + panic("Unknown state.") + } + } + return len(p), nil +} diff --git a/common/json/context.go b/common/json/context.go new file mode 100644 index 0000000..7a49070 --- /dev/null +++ b/common/json/context.go @@ -0,0 +1,23 @@ +//go:build go1.20 && !without_contextjson + +package json + +import ( + "github.com/sagernet/sing/common/json/internal/contextjson" +) + +var ( + Marshal = json.Marshal + Unmarshal = json.Unmarshal + NewEncoder = json.NewEncoder + NewDecoder = json.NewDecoder +) + +type ( + Encoder = json.Encoder + Decoder = json.Decoder + Token = json.Token + Delim = json.Delim + SyntaxError = json.SyntaxError + RawMessage = json.RawMessage +) diff --git a/common/json/internal/contextjson/README.md b/common/json/internal/contextjson/README.md new file mode 100644 index 0000000..da656b7 --- /dev/null +++ b/common/json/internal/contextjson/README.md @@ -0,0 +1,3 @@ +# contextjson + +mod from go1.21.4 diff --git a/common/json/internal/contextjson/decode.go b/common/json/internal/contextjson/decode.go new file mode 100644 index 0000000..8457171 --- /dev/null +++ b/common/json/internal/contextjson/decode.go @@ -0,0 +1,1325 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Represents JSON data structure using native Go types: booleans, floats, +// strings, arrays, and maps. + +package json + +import ( + "encoding" + "encoding/base64" + "fmt" + "reflect" + "strconv" + "strings" + "unicode" + "unicode/utf16" + "unicode/utf8" +) + +// Unmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. If v is nil or not a pointer, +// Unmarshal returns an InvalidUnmarshalError. +// +// Unmarshal uses the inverse of the encodings that +// Marshal uses, allocating maps, slices, and pointers as necessary, +// with the following additional rules: +// +// To unmarshal JSON into a pointer, Unmarshal first handles the case of +// the JSON being the JSON literal null. In that case, Unmarshal sets +// the pointer to nil. Otherwise, Unmarshal unmarshals the JSON into +// the value pointed at by the pointer. If the pointer is nil, Unmarshal +// allocates a new value for it to point to. +// +// To unmarshal JSON into a value implementing the Unmarshaler interface, +// Unmarshal calls that value's UnmarshalJSON method, including +// when the input is a JSON null. +// Otherwise, if the value implements encoding.TextUnmarshaler +// and the input is a JSON quoted string, Unmarshal calls that value's +// UnmarshalText method with the unquoted form of the string. +// +// To unmarshal JSON into a struct, Unmarshal matches incoming object +// keys to the keys used by Marshal (either the struct field name or its tag), +// preferring an exact match but also accepting a case-insensitive match. By +// default, object keys which don't have a corresponding struct field are +// ignored (see Decoder.DisallowUnknownFields for an alternative). +// +// To unmarshal JSON into an interface value, +// Unmarshal stores one of these in the interface value: +// +// bool, for JSON booleans +// float64, for JSON numbers +// string, for JSON strings +// []interface{}, for JSON arrays +// map[string]interface{}, for JSON objects +// nil for JSON null +// +// To unmarshal a JSON array into a slice, Unmarshal resets the slice length +// to zero and then appends each element to the slice. +// As a special case, to unmarshal an empty JSON array into a slice, +// Unmarshal replaces the slice with a new empty slice. +// +// To unmarshal a JSON array into a Go array, Unmarshal decodes +// JSON array elements into corresponding Go array elements. +// If the Go array is smaller than the JSON array, +// the additional JSON array elements are discarded. +// If the JSON array is smaller than the Go array, +// the additional Go array elements are set to zero values. +// +// To unmarshal a JSON object into a map, Unmarshal first establishes a map to +// use. If the map is nil, Unmarshal allocates a new map. Otherwise Unmarshal +// reuses the existing map, keeping existing entries. Unmarshal then stores +// key-value pairs from the JSON object into the map. The map's key type must +// either be any string type, an integer, implement json.Unmarshaler, or +// implement encoding.TextUnmarshaler. +// +// If the JSON-encoded data contain a syntax error, Unmarshal returns a SyntaxError. +// +// If a JSON value is not appropriate for a given target type, +// or if a JSON number overflows the target type, Unmarshal +// skips that field and completes the unmarshaling as best it can. +// If no more serious errors are encountered, Unmarshal returns +// an UnmarshalTypeError describing the earliest such error. In any +// case, it's not guaranteed that all the remaining fields following +// the problematic one will be unmarshaled into the target object. +// +// The JSON null value unmarshals into an interface, map, pointer, or slice +// by setting that Go value to nil. Because null is often used in JSON to mean +// “not present,” unmarshaling a JSON null into any other Go type has no effect +// on the value and produces no error. +// +// When unmarshaling quoted strings, invalid UTF-8 or +// invalid UTF-16 surrogate pairs are not treated as an error. +// Instead, they are replaced by the Unicode replacement +// character U+FFFD. +func Unmarshal(data []byte, v any) error { + // Check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + var d decodeState + err := checkValid(data, &d.scan) + if err != nil { + return err + } + + d.init(data) + return d.unmarshal(v) +} + +// Unmarshaler is the interface implemented by types +// that can unmarshal a JSON description of themselves. +// The input can be assumed to be a valid encoding of +// a JSON value. UnmarshalJSON must copy the JSON data +// if it wishes to retain the data after returning. +// +// By convention, to approximate the behavior of Unmarshal itself, +// Unmarshalers implement UnmarshalJSON([]byte("null")) as a no-op. +type Unmarshaler interface { + UnmarshalJSON([]byte) error +} + +// An UnmarshalTypeError describes a JSON value that was +// not appropriate for a value of a specific Go type. +type UnmarshalTypeError struct { + Value string // description of JSON value - "bool", "array", "number -5" + Type reflect.Type // type of Go value it could not be assigned to + Offset int64 // error occurred after reading Offset bytes + Struct string // name of the struct type containing the field + Field string // the full path from root node to the field +} + +func (e *UnmarshalTypeError) Error() string { + if e.Struct != "" || e.Field != "" { + return "json: cannot unmarshal " + e.Value + " into Go struct field " + e.Struct + "." + e.Field + " of type " + e.Type.String() + } + return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String() +} + +// An UnmarshalFieldError describes a JSON object key that +// led to an unexported (and therefore unwritable) struct field. +// +// Deprecated: No longer used; kept for compatibility. +type UnmarshalFieldError struct { + Key string + Type reflect.Type + Field reflect.StructField +} + +func (e *UnmarshalFieldError) Error() string { + return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String() +} + +// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal. +// (The argument to Unmarshal must be a non-nil pointer.) +type InvalidUnmarshalError struct { + Type reflect.Type +} + +func (e *InvalidUnmarshalError) Error() string { + if e.Type == nil { + return "json: Unmarshal(nil)" + } + + if e.Type.Kind() != reflect.Pointer { + return "json: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "json: Unmarshal(nil " + e.Type.String() + ")" +} + +func (d *decodeState) unmarshal(v any) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} + } + + d.scan.reset() + d.scanWhile(scanSkipSpace) + // We decode rv not rv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + err := d.value(rv) + if err != nil { + return d.addErrorContext(err) + } + return d.savedError +} + +// A Number represents a JSON number literal. +type Number string + +// String returns the literal text of the number. +func (n Number) String() string { return string(n) } + +// Float64 returns the number as a float64. +func (n Number) Float64() (float64, error) { + return strconv.ParseFloat(string(n), 64) +} + +// Int64 returns the number as an int64. +func (n Number) Int64() (int64, error) { + return strconv.ParseInt(string(n), 10, 64) +} + +// An errorContext provides context for type errors during decoding. +type errorContext struct { + Struct reflect.Type + FieldStack []string +} + +// decodeState represents the state while decoding a JSON value. +type decodeState struct { + data []byte + off int // next read offset in data + opcode int // last read result + scan scanner + errorContext *errorContext + savedError error + useNumber bool + disallowUnknownFields bool + context *decodeContext +} + +// readIndex returns the position of the last byte read. +func (d *decodeState) readIndex() int { + return d.off - 1 +} + +// phasePanicMsg is used as a panic message when we end up with something that +// shouldn't happen. It can indicate a bug in the JSON decoder, or that +// something is editing the data slice while the decoder executes. +const phasePanicMsg = "JSON decoder out of sync - data changing underfoot?" + +func (d *decodeState) init(data []byte) *decodeState { + d.data = data + d.off = 0 + d.savedError = nil + if d.errorContext != nil { + d.errorContext.Struct = nil + // Reuse the allocated space for the FieldStack slice. + d.errorContext.FieldStack = d.errorContext.FieldStack[:0] + } + return d +} + +// saveError saves the first err it is called with, +// for reporting at the end of the unmarshal. +func (d *decodeState) saveError(err error) { + if d.savedError == nil { + if d.context != nil { + d.savedError = d.addErrorContext(&contextError{err, d.formatContext(), d.context.key == ""}) + } else { + d.savedError = d.addErrorContext(err) + } + } +} + +// addErrorContext returns a new error enhanced with information from d.errorContext +func (d *decodeState) addErrorContext(err error) error { + if d.errorContext != nil && (d.errorContext.Struct != nil || len(d.errorContext.FieldStack) > 0) { + switch err := err.(type) { + case *UnmarshalTypeError: + err.Struct = d.errorContext.Struct.Name() + err.Field = strings.Join(d.errorContext.FieldStack, ".") + } + } + return err +} + +// skip scans to the end of what was started. +func (d *decodeState) skip() { + s, data, i := &d.scan, d.data, d.off + depth := len(s.parseState) + for { + op := s.step(s, data[i]) + i++ + if len(s.parseState) < depth { + d.off = i + d.opcode = op + return + } + } +} + +// scanNext processes the byte at d.data[d.off]. +func (d *decodeState) scanNext() { + if d.off < len(d.data) { + d.opcode = d.scan.step(&d.scan, d.data[d.off]) + d.off++ + } else { + d.opcode = d.scan.eof() + d.off = len(d.data) + 1 // mark processed EOF with len+1 + } +} + +// scanWhile processes bytes in d.data[d.off:] until it +// receives a scan code not equal to op. +func (d *decodeState) scanWhile(op int) { + s, data, i := &d.scan, d.data, d.off + for i < len(data) { + newOp := s.step(s, data[i]) + i++ + if newOp != op { + d.opcode = newOp + d.off = i + return + } + } + + d.off = len(data) + 1 // mark processed EOF with len+1 + d.opcode = d.scan.eof() +} + +// rescanLiteral is similar to scanWhile(scanContinue), but it specialises the +// common case where we're decoding a literal. The decoder scans the input +// twice, once for syntax errors and to check the length of the value, and the +// second to perform the decoding. +// +// Only in the second step do we use decodeState to tokenize literals, so we +// know there aren't any syntax errors. We can take advantage of that knowledge, +// and scan a literal's bytes much more quickly. +func (d *decodeState) rescanLiteral() { + data, i := d.data, d.off +Switch: + switch data[i-1] { + case '"': // string + for ; i < len(data); i++ { + switch data[i] { + case '\\': + i++ // escaped char + case '"': + i++ // tokenize the closing quote too + break Switch + } + } + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-': // number + for ; i < len(data); i++ { + switch data[i] { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + '.', 'e', 'E', '+', '-': + default: + break Switch + } + } + case 't': // true + i += len("rue") + case 'f': // false + i += len("alse") + case 'n': // null + i += len("ull") + } + if i < len(data) { + d.opcode = stateEndValue(&d.scan, data[i]) + } else { + d.opcode = scanEnd + } + d.off = i + 1 +} + +// value consumes a JSON value from d.data[d.off-1:], decoding into v, and +// reads the following byte ahead. If v is invalid, the value is discarded. +// The first byte of the value has been read already. +func (d *decodeState) value(v reflect.Value) error { + switch d.opcode { + default: + panic(phasePanicMsg) + + case scanBeginArray: + if v.IsValid() { + if err := d.array(v); err != nil { + return err + } + } else { + d.skip() + } + d.scanNext() + + case scanBeginObject: + if v.IsValid() { + if err := d.object(v); err != nil { + return err + } + } else { + d.skip() + } + d.scanNext() + + case scanBeginLiteral: + // All bytes inside literal return scanContinue op code. + start := d.readIndex() + d.rescanLiteral() + + if v.IsValid() { + if err := d.literalStore(d.data[start:d.readIndex()], v, false); err != nil { + return err + } + } + } + return nil +} + +type unquotedValue struct{} + +// valueQuoted is like value but decodes a +// quoted string literal or literal null into an interface value. +// If it finds anything other than a quoted string literal or null, +// valueQuoted returns unquotedValue{}. +func (d *decodeState) valueQuoted() any { + switch d.opcode { + default: + panic(phasePanicMsg) + + case scanBeginArray, scanBeginObject: + d.skip() + d.scanNext() + + case scanBeginLiteral: + v := d.literalInterface() + switch v.(type) { + case nil, string: + return v + } + } + return unquotedValue{} +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// If it encounters an Unmarshaler, indirect stops and returns that. +// If decodingNull is true, indirect stops at the first settable pointer so it +// can be set to nil. +func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { + // Issue #24153 indicates that it is generally not a guaranteed property + // that you may round-trip a reflect.Value by calling Value.Addr().Elem() + // and expect the value to still be settable for values derived from + // unexported embedded struct fields. + // + // The logic below effectively does this when it first addresses the value + // (to satisfy possible pointer methods) and continues to dereference + // subsequent pointers as necessary. + // + // After the first round-trip, we set v back to the original value to + // preserve the original RW flags contained in reflect.Value. + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Pointer && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Pointer && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Pointer) { + haveAddr = false + v = e + continue + } + } + + if v.Kind() != reflect.Pointer { + break + } + + if decodingNull && v.CanSet() { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v { + v = v.Elem() + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 && v.CanInterface() { + if u, ok := v.Interface().(Unmarshaler); ok { + return u, nil, reflect.Value{} + } + if !decodingNull { + if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { + return nil, u, reflect.Value{} + } + } + } + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } + } + return nil, nil, v +} + +// array consumes an array from d.data[d.off-1:], decoding into v. +// The first byte of the array ('[') has been read already. +func (d *decodeState) array(v reflect.Value) error { + // Check for unmarshaler. + u, ut, pv := indirect(v, false) + if u != nil { + start := d.readIndex() + d.skip() + err := u.UnmarshalJSON(d.data[start:d.off]) + if err != nil { + d.saveError(err) + } + return nil + } + if ut != nil { + d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + } + v = pv + + // Check type of target. + switch v.Kind() { + case reflect.Interface: + if v.NumMethod() == 0 { + // Decoding into nil interface? Switch to non-reflect code. + ai := d.arrayInterface() + v.Set(reflect.ValueOf(ai)) + return nil + } + // Otherwise it's invalid. + fallthrough + default: + d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + case reflect.Array, reflect.Slice: + break + } + + i := 0 + d.context = &decodeContext{parent: d.context} + for { + // Look ahead for ] - can only happen on first iteration. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndArray { + break + } + + // Expand slice length, growing the slice if necessary. + if v.Kind() == reflect.Slice { + if i >= v.Cap() { + v.Grow(1) + } + if i >= v.Len() { + v.SetLen(i + 1) + } + } + + if i < v.Len() { + // Decode into element. + if err := d.value(v.Index(i)); err != nil { + return err + } + } else { + // Ran out of fixed array: skip. + if err := d.value(reflect.Value{}); err != nil { + return err + } + } + i++ + + // Next token must be , or ]. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndArray { + break + } + if d.opcode != scanArrayValue { + panic(phasePanicMsg) + } + d.context.index++ + } + + d.context = d.context.parent + + if i < v.Len() { + if v.Kind() == reflect.Array { + for ; i < v.Len(); i++ { + v.Index(i).SetZero() // zero remainder of array + } + } else { + v.SetLen(i) // truncate the slice + } + } + if i == 0 && v.Kind() == reflect.Slice { + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + } + return nil +} + +var ( + nullLiteral = []byte("null") + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +) + +// object consumes an object from d.data[d.off-1:], decoding into v. +// The first byte ('{') of the object has been read already. +func (d *decodeState) object(v reflect.Value) error { + // Check for unmarshaler. + u, ut, pv := indirect(v, false) + if u != nil { + start := d.readIndex() + d.skip() + err := u.UnmarshalJSON(d.data[start:d.off]) + if err != nil { + d.saveError(err) + } + return nil + } + if ut != nil { + d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + } + v = pv + t := v.Type() + + // Decoding into nil interface? Switch to non-reflect code. + if v.Kind() == reflect.Interface && v.NumMethod() == 0 { + oi := d.objectInterface() + v.Set(reflect.ValueOf(oi)) + return nil + } + + var fields structFields + + // Check type of target: + // struct or + // map[T1]T2 where T1 is string, an integer type, + // or an encoding.TextUnmarshaler + switch v.Kind() { + case reflect.Map: + // Map key must either have string kind, have an integer kind, + // or be an encoding.TextUnmarshaler. + switch t.Key().Kind() { + case reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + default: + if !reflect.PointerTo(t.Key()).Implements(textUnmarshalerType) { + d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)}) + d.skip() + return nil + } + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + case reflect.Struct: + fields = cachedTypeFields(t) + // ok + default: + d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)}) + d.skip() + return nil + } + + var mapElem reflect.Value + var origErrorContext errorContext + if d.errorContext != nil { + origErrorContext = *d.errorContext + } + + d.context = &decodeContext{parent: d.context} + for { + // Read opening " of string key or closing }. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if d.opcode != scanBeginLiteral { + panic(phasePanicMsg) + } + + // Read key. + start := d.readIndex() + d.rescanLiteral() + item := d.data[start:d.readIndex()] + key, ok := unquoteBytes(item) + if !ok { + panic(phasePanicMsg) + } + d.context.key = string(key) + + // Figure out field corresponding to key. + var subv reflect.Value + destring := false // whether the value is wrapped in a string to be decoded first + + if v.Kind() == reflect.Map { + elemType := t.Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.SetZero() + } + subv = mapElem + } else { + f := fields.byExactName[string(key)] + if f == nil { + f = fields.byFoldedName[string(foldName(key))] + } + if f != nil { + subv = v + destring = f.quoted + for _, i := range f.index { + if subv.Kind() == reflect.Pointer { + if subv.IsNil() { + // If a struct embeds a pointer to an unexported type, + // it is not possible to set a newly allocated value + // since the field is unexported. + // + // See https://golang.org/issue/21357 + if !subv.CanSet() { + d.saveError(fmt.Errorf("json: cannot set embedded pointer to unexported struct: %v", subv.Type().Elem())) + // Invalidate subv to ensure d.value(subv) skips over + // the JSON value without assigning it to subv. + subv = reflect.Value{} + destring = false + break + } + subv.Set(reflect.New(subv.Type().Elem())) + } + subv = subv.Elem() + } + subv = subv.Field(i) + } + if d.errorContext == nil { + d.errorContext = new(errorContext) + } + d.errorContext.FieldStack = append(d.errorContext.FieldStack, f.name) + d.errorContext.Struct = t + } else if d.disallowUnknownFields { + d.saveError(fmt.Errorf("json: unknown field %q", key)) + } + d.context.index++ + } + + // Read : before value. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode != scanObjectKey { + panic(phasePanicMsg) + } + d.scanWhile(scanSkipSpace) + + if destring { + switch qv := d.valueQuoted().(type) { + case nil: + if err := d.literalStore(nullLiteral, subv, false); err != nil { + return err + } + case string: + if err := d.literalStore([]byte(qv), subv, true); err != nil { + return err + } + default: + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", subv.Type())) + } + } else { + if err := d.value(subv); err != nil { + return err + } + } + + // Write value back to map; + // if using struct, subv points into struct already. + if v.Kind() == reflect.Map { + kt := t.Key() + var kv reflect.Value + switch { + case reflect.PointerTo(kt).Implements(textUnmarshalerType): + kv = reflect.New(kt) + if err := d.literalStore(item, kv, true); err != nil { + return err + } + kv = kv.Elem() + case kt.Kind() == reflect.String: + kv = reflect.ValueOf(key).Convert(kt) + default: + switch kt.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := string(key) + n, err := strconv.ParseInt(s, 10, 64) + if err != nil || reflect.Zero(kt).OverflowInt(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) + break + } + kv = reflect.ValueOf(n).Convert(kt) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + s := string(key) + n, err := strconv.ParseUint(s, 10, 64) + if err != nil || reflect.Zero(kt).OverflowUint(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) + break + } + kv = reflect.ValueOf(n).Convert(kt) + default: + panic("json: Unexpected key type") // should never occur + } + } + if kv.IsValid() { + v.SetMapIndex(kv, subv) + } + } + + // Next token must be , or }. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.errorContext != nil { + // Reset errorContext to its original state. + // Keep the same underlying array for FieldStack, to reuse the + // space and avoid unnecessary allocs. + d.errorContext.FieldStack = d.errorContext.FieldStack[:len(origErrorContext.FieldStack)] + d.errorContext.Struct = origErrorContext.Struct + } + if d.opcode == scanEndObject { + break + } + if d.opcode != scanObjectValue { + panic(phasePanicMsg) + } + } + d.context = d.context.parent + return nil +} + +// convertNumber converts the number literal s to a float64 or a Number +// depending on the setting of d.useNumber. +func (d *decodeState) convertNumber(s string) (any, error) { + if d.useNumber { + return Number(s), nil + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, &UnmarshalTypeError{Value: "number " + s, Type: reflect.TypeOf(0.0), Offset: int64(d.off)} + } + return f, nil +} + +var numberType = reflect.TypeOf(Number("")) + +// literalStore decodes a literal stored in item into v. +// +// fromQuoted indicates whether this literal came from unwrapping a +// string from the ",string" struct tag option. this is used only to +// produce more helpful error messages. +func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) error { + // Check for unmarshaler. + if len(item) == 0 { + // Empty string given + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return nil + } + isNull := item[0] == 'n' // null + u, ut, pv := indirect(v, isNull) + if u != nil { + err := u.UnmarshalJSON(item) + if err != nil { + d.saveError(err) + } + return nil + } + if ut != nil { + if item[0] != '"' { + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return nil + } + val := "number" + switch item[0] { + case 'n': + val = "null" + case 't', 'f': + val = "bool" + } + d.saveError(&UnmarshalTypeError{Value: val, Type: v.Type(), Offset: int64(d.readIndex())}) + return nil + } + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + return ut.UnmarshalText(s) + } + + v = pv + + switch c := item[0]; c { + case 'n': // null + // The main parser checks that only true and false can reach here, + // but if this was a quoted string input, it could be anything. + if fromQuoted && string(item) != "null" { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + break + } + switch v.Kind() { + case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice: + v.SetZero() + // otherwise, ignore null for primitives/string + } + case 't', 'f': // true, false + value := item[0] == 't' + // The main parser checks that only true and false can reach here, + // but if this was a quoted string input, it could be anything. + if fromQuoted && string(item) != "true" && string(item) != "false" { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + break + } + switch v.Kind() { + default: + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())}) + } + case reflect.Bool: + v.SetBool(value) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(value)) + } else { + d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())}) + } + } + + case '"': // string + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + switch v.Kind() { + default: + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + case reflect.Slice: + if v.Type().Elem().Kind() != reflect.Uint8 { + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + b := make([]byte, base64.StdEncoding.DecodedLen(len(s))) + n, err := base64.StdEncoding.Decode(b, s) + if err != nil { + d.saveError(err) + break + } + v.SetBytes(b[:n]) + case reflect.String: + if v.Type() == numberType && !isValidNumber(string(s)) { + return fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", item) + } + v.SetString(string(s)) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(string(s))) + } else { + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + } + } + + default: // number + if c != '-' && (c < '0' || c > '9') { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + s := string(item) + switch v.Kind() { + default: + if v.Kind() == reflect.String && v.Type() == numberType { + // s must be a valid number, because it's + // already been tokenized. + v.SetString(s) + break + } + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())}) + case reflect.Interface: + n, err := d.convertNumber(s) + if err != nil { + d.saveError(err) + break + } + if v.NumMethod() != 0 { + d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.Set(reflect.ValueOf(n)) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(s, 10, 64) + if err != nil || v.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetInt(n) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(s, 10, 64) + if err != nil || v.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetUint(n) + + case reflect.Float32, reflect.Float64: + n, err := strconv.ParseFloat(s, v.Type().Bits()) + if err != nil || v.OverflowFloat(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetFloat(n) + } + } + return nil +} + +// The xxxInterface routines build up a value to be stored +// in an empty interface. They are not strictly necessary, +// but they avoid the weight of reflection in this common case. + +// valueInterface is like value but returns interface{} +func (d *decodeState) valueInterface() (val any) { + switch d.opcode { + default: + panic(phasePanicMsg) + case scanBeginArray: + val = d.arrayInterface() + d.scanNext() + case scanBeginObject: + val = d.objectInterface() + d.scanNext() + case scanBeginLiteral: + val = d.literalInterface() + } + return +} + +// arrayInterface is like array but returns []interface{}. +func (d *decodeState) arrayInterface() []any { + v := make([]any, 0) + d.context = &decodeContext{parent: d.context} + for { + // Look ahead for ] - can only happen on first iteration. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndArray { + break + } + + v = append(v, d.valueInterface()) + + // Next token must be , or ]. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndArray { + break + } + if d.opcode != scanArrayValue { + panic(phasePanicMsg) + } + d.context.index++ + } + d.context = d.context.parent + return v +} + +// objectInterface is like object but returns map[string]interface{}. +func (d *decodeState) objectInterface() map[string]any { + m := make(map[string]any) + d.context = &decodeContext{parent: d.context} + for { + // Read opening " of string key or closing }. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if d.opcode != scanBeginLiteral { + panic(phasePanicMsg) + } + + // Read string key. + start := d.readIndex() + d.rescanLiteral() + item := d.data[start:d.readIndex()] + key, ok := unquote(item) + if !ok { + panic(phasePanicMsg) + } + d.context.key = key + + // Read : before value. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode != scanObjectKey { + panic(phasePanicMsg) + } + d.scanWhile(scanSkipSpace) + + // Read value. + m[key] = d.valueInterface() + + // Next token must be , or }. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndObject { + break + } + if d.opcode != scanObjectValue { + panic(phasePanicMsg) + } + d.context.index++ + } + d.context = d.context.parent + return m +} + +// literalInterface consumes and returns a literal from d.data[d.off-1:] and +// it reads the following byte ahead. The first byte of the literal has been +// read already (that's how the caller knows it's a literal). +func (d *decodeState) literalInterface() any { + // All bytes inside literal return scanContinue op code. + start := d.readIndex() + d.rescanLiteral() + + item := d.data[start:d.readIndex()] + + switch c := item[0]; c { + case 'n': // null + return nil + + case 't', 'f': // true, false + return c == 't' + + case '"': // string + s, ok := unquote(item) + if !ok { + panic(phasePanicMsg) + } + return s + + default: // number + if c != '-' && (c < '0' || c > '9') { + panic(phasePanicMsg) + } + n, err := d.convertNumber(string(item)) + if err != nil { + d.saveError(err) + } + return n + } +} + +// getu4 decodes \uXXXX from the beginning of s, returning the hex value, +// or it returns -1. +func getu4(s []byte) rune { + if len(s) < 6 || s[0] != '\\' || s[1] != 'u' { + return -1 + } + var r rune + for _, c := range s[2:6] { + switch { + case '0' <= c && c <= '9': + c = c - '0' + case 'a' <= c && c <= 'f': + c = c - 'a' + 10 + case 'A' <= c && c <= 'F': + c = c - 'A' + 10 + default: + return -1 + } + r = r*16 + rune(c) + } + return r +} + +// unquote converts a quoted JSON string literal s into an actual string t. +// The rules are different than for Go, so cannot use strconv.Unquote. +func unquote(s []byte) (t string, ok bool) { + s, ok = unquoteBytes(s) + t = string(s) + return +} + +func unquoteBytes(s []byte) (t []byte, ok bool) { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return + } + s = s[1 : len(s)-1] + + // Check for unusual characters. If there are none, + // then no unquoting is needed, so return a slice of the + // original bytes. + r := 0 + for r < len(s) { + c := s[r] + if c == '\\' || c == '"' || c < ' ' { + break + } + if c < utf8.RuneSelf { + r++ + continue + } + rr, size := utf8.DecodeRune(s[r:]) + if rr == utf8.RuneError && size == 1 { + break + } + r += size + } + if r == len(s) { + return s, true + } + + b := make([]byte, len(s)+2*utf8.UTFMax) + w := copy(b, s[0:r]) + for r < len(s) { + // Out of room? Can only happen if s is full of + // malformed UTF-8 and we're replacing each + // byte with RuneError. + if w >= len(b)-2*utf8.UTFMax { + nb := make([]byte, (len(b)+utf8.UTFMax)*2) + copy(nb, b[0:w]) + b = nb + } + switch c := s[r]; { + case c == '\\': + r++ + if r >= len(s) { + return + } + switch s[r] { + default: + return + case '"', '\\', '/', '\'': + b[w] = s[r] + r++ + w++ + case 'b': + b[w] = '\b' + r++ + w++ + case 'f': + b[w] = '\f' + r++ + w++ + case 'n': + b[w] = '\n' + r++ + w++ + case 'r': + b[w] = '\r' + r++ + w++ + case 't': + b[w] = '\t' + r++ + w++ + case 'u': + r-- + rr := getu4(s[r:]) + if rr < 0 { + return + } + r += 6 + if utf16.IsSurrogate(rr) { + rr1 := getu4(s[r:]) + if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar { + // A valid pair; consume. + r += 6 + w += utf8.EncodeRune(b[w:], dec) + break + } + // Invalid surrogate; fall back to replacement rune. + rr = unicode.ReplacementChar + } + w += utf8.EncodeRune(b[w:], rr) + } + + // Quote, control characters are invalid. + case c == '"', c < ' ': + return + + // ASCII + case c < utf8.RuneSelf: + b[w] = c + r++ + w++ + + // Coerce to well-formed UTF-8. + default: + rr, size := utf8.DecodeRune(s[r:]) + r += size + w += utf8.EncodeRune(b[w:], rr) + } + } + return b[0:w], true +} diff --git a/common/json/internal/contextjson/decode_context.go b/common/json/internal/contextjson/decode_context.go new file mode 100644 index 0000000..3000926 --- /dev/null +++ b/common/json/internal/contextjson/decode_context.go @@ -0,0 +1,49 @@ +package json + +import "strconv" + +type decodeContext struct { + parent *decodeContext + index int + key string +} + +func (d *decodeState) formatContext() string { + var description string + context := d.context + var appendDot bool + for context != nil { + if appendDot { + description = "." + description + } + if context.key != "" { + description = context.key + description + appendDot = true + } else { + description = "[" + strconv.Itoa(context.index) + "]" + description + appendDot = false + } + context = context.parent + } + return description +} + +type contextError struct { + parent error + context string + index bool +} + +func (c *contextError) Unwrap() error { + return c.parent +} + +func (c *contextError) Error() string { + //goland:noinspection GoTypeAssertionOnErrors + switch c.parent.(type) { + case *contextError: + return c.context + "." + c.parent.Error() + default: + return c.context + ": " + c.parent.Error() + } +} diff --git a/common/json/internal/contextjson/encode.go b/common/json/internal/contextjson/encode.go new file mode 100644 index 0000000..296177a --- /dev/null +++ b/common/json/internal/contextjson/encode.go @@ -0,0 +1,1283 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package json implements encoding and decoding of JSON as defined in +// RFC 7159. The mapping between JSON and Go values is described +// in the documentation for the Marshal and Unmarshal functions. +// +// See "JSON and Go" for an introduction to this package: +// https://golang.org/doc/articles/json_and_go.html +package json + +import ( + "bytes" + "encoding" + "encoding/base64" + "fmt" + "math" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "unicode" + "unicode/utf8" +) + +// Marshal returns the JSON encoding of v. +// +// Marshal traverses the value v recursively. +// If an encountered value implements the Marshaler interface +// and is not a nil pointer, Marshal calls its MarshalJSON method +// to produce JSON. If no MarshalJSON method is present but the +// value implements encoding.TextMarshaler instead, Marshal calls +// its MarshalText method and encodes the result as a JSON string. +// The nil pointer exception is not strictly necessary +// but mimics a similar, necessary exception in the behavior of +// UnmarshalJSON. +// +// Otherwise, Marshal uses the following type-dependent default encodings: +// +// Boolean values encode as JSON booleans. +// +// Floating point, integer, and Number values encode as JSON numbers. +// NaN and +/-Inf values will return an [UnsupportedValueError]. +// +// String values encode as JSON strings coerced to valid UTF-8, +// replacing invalid bytes with the Unicode replacement rune. +// So that the JSON will be safe to embed inside HTML