Compare commits

..

No commits in common. "dev" and "v0.1.8" have entirely different histories.
dev ... v0.1.8

271 changed files with 1867 additions and 19623 deletions

View file

@ -5,9 +5,6 @@
"config:base",
":disableRateLimiting"
],
"baseBranches": [
"dev"
],
"packageRules": [
{
"matchManagers": [
@ -17,9 +14,9 @@
},
{
"matchManagers": [
"dockerfile"
"gomod"
],
"groupName": "Dockerfile"
"groupName": "gomod"
}
]
}

43
.github/workflows/debug.yml vendored Normal file
View file

@ -0,0 +1,43 @@
name: Debug build
on:
push:
branches:
- dev
paths-ignore:
- '**.md'
- '.github/**'
- '!.github/workflows/debug.yml'
pull_request:
branches:
- dev
jobs:
build:
name: Debug build
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v3
with:
go-version: ${{ steps.version.outputs.go_version }}
- name: Add cache to Go proxy
run: |
version=`git rev-parse HEAD`
mkdir build
pushd build
go mod init build
go get -v github.com/sagernet/sing@$version
popd
continue-on-error: true
- name: Build
run: |
make test

View file

@ -1,9 +1,8 @@
name: lint
name: Lint
on:
push:
branches:
- main
- dev
paths-ignore:
- '**.md'
@ -11,7 +10,6 @@ on:
- '!.github/workflows/lint.yml'
pull_request:
branches:
- main
- dev
jobs:
@ -20,20 +18,24 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v5
uses: actions/setup-go@v3
with:
go-version: ^1.23
go-version: ${{ steps.version.outputs.go_version }}
- name: Cache go module
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: |
~/go/pkg/mod
key: go-${{ hashFiles('**/go.sum') }}
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
uses: golangci/golangci-lint-action@v3
with:
version: latest

View file

@ -1,112 +0,0 @@
name: test
on:
push:
branches:
- main
- dev
paths-ignore:
- '**.md'
- '.github/**'
- '!.github/workflows/debug.yml'
pull_request:
branches:
- main
- dev
jobs:
build:
name: Linux
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.23
- name: Build
run: |
make test
build_go120:
name: Linux (Go 1.20)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.20
continue-on-error: true
- name: Build
run: |
make test
build_go121:
name: Linux (Go 1.21)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.21
continue-on-error: true
- name: Build
run: |
make test
build_go122:
name: Linux (Go 1.22)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.22
continue-on-error: true
- name: Build
run: |
make test
build_windows:
name: Windows
runs-on: windows-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.23
continue-on-error: true
- name: Build
run: |
make test
build_darwin:
name: macOS
runs-on: macos-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.23
continue-on-error: true
- name: Build
run: |
make test

1
.gitignore vendored
View file

@ -1,3 +1,2 @@
/.idea/
/vendor/
.DS_Store

View file

@ -5,8 +5,6 @@ linters:
- govet
- gci
- staticcheck
- paralleltest
- ineffassign
linters-settings:
gci:
@ -16,9 +14,4 @@ linters-settings:
- prefix(github.com/sagernet/)
- default
staticcheck:
checks:
- all
- -SA1003
run:
go: "1.23"
go: '1.20'

View file

@ -1,21 +1,21 @@
fmt:
@gofumpt -l -w .
@gofmt -s -w .
@gci write --custom-order -s standard -s "prefix(github.com/sagernet/)" -s "default" .
@gci write --custom-order -s "standard,prefix(github.com/sagernet/),default" .
fmt_install:
go install -v mvdan.cc/gofumpt@latest
go install -v github.com/daixiang0/gci@latest
lint:
GOOS=linux golangci-lint run
GOOS=android golangci-lint run
GOOS=windows golangci-lint run
GOOS=darwin golangci-lint run
GOOS=freebsd golangci-lint run
GOOS=linux golangci-lint run ./...
GOOS=android golangci-lint run ./...
GOOS=windows golangci-lint run ./...
GOOS=darwin golangci-lint run ./...
GOOS=freebsd golangci-lint run ./...
lint_install:
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
test:
go test ./...
go test -v ./...

View file

@ -1,6 +1,3 @@
# sing
![test](https://github.com/sagernet/sing/actions/workflows/test.yml/badge.svg)
![lint](https://github.com/sagernet/sing/actions/workflows/lint.yml/badge.svg)
Do you hear the people sing?

View file

@ -18,6 +18,7 @@ var _ xml.TokenReader = (*Reader)(nil)
type Reader struct {
reader *bytes.Reader
stringRefs []string
attrs []xml.Attr
}
func NewReader(content []byte) (xml.TokenReader, bool) {
@ -46,7 +47,7 @@ func (r *Reader) Token() (token xml.Token, err error) {
return
}
var attrs []xml.Attr
attrs, err = r.readAttributes()
attrs, err = r.pullAttributes()
if err != nil {
return
}
@ -92,41 +93,35 @@ func (r *Reader) Token() (token xml.Token, err error) {
_, err = r.readUTF()
return
case ATTRIBUTE:
_, err = r.readAttribute()
return
return nil, E.New("unexpected attribute")
}
return nil, E.New("unknown token type ", tokenType, " with type ", eventType)
}
func (r *Reader) readAttributes() ([]xml.Attr, error) {
var attrs []xml.Attr
for {
attr, err := r.readAttribute()
if err == io.EOF {
break
}
attrs = append(attrs, attr)
func (r *Reader) pullAttributes() ([]xml.Attr, error) {
err := r.pullAttribute()
if err != nil {
return nil, err
}
attrs := r.attrs
r.attrs = nil
return attrs, nil
}
func (r *Reader) readAttribute() (xml.Attr, error) {
func (r *Reader) pullAttribute() error {
event, err := r.reader.ReadByte()
if err != nil {
return xml.Attr{}, nil
return nil
}
tokenType := event & 0x0f
eventType := event & 0xf0
if tokenType != ATTRIBUTE {
err = r.reader.UnreadByte()
if err != nil {
return xml.Attr{}, nil
}
return xml.Attr{}, io.EOF
return r.reader.UnreadByte()
}
name, err := r.readInternedUTF()
var name string
name, err = r.readInternedUTF()
if err != nil {
return xml.Attr{}, err
return err
}
var value string
switch eventType {
@ -139,73 +134,74 @@ func (r *Reader) readAttribute() (xml.Attr, error) {
case TypeString:
value, err = r.readUTF()
if err != nil {
return xml.Attr{}, err
return err
}
case TypeStringInterned:
value, err = r.readInternedUTF()
if err != nil {
return xml.Attr{}, err
return err
}
case TypeBytesHex:
var data []byte
data, err = r.readBytes()
if err != nil {
return xml.Attr{}, err
return err
}
value = hex.EncodeToString(data)
case TypeBytesBase64:
var data []byte
data, err = r.readBytes()
if err != nil {
return xml.Attr{}, err
return err
}
value = base64.StdEncoding.EncodeToString(data)
case TypeInt:
var data int32
err = binary.Read(r.reader, binary.BigEndian, &data)
if err != nil {
return xml.Attr{}, err
return err
}
value = strconv.FormatInt(int64(data), 10)
case TypeIntHex:
var data int32
err = binary.Read(r.reader, binary.BigEndian, &data)
if err != nil {
return xml.Attr{}, err
return err
}
value = "0x" + strconv.FormatInt(int64(data), 16)
case TypeLong:
var data int64
err = binary.Read(r.reader, binary.BigEndian, &data)
if err != nil {
return xml.Attr{}, err
return err
}
value = strconv.FormatInt(data, 10)
case TypeLongHex:
var data int64
err = binary.Read(r.reader, binary.BigEndian, &data)
if err != nil {
return xml.Attr{}, err
return err
}
value = "0x" + strconv.FormatInt(data, 16)
case TypeFloat:
var data float32
err = binary.Read(r.reader, binary.BigEndian, &data)
if err != nil {
return xml.Attr{}, err
return err
}
value = strconv.FormatFloat(float64(data), 'g', -1, 32)
case TypeDouble:
var data float64
err = binary.Read(r.reader, binary.BigEndian, &data)
if err != nil {
return xml.Attr{}, err
return err
}
value = strconv.FormatFloat(data, 'g', -1, 64)
default:
return xml.Attr{}, E.New("unexpected attribute type, ", eventType)
return E.New("unexpected attribute type, ", eventType)
}
return xml.Attr{Name: xml.Name{Local: name}, Value: value}, nil
r.attrs = append(r.attrs, xml.Attr{Name: xml.Name{Local: name}, Value: value})
return r.pullAttribute()
}
func (r *Reader) readUnsignedShort() (uint16, error) {

View file

@ -1,46 +0,0 @@
package atomic
import (
"sync/atomic"
"github.com/sagernet/sing/common"
)
type TypedValue[T any] struct {
value atomic.Value
}
// typedValue is a struct with determined type to resolve atomic.Value usages with interface types
// https://github.com/golang/go/issues/22550
//
// The intention to have an atomic value store for errors. However, running this code panics:
// panic: sync/atomic: store of inconsistently typed value into Value
// This is because atomic.Value requires that the underlying concrete type be the same (which is a reasonable expectation for its implementation).
// When going through the atomic.Value.Store method call, the fact that both these are of the error interface is lost.
type typedValue[T any] struct {
value T
}
func (t *TypedValue[T]) Load() T {
value := t.value.Load()
if value == nil {
return common.DefaultValue[T]()
}
return value.(typedValue[T]).value
}
func (t *TypedValue[T]) Store(value T) {
t.value.Store(typedValue[T]{value})
}
func (t *TypedValue[T]) Swap(new T) T {
old := t.value.Swap(typedValue[T]{new})
if old == nil {
return common.DefaultValue[T]()
}
return old.(typedValue[T]).value
}
func (t *TypedValue[T]) CompareAndSwap(old, new T) bool {
return t.value.CompareAndSwap(typedValue[T]{old}, typedValue[T]{new})
}

View file

@ -1,19 +0,0 @@
//go:build go1.19
package atomic
import "sync/atomic"
type (
Bool = atomic.Bool
Int32 = atomic.Int32
Int64 = atomic.Int64
Uint32 = atomic.Uint32
Uint64 = atomic.Uint64
Uintptr = atomic.Uintptr
Value = atomic.Value
)
type Pointer[T any] struct {
atomic.Pointer[T]
}

View file

@ -1,198 +0,0 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.19
package atomic
import (
"sync/atomic"
"unsafe"
)
// A Bool is an atomic boolean value.
// The zero value is false.
type Bool struct {
_ noCopy
v uint32
}
// Load atomically loads and returns the value stored in x.
func (x *Bool) Load() bool { return atomic.LoadUint32(&x.v) != 0 }
// Store atomically stores val into x.
func (x *Bool) Store(val bool) { atomic.StoreUint32(&x.v, b32(val)) }
// Swap atomically stores new into x and returns the previous value.
func (x *Bool) Swap(new bool) (old bool) { return atomic.SwapUint32(&x.v, b32(new)) != 0 }
// CompareAndSwap executes the compare-and-swap operation for the boolean value x.
func (x *Bool) CompareAndSwap(old, new bool) (swapped bool) {
return atomic.CompareAndSwapUint32(&x.v, b32(old), b32(new))
}
// b32 returns a uint32 0 or 1 representing b.
func b32(b bool) uint32 {
if b {
return 1
}
return 0
}
// A Pointer is an atomic pointer of type *T. The zero value is a nil *T.
type Pointer[T any] struct {
// Mention *T in a field to disallow conversion between Pointer types.
// See go.dev/issue/56603 for more details.
// Use *T, not T, to avoid spurious recursive type definition errors.
_ [0]*T
_ noCopy
v unsafe.Pointer
}
// Load atomically loads and returns the value stored in x.
func (x *Pointer[T]) Load() *T { return (*T)(atomic.LoadPointer(&x.v)) }
// Store atomically stores val into x.
func (x *Pointer[T]) Store(val *T) { atomic.StorePointer(&x.v, unsafe.Pointer(val)) }
// Swap atomically stores new into x and returns the previous value.
func (x *Pointer[T]) Swap(new *T) (old *T) {
return (*T)(atomic.SwapPointer(&x.v, unsafe.Pointer(new)))
}
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Pointer[T]) CompareAndSwap(old, new *T) (swapped bool) {
return atomic.CompareAndSwapPointer(&x.v, unsafe.Pointer(old), unsafe.Pointer(new))
}
// An Int32 is an atomic int32. The zero value is zero.
type Int32 struct {
_ noCopy
v int32
}
// Load atomically loads and returns the value stored in x.
func (x *Int32) Load() int32 { return atomic.LoadInt32(&x.v) }
// Store atomically stores val into x.
func (x *Int32) Store(val int32) { atomic.StoreInt32(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Int32) Swap(new int32) (old int32) { return atomic.SwapInt32(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Int32) CompareAndSwap(old, new int32) (swapped bool) {
return atomic.CompareAndSwapInt32(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Int32) Add(delta int32) (new int32) { return atomic.AddInt32(&x.v, delta) }
// An Int64 is an atomic int64. The zero value is zero.
type Int64 struct {
_ noCopy
v int64
}
// Load atomically loads and returns the value stored in x.
func (x *Int64) Load() int64 { return atomic.LoadInt64(&x.v) }
// Store atomically stores val into x.
func (x *Int64) Store(val int64) { atomic.StoreInt64(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Int64) Swap(new int64) (old int64) { return atomic.SwapInt64(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Int64) CompareAndSwap(old, new int64) (swapped bool) {
return atomic.CompareAndSwapInt64(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Int64) Add(delta int64) (new int64) { return atomic.AddInt64(&x.v, delta) }
// An Uint32 is an atomic uint32. The zero value is zero.
type Uint32 struct {
_ noCopy
v uint32
}
// Load atomically loads and returns the value stored in x.
func (x *Uint32) Load() uint32 { return atomic.LoadUint32(&x.v) }
// Store atomically stores val into x.
func (x *Uint32) Store(val uint32) { atomic.StoreUint32(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Uint32) Swap(new uint32) (old uint32) { return atomic.SwapUint32(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Uint32) CompareAndSwap(old, new uint32) (swapped bool) {
return atomic.CompareAndSwapUint32(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Uint32) Add(delta uint32) (new uint32) { return atomic.AddUint32(&x.v, delta) }
// An Uint64 is an atomic uint64. The zero value is zero.
type Uint64 struct {
_ noCopy
v uint64
}
// Load atomically loads and returns the value stored in x.
func (x *Uint64) Load() uint64 { return atomic.LoadUint64(&x.v) }
// Store atomically stores val into x.
func (x *Uint64) Store(val uint64) { atomic.StoreUint64(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Uint64) Swap(new uint64) (old uint64) { return atomic.SwapUint64(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Uint64) CompareAndSwap(old, new uint64) (swapped bool) {
return atomic.CompareAndSwapUint64(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Uint64) Add(delta uint64) (new uint64) { return atomic.AddUint64(&x.v, delta) }
// An Uintptr is an atomic uintptr. The zero value is zero.
type Uintptr struct {
_ noCopy
v uintptr
}
// Load atomically loads and returns the value stored in x.
func (x *Uintptr) Load() uintptr { return atomic.LoadUintptr(&x.v) }
// Store atomically stores val into x.
func (x *Uintptr) Store(val uintptr) { atomic.StoreUintptr(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Uintptr) Swap(new uintptr) (old uintptr) { return atomic.SwapUintptr(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Uintptr) CompareAndSwap(old, new uintptr) (swapped bool) {
return atomic.CompareAndSwapUintptr(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Uintptr) Add(delta uintptr) (new uintptr) { return atomic.AddUintptr(&x.v, delta) }
// noCopy may be added to structs which must not be copied
// after the first use.
//
// See https://golang.org/issues/8005#issuecomment-190753527
// for details.
//
// Note that it must not be embedded, due to the Lock and Unlock methods.
type noCopy struct{}
// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}
type Value = atomic.Value

View file

@ -1,30 +1,38 @@
package auth
import "github.com/sagernet/sing/common"
type Authenticator interface {
Verify(user string, pass string) bool
Users() []string
}
type User struct {
Username string
Password string
Username string `json:"username"`
Password string `json:"password"`
}
type Authenticator struct {
userMap map[string][]string
type inMemoryAuthenticator struct {
storage map[string]string
usernames []string
}
func NewAuthenticator(users []User) *Authenticator {
func (au *inMemoryAuthenticator) Verify(username string, password string) bool {
realPass, ok := au.storage[username]
return ok && realPass == password
}
func (au *inMemoryAuthenticator) Users() []string { return au.usernames }
func NewAuthenticator(users []User) Authenticator {
if len(users) == 0 {
return nil
}
au := &Authenticator{
userMap: make(map[string][]string),
au := &inMemoryAuthenticator{
storage: make(map[string]string),
usernames: make([]string, 0, len(users)),
}
for _, user := range users {
au.userMap[user.Username] = append(au.userMap[user.Username], user.Password)
au.storage[user.Username] = user.Password
au.usernames = append(au.usernames, user.Username)
}
return au
}
func (au *Authenticator) Verify(username string, password string) bool {
passwordList, ok := au.userMap[username]
return ok && common.Contains(passwordList, password)
}

View file

@ -1,63 +0,0 @@
package baderror
import (
"context"
"errors"
"io"
"net"
"strings"
)
func Contains(err error, msgList ...string) bool {
for _, msg := range msgList {
if strings.Contains(err.Error(), msg) {
return true
}
}
return false
}
func WrapH2(err error) error {
if err == nil {
return nil
}
if errors.Is(err, io.ErrUnexpectedEOF) {
return io.EOF
}
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {
return net.ErrClosed
}
return err
}
func WrapGRPC(err error) error {
// grpc uses stupid internal error types
if err == nil {
return nil
}
if Contains(err, "EOF") {
return io.EOF
}
if Contains(err, "Canceled") {
return context.Canceled
}
if Contains(err,
"the client connection is closing",
"server closed the stream without sending trailers") {
return net.ErrClosed
}
return err
}
func WrapQUIC(err error) error {
if err == nil {
return nil
}
if Contains(err,
"canceled by remote with error code 0",
"canceled by local with error code 0",
) {
return net.ErrClosed
}
return err
}

View file

@ -3,9 +3,6 @@ package batch
import (
"context"
"sync"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
type Option[T any] func(b *Batch[T])
@ -20,10 +17,6 @@ type Error struct {
Err error
}
func (e *Error) Error() string {
return E.Cause(e.Err, e.Key).Error()
}
func WithConcurrencyNum[T any](n int) Option[T] {
return func(b *Batch[T]) {
q := make(chan struct{}, n)
@ -42,7 +35,7 @@ type Batch[T any] struct {
mux sync.Mutex
err *Error
once sync.Once
cancel common.ContextCancelCauseFunc
cancel func()
}
func (b *Batch[T]) Go(key string, fn func() (T, error)) {
@ -61,7 +54,7 @@ func (b *Batch[T]) Go(key string, fn func() (T, error)) {
b.once.Do(func() {
b.err = &Error{key, err}
if b.cancel != nil {
b.cancel(b.err)
b.cancel()
}
})
}
@ -76,7 +69,7 @@ func (b *Batch[T]) Go(key string, fn func() (T, error)) {
func (b *Batch[T]) Wait() *Error {
b.wg.Wait()
if b.cancel != nil {
b.cancel(nil)
b.cancel()
}
return b.err
}
@ -97,7 +90,7 @@ func (b *Batch[T]) Result() map[string]Result[T] {
}
func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) {
ctx, cancel := common.ContextWithCancelCause(ctx)
ctx, cancel := context.WithCancel(ctx)
b := &Batch[T]{
result: map[string]Result[T]{},

View file

@ -1,3 +0,0 @@
# binary
mod from go 1.22.3

View file

@ -1,817 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package binary implements simple translation between numbers and byte
// sequences and encoding and decoding of varints.
//
// Numbers are translated by reading and writing fixed-size values.
// A fixed-size value is either a fixed-size arithmetic
// type (bool, int8, uint8, int16, float32, complex64, ...)
// or an array or struct containing only fixed-size values.
//
// The varint functions encode and decode single integer values using
// a variable-length encoding; smaller values require fewer bytes.
// For a specification, see
// https://developers.google.com/protocol-buffers/docs/encoding.
//
// This package favors simplicity over efficiency. Clients that require
// high-performance serialization, especially for large data structures,
// should look at more advanced solutions such as the [encoding/gob]
// package or [google.golang.org/protobuf] for protocol buffers.
package binary
import (
"errors"
"io"
"math"
"reflect"
"sync"
)
// A ByteOrder specifies how to convert byte slices into
// 16-, 32-, or 64-bit unsigned integers.
//
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
type ByteOrder interface {
Uint16([]byte) uint16
Uint32([]byte) uint32
Uint64([]byte) uint64
PutUint16([]byte, uint16)
PutUint32([]byte, uint32)
PutUint64([]byte, uint64)
String() string
}
// AppendByteOrder specifies how to append 16-, 32-, or 64-bit unsigned integers
// into a byte slice.
//
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
type AppendByteOrder interface {
AppendUint16([]byte, uint16) []byte
AppendUint32([]byte, uint32) []byte
AppendUint64([]byte, uint64) []byte
String() string
}
// LittleEndian is the little-endian implementation of [ByteOrder] and [AppendByteOrder].
var LittleEndian littleEndian
// BigEndian is the big-endian implementation of [ByteOrder] and [AppendByteOrder].
var BigEndian bigEndian
type littleEndian struct{}
func (littleEndian) Uint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[0]) | uint16(b[1])<<8
}
func (littleEndian) PutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
}
func (littleEndian) AppendUint16(b []byte, v uint16) []byte {
return append(b,
byte(v),
byte(v>>8),
)
}
func (littleEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func (littleEndian) PutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
}
func (littleEndian) AppendUint32(b []byte, v uint32) []byte {
return append(b,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
)
}
func (littleEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
func (littleEndian) PutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
b[4] = byte(v >> 32)
b[5] = byte(v >> 40)
b[6] = byte(v >> 48)
b[7] = byte(v >> 56)
}
func (littleEndian) AppendUint64(b []byte, v uint64) []byte {
return append(b,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
byte(v>>32),
byte(v>>40),
byte(v>>48),
byte(v>>56),
)
}
func (littleEndian) String() string { return "LittleEndian" }
func (littleEndian) GoString() string { return "binary.LittleEndian" }
type bigEndian struct{}
func (bigEndian) Uint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[1]) | uint16(b[0])<<8
}
func (bigEndian) PutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 8)
b[1] = byte(v)
}
func (bigEndian) AppendUint16(b []byte, v uint16) []byte {
return append(b,
byte(v>>8),
byte(v),
)
}
func (bigEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}
func (bigEndian) PutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
b[3] = byte(v)
}
func (bigEndian) AppendUint32(b []byte, v uint32) []byte {
return append(b,
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v),
)
}
func (bigEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
}
func (bigEndian) PutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
b[3] = byte(v >> 32)
b[4] = byte(v >> 24)
b[5] = byte(v >> 16)
b[6] = byte(v >> 8)
b[7] = byte(v)
}
func (bigEndian) AppendUint64(b []byte, v uint64) []byte {
return append(b,
byte(v>>56),
byte(v>>48),
byte(v>>40),
byte(v>>32),
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v),
)
}
func (bigEndian) String() string { return "BigEndian" }
func (bigEndian) GoString() string { return "binary.BigEndian" }
func (nativeEndian) String() string { return "NativeEndian" }
func (nativeEndian) GoString() string { return "binary.NativeEndian" }
// Read reads structured binary data from r into data.
// Data must be a pointer to a fixed-size value or a slice
// of fixed-size values.
// Bytes read from r are decoded using the specified byte order
// and written to successive fields of the data.
// When decoding boolean values, a zero byte is decoded as false, and
// any other non-zero byte is decoded as true.
// When reading into structs, the field data for fields with
// blank (_) field names is skipped; i.e., blank field names
// may be used for padding.
// When reading into a struct, all non-blank fields must be exported
// or Read may panic.
//
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// Read returns [io.ErrUnexpectedEOF].
func Read(r io.Reader, order ByteOrder, data any) error {
// Fast path for basic types and slices.
if n := intDataSize(data); n != 0 {
bs := make([]byte, n)
if _, err := io.ReadFull(r, bs); err != nil {
return err
}
switch data := data.(type) {
case *bool:
*data = bs[0] != 0
case *int8:
*data = int8(bs[0])
case *uint8:
*data = bs[0]
case *int16:
*data = int16(order.Uint16(bs))
case *uint16:
*data = order.Uint16(bs)
case *int32:
*data = int32(order.Uint32(bs))
case *uint32:
*data = order.Uint32(bs)
case *int64:
*data = int64(order.Uint64(bs))
case *uint64:
*data = order.Uint64(bs)
case *float32:
*data = math.Float32frombits(order.Uint32(bs))
case *float64:
*data = math.Float64frombits(order.Uint64(bs))
case []bool:
for i, x := range bs { // Easier to loop over the input for 8-bit values.
data[i] = x != 0
}
case []int8:
for i, x := range bs {
data[i] = int8(x)
}
case []uint8:
copy(data, bs)
case []int16:
for i := range data {
data[i] = int16(order.Uint16(bs[2*i:]))
}
case []uint16:
for i := range data {
data[i] = order.Uint16(bs[2*i:])
}
case []int32:
for i := range data {
data[i] = int32(order.Uint32(bs[4*i:]))
}
case []uint32:
for i := range data {
data[i] = order.Uint32(bs[4*i:])
}
case []int64:
for i := range data {
data[i] = int64(order.Uint64(bs[8*i:]))
}
case []uint64:
for i := range data {
data[i] = order.Uint64(bs[8*i:])
}
case []float32:
for i := range data {
data[i] = math.Float32frombits(order.Uint32(bs[4*i:]))
}
case []float64:
for i := range data {
data[i] = math.Float64frombits(order.Uint64(bs[8*i:]))
}
default:
n = 0 // fast path doesn't apply
}
if n != 0 {
return nil
}
}
// Fallback to reflect-based decoding.
v := reflect.ValueOf(data)
size := -1
switch v.Kind() {
case reflect.Pointer:
v = v.Elem()
size = dataSize(v)
case reflect.Slice:
size = dataSize(v)
}
if size < 0 {
return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String())
}
d := &decoder{order: order, buf: make([]byte, size)}
if _, err := io.ReadFull(r, d.buf); err != nil {
return err
}
d.value(v)
return nil
}
// Write writes the binary representation of data into w.
// Data must be a fixed-size value or a slice of fixed-size
// values, or a pointer to such data.
// Boolean values encode as one byte: 1 for true, and 0 for false.
// Bytes written to w are encoded using the specified byte order
// and read from successive fields of the data.
// When writing structs, zero values are written for fields
// with blank (_) field names.
func Write(w io.Writer, order ByteOrder, data any) error {
// Fast path for basic types and slices.
if n := intDataSize(data); n != 0 {
bs := make([]byte, n)
switch v := data.(type) {
case *bool:
if *v {
bs[0] = 1
} else {
bs[0] = 0
}
case bool:
if v {
bs[0] = 1
} else {
bs[0] = 0
}
case []bool:
for i, x := range v {
if x {
bs[i] = 1
} else {
bs[i] = 0
}
}
case *int8:
bs[0] = byte(*v)
case int8:
bs[0] = byte(v)
case []int8:
for i, x := range v {
bs[i] = byte(x)
}
case *uint8:
bs[0] = *v
case uint8:
bs[0] = v
case []uint8:
bs = v
case *int16:
order.PutUint16(bs, uint16(*v))
case int16:
order.PutUint16(bs, uint16(v))
case []int16:
for i, x := range v {
order.PutUint16(bs[2*i:], uint16(x))
}
case *uint16:
order.PutUint16(bs, *v)
case uint16:
order.PutUint16(bs, v)
case []uint16:
for i, x := range v {
order.PutUint16(bs[2*i:], x)
}
case *int32:
order.PutUint32(bs, uint32(*v))
case int32:
order.PutUint32(bs, uint32(v))
case []int32:
for i, x := range v {
order.PutUint32(bs[4*i:], uint32(x))
}
case *uint32:
order.PutUint32(bs, *v)
case uint32:
order.PutUint32(bs, v)
case []uint32:
for i, x := range v {
order.PutUint32(bs[4*i:], x)
}
case *int64:
order.PutUint64(bs, uint64(*v))
case int64:
order.PutUint64(bs, uint64(v))
case []int64:
for i, x := range v {
order.PutUint64(bs[8*i:], uint64(x))
}
case *uint64:
order.PutUint64(bs, *v)
case uint64:
order.PutUint64(bs, v)
case []uint64:
for i, x := range v {
order.PutUint64(bs[8*i:], x)
}
case *float32:
order.PutUint32(bs, math.Float32bits(*v))
case float32:
order.PutUint32(bs, math.Float32bits(v))
case []float32:
for i, x := range v {
order.PutUint32(bs[4*i:], math.Float32bits(x))
}
case *float64:
order.PutUint64(bs, math.Float64bits(*v))
case float64:
order.PutUint64(bs, math.Float64bits(v))
case []float64:
for i, x := range v {
order.PutUint64(bs[8*i:], math.Float64bits(x))
}
}
_, err := w.Write(bs)
return err
}
// Fallback to reflect-based encoding.
v := reflect.Indirect(reflect.ValueOf(data))
size := dataSize(v)
if size < 0 {
return errors.New("binary.Write: some values are not fixed-sized in type " + reflect.TypeOf(data).String())
}
buf := make([]byte, size)
e := &encoder{order: order, buf: buf}
e.value(v)
_, err := w.Write(buf)
return err
}
// Size returns how many bytes [Write] would generate to encode the value v, which
// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data.
// If v is neither of these, Size returns -1.
func Size(v any) int {
return dataSize(reflect.Indirect(reflect.ValueOf(v)))
}
var structSize sync.Map // map[reflect.Type]int
// dataSize returns the number of bytes the actual data represented by v occupies in memory.
// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice
// it returns the length of the slice times the element size and does not count the memory
// occupied by the header. If the type of v is not acceptable, dataSize returns -1.
func dataSize(v reflect.Value) int {
switch v.Kind() {
case reflect.Slice:
if s := sizeof(v.Type().Elem()); s >= 0 {
return s * v.Len()
}
case reflect.Struct:
t := v.Type()
if size, ok := structSize.Load(t); ok {
return size.(int)
}
size := sizeof(t)
structSize.Store(t, size)
return size
default:
if v.IsValid() {
return sizeof(v.Type())
}
}
return -1
}
// sizeof returns the size >= 0 of variables for the given type or -1 if the type is not acceptable.
func sizeof(t reflect.Type) int {
switch t.Kind() {
case reflect.Array:
if s := sizeof(t.Elem()); s >= 0 {
return s * t.Len()
}
case reflect.Struct:
sum := 0
for i, n := 0, t.NumField(); i < n; i++ {
s := sizeof(t.Field(i).Type)
if s < 0 {
return -1
}
sum += s
}
return sum
case reflect.Bool,
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
return int(t.Size())
}
return -1
}
type coder struct {
order ByteOrder
buf []byte
offset int
}
type (
decoder coder
encoder coder
)
func (d *decoder) bool() bool {
x := d.buf[d.offset]
d.offset++
return x != 0
}
func (e *encoder) bool(x bool) {
if x {
e.buf[e.offset] = 1
} else {
e.buf[e.offset] = 0
}
e.offset++
}
func (d *decoder) uint8() uint8 {
x := d.buf[d.offset]
d.offset++
return x
}
func (e *encoder) uint8(x uint8) {
e.buf[e.offset] = x
e.offset++
}
func (d *decoder) uint16() uint16 {
x := d.order.Uint16(d.buf[d.offset : d.offset+2])
d.offset += 2
return x
}
func (e *encoder) uint16(x uint16) {
e.order.PutUint16(e.buf[e.offset:e.offset+2], x)
e.offset += 2
}
func (d *decoder) uint32() uint32 {
x := d.order.Uint32(d.buf[d.offset : d.offset+4])
d.offset += 4
return x
}
func (e *encoder) uint32(x uint32) {
e.order.PutUint32(e.buf[e.offset:e.offset+4], x)
e.offset += 4
}
func (d *decoder) uint64() uint64 {
x := d.order.Uint64(d.buf[d.offset : d.offset+8])
d.offset += 8
return x
}
func (e *encoder) uint64(x uint64) {
e.order.PutUint64(e.buf[e.offset:e.offset+8], x)
e.offset += 8
}
func (d *decoder) int8() int8 { return int8(d.uint8()) }
func (e *encoder) int8(x int8) { e.uint8(uint8(x)) }
func (d *decoder) int16() int16 { return int16(d.uint16()) }
func (e *encoder) int16(x int16) { e.uint16(uint16(x)) }
func (d *decoder) int32() int32 { return int32(d.uint32()) }
func (e *encoder) int32(x int32) { e.uint32(uint32(x)) }
func (d *decoder) int64() int64 { return int64(d.uint64()) }
func (e *encoder) int64(x int64) { e.uint64(uint64(x)) }
func (d *decoder) value(v reflect.Value) {
switch v.Kind() {
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
d.value(v.Index(i))
}
case reflect.Struct:
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
// Note: Calling v.CanSet() below is an optimization.
// It would be sufficient to check the field name,
// but creating the StructField info for each field is
// costly (run "go test -bench=ReadStruct" and compare
// results when making changes to this code).
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
d.value(v)
} else {
d.skip(v)
}
}
case reflect.Slice:
l := v.Len()
for i := 0; i < l; i++ {
d.value(v.Index(i))
}
case reflect.Bool:
v.SetBool(d.bool())
case reflect.Int8:
v.SetInt(int64(d.int8()))
case reflect.Int16:
v.SetInt(int64(d.int16()))
case reflect.Int32:
v.SetInt(int64(d.int32()))
case reflect.Int64:
v.SetInt(d.int64())
case reflect.Uint8:
v.SetUint(uint64(d.uint8()))
case reflect.Uint16:
v.SetUint(uint64(d.uint16()))
case reflect.Uint32:
v.SetUint(uint64(d.uint32()))
case reflect.Uint64:
v.SetUint(d.uint64())
case reflect.Float32:
v.SetFloat(float64(math.Float32frombits(d.uint32())))
case reflect.Float64:
v.SetFloat(math.Float64frombits(d.uint64()))
case reflect.Complex64:
v.SetComplex(complex(
float64(math.Float32frombits(d.uint32())),
float64(math.Float32frombits(d.uint32())),
))
case reflect.Complex128:
v.SetComplex(complex(
math.Float64frombits(d.uint64()),
math.Float64frombits(d.uint64()),
))
}
}
func (e *encoder) value(v reflect.Value) {
switch v.Kind() {
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
e.value(v.Index(i))
}
case reflect.Struct:
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
// see comment for corresponding code in decoder.value()
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
e.value(v)
} else {
e.skip(v)
}
}
case reflect.Slice:
l := v.Len()
for i := 0; i < l; i++ {
e.value(v.Index(i))
}
case reflect.Bool:
e.bool(v.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch v.Type().Kind() {
case reflect.Int8:
e.int8(int8(v.Int()))
case reflect.Int16:
e.int16(int16(v.Int()))
case reflect.Int32:
e.int32(int32(v.Int()))
case reflect.Int64:
e.int64(v.Int())
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
switch v.Type().Kind() {
case reflect.Uint8:
e.uint8(uint8(v.Uint()))
case reflect.Uint16:
e.uint16(uint16(v.Uint()))
case reflect.Uint32:
e.uint32(uint32(v.Uint()))
case reflect.Uint64:
e.uint64(v.Uint())
}
case reflect.Float32, reflect.Float64:
switch v.Type().Kind() {
case reflect.Float32:
e.uint32(math.Float32bits(float32(v.Float())))
case reflect.Float64:
e.uint64(math.Float64bits(v.Float()))
}
case reflect.Complex64, reflect.Complex128:
switch v.Type().Kind() {
case reflect.Complex64:
x := v.Complex()
e.uint32(math.Float32bits(float32(real(x))))
e.uint32(math.Float32bits(float32(imag(x))))
case reflect.Complex128:
x := v.Complex()
e.uint64(math.Float64bits(real(x)))
e.uint64(math.Float64bits(imag(x)))
}
}
}
func (d *decoder) skip(v reflect.Value) {
d.offset += dataSize(v)
}
func (e *encoder) skip(v reflect.Value) {
n := dataSize(v)
zero := e.buf[e.offset : e.offset+n]
for i := range zero {
zero[i] = 0
}
e.offset += n
}
// intDataSize returns the size of the data required to represent the data when encoded.
// It returns zero if the type cannot be implemented by the fast path in Read or Write.
func intDataSize(data any) int {
switch data := data.(type) {
case bool, int8, uint8, *bool, *int8, *uint8:
return 1
case []bool:
return len(data)
case []int8:
return len(data)
case []uint8:
return len(data)
case int16, uint16, *int16, *uint16:
return 2
case []int16:
return 2 * len(data)
case []uint16:
return 2 * len(data)
case int32, uint32, *int32, *uint32:
return 4
case []int32:
return 4 * len(data)
case []uint32:
return 4 * len(data)
case int64, uint64, *int64, *uint64:
return 8
case []int64:
return 8 * len(data)
case []uint64:
return 8 * len(data)
case float32, *float32:
return 4
case float64, *float64:
return 8
case []float32:
return 4 * len(data)
case []float64:
return 8 * len(data)
}
return 0
}

View file

@ -1,18 +0,0 @@
package binary
import (
"encoding/binary"
"reflect"
)
func DataSize(t reflect.Value) int {
return dataSize(t)
}
func EncodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
(&encoder{order: order, buf: buf}).value(v)
}
func DecodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
(&decoder{order: order, buf: buf}).value(v)
}

View file

@ -1,14 +0,0 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build armbe || arm64be || m68k || mips || mips64 || mips64p32 || ppc || ppc64 || s390 || s390x || shbe || sparc || sparc64
package binary
type nativeEndian struct {
bigEndian
}
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
var NativeEndian nativeEndian

View file

@ -1,14 +0,0 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build 386 || amd64 || amd64p32 || alpha || arm || arm64 || loong64 || mipsle || mips64le || mips64p32le || nios2 || ppc64le || riscv || riscv64 || sh || wasm
package binary
type nativeEndian struct {
littleEndian
}
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
var NativeEndian nativeEndian

View file

@ -1,166 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package binary
// This file implements "varint" encoding of 64-bit integers.
// The encoding is:
// - unsigned integers are serialized 7 bits at a time, starting with the
// least significant bits
// - the most significant bit (msb) in each output byte indicates if there
// is a continuation byte (msb = 1)
// - signed integers are mapped to unsigned integers using "zig-zag"
// encoding: Positive values x are written as 2*x + 0, negative values
// are written as 2*(^x) + 1; that is, negative numbers are complemented
// and whether to complement is encoded in bit 0.
//
// Design note:
// At most 10 bytes are needed for 64-bit values. The encoding could
// be more dense: a full 64-bit value needs an extra byte just to hold bit 63.
// Instead, the msb of the previous byte could be used to hold bit 63 since we
// know there can't be more than 64 bits. This is a trivial improvement and
// would reduce the maximum encoding length to 9 bytes. However, it breaks the
// invariant that the msb is always the "continuation bit" and thus makes the
// format incompatible with a varint encoding for larger numbers (say 128-bit).
import (
"errors"
"io"
)
// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer.
const (
MaxVarintLen16 = 3
MaxVarintLen32 = 5
MaxVarintLen64 = 10
)
// AppendUvarint appends the varint-encoded form of x,
// as generated by [PutUvarint], to buf and returns the extended buffer.
func AppendUvarint(buf []byte, x uint64) []byte {
for x >= 0x80 {
buf = append(buf, byte(x)|0x80)
x >>= 7
}
return append(buf, byte(x))
}
// PutUvarint encodes a uint64 into buf and returns the number of bytes written.
// If the buffer is too small, PutUvarint will panic.
func PutUvarint(buf []byte, x uint64) int {
i := 0
for x >= 0x80 {
buf[i] = byte(x) | 0x80
x >>= 7
i++
}
buf[i] = byte(x)
return i + 1
}
// Uvarint decodes a uint64 from buf and returns that value and the
// number of bytes read (> 0). If an error occurred, the value is 0
// and the number of bytes n is <= 0 meaning:
//
// n == 0: buf too small
// n < 0: value larger than 64 bits (overflow)
// and -n is the number of bytes read
func Uvarint(buf []byte) (uint64, int) {
var x uint64
var s uint
for i, b := range buf {
if i == MaxVarintLen64 {
// Catch byte reads past MaxVarintLen64.
// See issue https://golang.org/issues/41185
return 0, -(i + 1) // overflow
}
if b < 0x80 {
if i == MaxVarintLen64-1 && b > 1 {
return 0, -(i + 1) // overflow
}
return x | uint64(b)<<s, i + 1
}
x |= uint64(b&0x7f) << s
s += 7
}
return 0, 0
}
// AppendVarint appends the varint-encoded form of x,
// as generated by [PutVarint], to buf and returns the extended buffer.
func AppendVarint(buf []byte, x int64) []byte {
ux := uint64(x) << 1
if x < 0 {
ux = ^ux
}
return AppendUvarint(buf, ux)
}
// PutVarint encodes an int64 into buf and returns the number of bytes written.
// If the buffer is too small, PutVarint will panic.
func PutVarint(buf []byte, x int64) int {
ux := uint64(x) << 1
if x < 0 {
ux = ^ux
}
return PutUvarint(buf, ux)
}
// Varint decodes an int64 from buf and returns that value and the
// number of bytes read (> 0). If an error occurred, the value is 0
// and the number of bytes n is <= 0 with the following meaning:
//
// n == 0: buf too small
// n < 0: value larger than 64 bits (overflow)
// and -n is the number of bytes read
func Varint(buf []byte) (int64, int) {
ux, n := Uvarint(buf) // ok to continue in presence of error
x := int64(ux >> 1)
if ux&1 != 0 {
x = ^x
}
return x, n
}
var errOverflow = errors.New("binary: varint overflows a 64-bit integer")
// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64.
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// ReadUvarint returns [io.ErrUnexpectedEOF].
func ReadUvarint(r io.ByteReader) (uint64, error) {
var x uint64
var s uint
for i := 0; i < MaxVarintLen64; i++ {
b, err := r.ReadByte()
if err != nil {
if i > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return x, err
}
if b < 0x80 {
if i == MaxVarintLen64-1 && b > 1 {
return x, errOverflow
}
return x | uint64(b)<<s, nil
}
x |= uint64(b&0x7f) << s
s += 7
}
return x, errOverflow
}
// ReadVarint reads an encoded signed integer from r and returns it as an int64.
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// ReadVarint returns [io.ErrUnexpectedEOF].
func ReadVarint(r io.ByteReader) (int64, error) {
ux, err := ReadUvarint(r) // ok to continue in presence of error
x := int64(ux >> 1)
if ux&1 != 0 {
x = ^x
}
return x, err
}

View file

@ -5,10 +5,11 @@ package buf
import (
"errors"
"math/bits"
"strconv"
"sync"
)
var DefaultAllocator = newDefaultAllocator()
var DefaultAllocator = newDefaultAllocer()
type Allocator interface {
Get(size int) []byte
@ -17,72 +18,36 @@ type Allocator interface {
// defaultAllocator for incoming frames, optimized to prevent overwriting after zeroing
type defaultAllocator struct {
buffers [11]sync.Pool
buffers []sync.Pool
}
// NewAllocator initiates a []byte allocator for frames less than 65536 bytes,
// the waste(memory fragmentation) of space allocation is guaranteed to be
// no more than 50%.
func newDefaultAllocator() Allocator {
return &defaultAllocator{
buffers: [...]sync.Pool{ // 64B -> 64K
{New: func() any { return new([1 << 6]byte) }},
{New: func() any { return new([1 << 7]byte) }},
{New: func() any { return new([1 << 8]byte) }},
{New: func() any { return new([1 << 9]byte) }},
{New: func() any { return new([1 << 10]byte) }},
{New: func() any { return new([1 << 11]byte) }},
{New: func() any { return new([1 << 12]byte) }},
{New: func() any { return new([1 << 13]byte) }},
{New: func() any { return new([1 << 14]byte) }},
{New: func() any { return new([1 << 15]byte) }},
{New: func() any { return new([1 << 16]byte) }},
},
func newDefaultAllocer() Allocator {
alloc := new(defaultAllocator)
alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K
for k := range alloc.buffers {
i := k
alloc.buffers[k].New = func() any {
return make([]byte, 1<<uint32(i))
}
}
return alloc
}
// Get a []byte from pool with most appropriate cap
func (alloc *defaultAllocator) Get(size int) []byte {
if size <= 0 || size > 65536 {
return nil
panic("alloc bad size: " + strconv.Itoa(size))
}
var index uint16
if size > 64 {
index = msb(size)
if size != 1<<index {
index += 1
}
index -= 6
bits := msb(size)
if size == 1<<bits {
return alloc.buffers[bits].Get().([]byte)[:size]
}
buffer := alloc.buffers[index].Get()
switch index {
case 0:
return buffer.(*[1 << 6]byte)[:size]
case 1:
return buffer.(*[1 << 7]byte)[:size]
case 2:
return buffer.(*[1 << 8]byte)[:size]
case 3:
return buffer.(*[1 << 9]byte)[:size]
case 4:
return buffer.(*[1 << 10]byte)[:size]
case 5:
return buffer.(*[1 << 11]byte)[:size]
case 6:
return buffer.(*[1 << 12]byte)[:size]
case 7:
return buffer.(*[1 << 13]byte)[:size]
case 8:
return buffer.(*[1 << 14]byte)[:size]
case 9:
return buffer.(*[1 << 15]byte)[:size]
case 10:
return buffer.(*[1 << 16]byte)[:size]
default:
panic("invalid pool index")
}
return alloc.buffers[bits+1].Get().([]byte)[:size]
}
// Put returns a []byte to pool for future use,
@ -92,37 +57,10 @@ func (alloc *defaultAllocator) Put(buf []byte) error {
if cap(buf) == 0 || cap(buf) > 65536 || cap(buf) != 1<<bits {
return errors.New("allocator Put() incorrect buffer size")
}
bits -= 6
buf = buf[:cap(buf)]
//nolint
//lint:ignore SA6002 ignore temporarily
switch bits {
case 0:
alloc.buffers[bits].Put((*[1 << 6]byte)(buf))
case 1:
alloc.buffers[bits].Put((*[1 << 7]byte)(buf))
case 2:
alloc.buffers[bits].Put((*[1 << 8]byte)(buf))
case 3:
alloc.buffers[bits].Put((*[1 << 9]byte)(buf))
case 4:
alloc.buffers[bits].Put((*[1 << 10]byte)(buf))
case 5:
alloc.buffers[bits].Put((*[1 << 11]byte)(buf))
case 6:
alloc.buffers[bits].Put((*[1 << 12]byte)(buf))
case 7:
alloc.buffers[bits].Put((*[1 << 13]byte)(buf))
case 8:
alloc.buffers[bits].Put((*[1 << 14]byte)(buf))
case 9:
alloc.buffers[bits].Put((*[1 << 15]byte)(buf))
case 10:
alloc.buffers[bits].Put((*[1 << 16]byte)(buf))
default:
panic("invalid pool index")
}
alloc.buffers[bits].Put(buf)
return nil
}

View file

@ -4,70 +4,109 @@ import (
"crypto/rand"
"io"
"net"
"strconv"
"sync/atomic"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)
const (
ReversedHeader = 1024
BufferSize = 32 * 1024
UDPBufferSize = 16 * 1024
)
type Buffer struct {
data []byte
start int
end int
capacity int
refs atomic.Int32
managed bool
data []byte
start int
end int
refs int32
managed bool
closed bool
}
func New() *Buffer {
return &Buffer{
data: Get(BufferSize),
capacity: BufferSize,
managed: true,
data: Get(BufferSize),
start: ReversedHeader,
end: ReversedHeader,
managed: true,
}
}
func NewPacket() *Buffer {
return &Buffer{
data: Get(UDPBufferSize),
capacity: UDPBufferSize,
managed: true,
data: Get(UDPBufferSize),
start: ReversedHeader,
end: ReversedHeader,
managed: true,
}
}
func NewSize(size int) *Buffer {
if size == 0 {
return &Buffer{}
} else if size > 65535 {
if size > 65535 {
return &Buffer{
data: make([]byte, size),
capacity: size,
data: make([]byte, size),
}
}
return &Buffer{
data: Get(size),
capacity: size,
managed: true,
data: Get(size),
managed: true,
}
}
func StackNew() *Buffer {
if common.UnsafeBuffer {
return &Buffer{
data: make([]byte, BufferSize),
start: ReversedHeader,
end: ReversedHeader,
}
} else {
return New()
}
}
func StackNewPacket() *Buffer {
if common.UnsafeBuffer {
return &Buffer{
data: make([]byte, UDPBufferSize),
start: ReversedHeader,
end: ReversedHeader,
}
} else {
return NewPacket()
}
}
func StackNewSize(size int) *Buffer {
if common.UnsafeBuffer {
return &Buffer{
data: Make(size),
}
} else {
return NewSize(size)
}
}
func As(data []byte) *Buffer {
return &Buffer{
data: data,
end: len(data),
capacity: len(data),
data: data,
end: len(data),
}
}
func With(data []byte) *Buffer {
return &Buffer{
data: data,
capacity: len(data),
data: data,
}
}
func (b *Buffer) Closed() bool {
return b.closed
}
func (b *Buffer) Byte(index int) byte {
return b.data[b.start+index]
}
@ -78,8 +117,8 @@ func (b *Buffer) SetByte(index int, value byte) {
func (b *Buffer) Extend(n int) []byte {
end := b.end + n
if end > b.capacity {
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",end ", b.end, ", need ", n))
if end > cap(b.data) {
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",end " + strconv.Itoa(b.end) + ", need " + strconv.Itoa(n))
}
ext := b.data[b.end:end]
b.end = end
@ -101,14 +140,14 @@ func (b *Buffer) Write(data []byte) (n int, err error) {
if b.IsFull() {
return 0, io.ErrShortBuffer
}
n = copy(b.data[b.end:b.capacity], data)
n = copy(b.data[b.end:], data)
b.end += n
return
}
func (b *Buffer) ExtendHeader(n int) []byte {
if b.start < n {
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",start ", b.start, ", need ", n))
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",start " + strconv.Itoa(b.start) + ", need " + strconv.Itoa(n))
}
b.start -= n
return b.data[b.start : b.start+n]
@ -129,13 +168,13 @@ func (b *Buffer) WriteByte(d byte) error {
return nil
}
func (b *Buffer) ReadOnceFrom(r io.Reader) (int, error) {
func (b *Buffer) ReadOnceFrom(r io.Reader) (int64, error) {
if b.IsFull() {
return 0, io.ErrShortBuffer
}
n, err := r.Read(b.FreeBytes())
b.end += n
return n, err
return int64(n), err
}
func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) {
@ -149,8 +188,7 @@ func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) {
func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
if min <= 0 {
n, err := b.ReadOnceFrom(r)
return int64(n), err
return b.ReadOnceFrom(r)
}
if b.IsFull() {
return 0, io.ErrShortBuffer
@ -161,7 +199,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
}
func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) {
if b.end+size > b.capacity {
if b.end+size > b.Cap() {
return 0, io.ErrShortBuffer
}
n, err = io.ReadFull(r, b.data[b.end:b.end+size])
@ -198,7 +236,7 @@ func (b *Buffer) WriteString(s string) (n int, err error) {
if b.IsFull() {
return 0, io.ErrShortBuffer
}
n = copy(b.data[b.end:b.capacity], s)
n = copy(b.data[b.end:], s)
b.end += n
return
}
@ -213,10 +251,13 @@ func (b *Buffer) WriteZero() error {
}
func (b *Buffer) WriteZeroN(n int) error {
if b.end+n > b.capacity {
if b.end+n > b.Cap() {
return io.ErrShortBuffer
}
common.ClearArray(b.Extend(n))
for i := b.end; i <= b.end+n; i++ {
b.data[i] = 0
}
b.end += n
return nil
}
@ -259,63 +300,40 @@ func (b *Buffer) Resize(start, end int) {
b.end = b.start + end
}
func (b *Buffer) Reserve(n int) {
if n > b.capacity {
panic(F.ToString("buffer overflow: capacity ", b.capacity, ", need ", n))
}
b.capacity -= n
}
func (b *Buffer) OverCap(n int) {
if b.capacity+n > len(b.data) {
panic(F.ToString("buffer overflow: capacity ", len(b.data), ", need ", b.capacity+n))
}
b.capacity += n
}
func (b *Buffer) Reset() {
b.start = ReversedHeader
b.end = ReversedHeader
}
func (b *Buffer) FullReset() {
b.start = 0
b.end = 0
b.capacity = len(b.data)
}
// Deprecated: use Reset instead.
func (b *Buffer) FullReset() {
b.Reset()
}
func (b *Buffer) IncRef() {
b.refs.Add(1)
atomic.AddInt32(&b.refs, 1)
}
func (b *Buffer) DecRef() {
b.refs.Add(-1)
atomic.AddInt32(&b.refs, -1)
}
func (b *Buffer) Release() {
if b == nil || !b.managed {
if b == nil || b.closed || !b.managed {
return
}
if b.refs.Load() > 0 {
if atomic.LoadInt32(&b.refs) > 0 {
return
}
common.Must(Put(b.data))
*b = Buffer{}
*b = Buffer{closed: true}
}
func (b *Buffer) Leak() {
if debug.Enabled {
if b == nil || !b.managed {
return
}
refs := b.refs.Load()
if refs == 0 {
panic("leaking buffer")
} else {
panic(F.ToString("leaking buffer with ", refs, " references"))
}
} else {
b.Release()
func (b *Buffer) Cut(start int, end int) *Buffer {
b.start += start
b.end = len(b.data) - end
return &Buffer{
data: b.data[b.start:b.end],
}
}
@ -328,10 +346,6 @@ func (b *Buffer) Len() int {
}
func (b *Buffer) Cap() int {
return b.capacity
}
func (b *Buffer) RawCap() int {
return len(b.data)
}
@ -339,6 +353,10 @@ func (b *Buffer) Bytes() []byte {
return b.data[b.start:b.end]
}
func (b *Buffer) Slice() []byte {
return b.data
}
func (b *Buffer) From(n int) []byte {
return b.data[b.start+n : b.end]
}
@ -356,11 +374,11 @@ func (b *Buffer) Index(start int) []byte {
}
func (b *Buffer) FreeLen() int {
return b.capacity - b.end
return b.Cap() - b.end
}
func (b *Buffer) FreeBytes() []byte {
return b.data[b.end:b.capacity]
return b.data[b.end:b.Cap()]
}
func (b *Buffer) IsEmpty() bool {
@ -368,7 +386,7 @@ func (b *Buffer) IsEmpty() bool {
}
func (b *Buffer) IsFull() bool {
return b.end == b.capacity
return b.end == b.Cap()
}
func (b *Buffer) ToOwned() *Buffer {
@ -376,6 +394,5 @@ func (b *Buffer) ToOwned() *Buffer {
copy(n.data[b.start:b.end], b.data[b.start:b.end])
n.start = b.start
n.end = b.end
n.capacity = b.capacity
return n
}

View file

@ -1,8 +0,0 @@
//go:build with_low_memory
package buf
const (
BufferSize = 16 * 1024
UDPBufferSize = 8 * 1024
)

View file

@ -1,8 +0,0 @@
//go:build !with_low_memory
package buf
const (
BufferSize = 32 * 1024
UDPBufferSize = 16 * 1024
)

9
common/buf/hex.go Normal file
View file

@ -0,0 +1,9 @@
package buf
import "encoding/hex"
func EncodeHexString(src []byte) string {
dst := Make(hex.EncodedLen(len(src)))
hex.Encode(dst, src)
return string(dst)
}

View file

@ -1,9 +1,6 @@
package buf
func Get(size int) []byte {
if size == 0 {
return nil
}
return DefaultAllocator.Get(size)
}
@ -11,7 +8,43 @@ func Put(buf []byte) error {
return DefaultAllocator.Put(buf)
}
// Deprecated: use array instead.
func Make(size int) []byte {
return make([]byte, size)
var buffer []byte
switch {
case size <= 2:
buffer = make([]byte, 2)
case size <= 4:
buffer = make([]byte, 4)
case size <= 8:
buffer = make([]byte, 8)
case size <= 16:
buffer = make([]byte, 16)
case size <= 32:
buffer = make([]byte, 32)
case size <= 64:
buffer = make([]byte, 64)
case size <= 128:
buffer = make([]byte, 128)
case size <= 256:
buffer = make([]byte, 256)
case size <= 512:
buffer = make([]byte, 512)
case size <= 1024:
buffer = make([]byte, 1024)
case size <= 2048:
buffer = make([]byte, 2048)
case size <= 4096:
buffer = make([]byte, 4096)
case size <= 8192:
buffer = make([]byte, 8192)
case size <= 16384:
buffer = make([]byte, 16384)
case size <= 32768:
buffer = make([]byte, 32768)
case size <= 65535:
buffer = make([]byte, 65535)
default:
return make([]byte, size)
}
return buffer[:size]
}

34
common/buf/ptr.go Normal file
View file

@ -0,0 +1,34 @@
//go:build !disable_unsafe
package buf
import (
"unsafe"
"github.com/sagernet/sing/common"
)
type dbgVar struct {
name string
value *int32
}
//go:linkname dbgvars runtime.dbgvars
var dbgvars any
// go.info.runtime.dbgvars: relocation target go.info.[]github.com/sagernet/sing/common/buf.dbgVar not defined
// var dbgvars []dbgVar
func init() {
if !common.UnsafeBuffer {
return
}
debugVars := *(*[]dbgVar)(unsafe.Pointer(&dbgvars))
for _, v := range debugVars {
if v.name == "invalidptr" {
*v.value = 0
return
}
}
panic("can't disable invalidptr")
}

View file

@ -9,20 +9,19 @@ import (
type AddrConn struct {
net.Conn
Source M.Socksaddr
Destination M.Socksaddr
M.Metadata
}
func (c *AddrConn) LocalAddr() net.Addr {
if c.Destination.IsValid() {
return c.Destination.TCPAddr()
if c.Metadata.Destination.IsValid() {
return c.Metadata.Destination.TCPAddr()
}
return c.Conn.LocalAddr()
}
func (c *AddrConn) RemoteAddr() net.Addr {
if c.Source.IsValid() {
return c.Source.TCPAddr()
if c.Metadata.Source.IsValid() {
return c.Metadata.Source.TCPAddr()
}
return c.Conn.RemoteAddr()
}

View file

@ -1,34 +0,0 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package bufio
import (
"encoding/binary"
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
if destination.Addr().Is4() {
sa := unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: destination.Addr().As4(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet4
} else {
sa := unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: destination.Addr().As16(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet6
}
return
}

View file

@ -1,30 +0,0 @@
package bufio
import (
"encoding/binary"
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
if destination.Addr().Is4() {
sa := unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: destination.Addr().As4(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet4
} else {
sa := unix.RawSockaddrInet6{
Family: unix.AF_INET6,
Addr: destination.Addr().As16(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet6
}
return
}

View file

@ -1,30 +0,0 @@
package bufio
import (
"encoding/binary"
"net/netip"
"unsafe"
"golang.org/x/sys/windows"
)
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen int32) {
if destination.Addr().Is4() {
sa := windows.RawSockaddrInet4{
Family: windows.AF_INET,
Addr: destination.Addr().As4(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = int32(unsafe.Sizeof(sa))
} else {
sa := windows.RawSockaddrInet6{
Family: windows.AF_INET6,
Addr: destination.Addr().As16(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = int32(unsafe.Sizeof(sa))
}
return
}

View file

@ -1,165 +0,0 @@
package bufio
import (
"net"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type BindPacketConn interface {
N.NetPacketConn
net.Conn
}
type bindPacketConn struct {
N.NetPacketConn
addr net.Addr
}
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) BindPacketConn {
return &bindPacketConn{
NewPacketConn(conn),
addr,
}
}
func (c *bindPacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *bindPacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.addr)
}
func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
if !isReadWaiter {
return nil, false
}
return &bindPacketReadWaiter{readWaiter}, true
}
func (c *bindPacketConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *bindPacketConn) Upstream() any {
return c.NetPacketConn
}
var (
_ N.NetPacketConn = (*UnbindPacketConn)(nil)
_ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil)
)
type UnbindPacketConn struct {
N.ExtendedConn
addr M.Socksaddr
}
func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn {
return &UnbindPacketConn{
NewExtendedConn(conn),
M.SocksaddrFromNet(conn.RemoteAddr()),
}
}
func NewUnbindPacketConnWithAddr(conn net.Conn, addr M.Socksaddr) N.NetPacketConn {
return &UnbindPacketConn{
NewExtendedConn(conn),
addr,
}
}
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.ExtendedConn.Read(p)
if err == nil {
addr = c.addr.UDPAddr()
}
return
}
func (c *UnbindPacketConn) WriteTo(p []byte, _ net.Addr) (n int, err error) {
return c.ExtendedConn.Write(p)
}
func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
err = c.ExtendedConn.ReadBuffer(buffer)
if err != nil {
return
}
destination = c.addr
return
}
func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
readWaiter, isReadWaiter := CreateReadWaiter(c.ExtendedConn)
if !isReadWaiter {
return nil, false
}
return &unbindPacketReadWaiter{readWaiter, c.addr}, true
}
func (c *UnbindPacketConn) Upstream() any {
return c.ExtendedConn
}
func NewServerPacketConn(conn net.PacketConn) N.ExtendedConn {
return &serverPacketConn{
NetPacketConn: NewPacketConn(conn),
}
}
type serverPacketConn struct {
N.NetPacketConn
remoteAddr M.Socksaddr
}
func (c *serverPacketConn) Read(p []byte) (n int, err error) {
n, addr, err := c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
c.remoteAddr = M.SocksaddrFromNet(addr)
return
}
func (c *serverPacketConn) ReadBuffer(buffer *buf.Buffer) error {
destination, err := c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return err
}
c.remoteAddr = destination
return nil
}
func (c *serverPacketConn) Write(p []byte) (n int, err error) {
return c.NetPacketConn.WriteTo(p, c.remoteAddr.UDPAddr())
}
func (c *serverPacketConn) WriteBuffer(buffer *buf.Buffer) error {
return c.NetPacketConn.WritePacket(buffer, c.remoteAddr)
}
func (c *serverPacketConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *serverPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
if !isReadWaiter {
return nil, false
}
return &serverPacketReadWaiter{c, readWaiter}, true
}

View file

@ -1,62 +0,0 @@
package bufio
import (
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var _ N.ReadWaiter = (*bindPacketReadWaiter)(nil)
type bindPacketReadWaiter struct {
readWaiter N.PacketReadWaiter
}
func (w *bindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *bindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, _, err = w.readWaiter.WaitReadPacket()
return
}
var _ N.PacketReadWaiter = (*unbindPacketReadWaiter)(nil)
type unbindPacketReadWaiter struct {
readWaiter N.ReadWaiter
addr M.Socksaddr
}
func (w *unbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *unbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, err = w.readWaiter.WaitReadBuffer()
if err != nil {
return
}
destination = w.addr
return
}
var _ N.ReadWaiter = (*serverPacketReadWaiter)(nil)
type serverPacketReadWaiter struct {
*serverPacketConn
readWaiter N.PacketReadWaiter
}
func (w *serverPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *serverPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, destination, err := w.readWaiter.WaitReadPacket()
if err != nil {
return
}
w.remoteAddr = destination
return
}

View file

@ -2,12 +2,77 @@ package bufio
import (
"io"
"os"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
)
type BufferedReader struct {
upstream N.ExtendedReader
buffer *buf.Buffer
}
func NewBufferedReader(upstream io.Reader, buffer *buf.Buffer) *BufferedReader {
return &BufferedReader{
upstream: NewExtendedReader(upstream),
buffer: buffer,
}
}
func (r *BufferedReader) Read(p []byte) (n int, err error) {
if r.buffer.Closed() {
return 0, os.ErrClosed
}
if r.buffer.IsEmpty() {
r.buffer.Reset()
err = r.upstream.ReadBuffer(r.buffer)
if err != nil {
r.buffer.Release()
return
}
}
return r.buffer.Read(p)
}
func (r *BufferedReader) ReadBuffer(buffer *buf.Buffer) error {
if r.buffer.Closed() {
return os.ErrClosed
}
var err error
if r.buffer.IsEmpty() {
r.buffer.Reset()
err = r.upstream.ReadBuffer(r.buffer)
if err != nil {
r.buffer.Release()
return err
}
}
if r.buffer.Len() > buffer.FreeLen() {
err = common.Error(buffer.ReadFullFrom(r.buffer, buffer.FreeLen()))
} else {
err = common.Error(buffer.ReadFullFrom(r.buffer, r.buffer.Len()))
}
if err != nil {
r.buffer.Release()
}
return err
}
func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) {
if r.buffer.Closed() {
return 0, os.ErrClosed
}
defer r.buffer.Release()
return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r.upstream), r.buffer)
}
func (r *BufferedReader) Upstream() any {
return r.upstream
}
type BufferedWriter struct {
upstream io.Writer
buffer *buf.Buffer
@ -38,26 +103,7 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) {
if err != nil {
return
}
w.buffer.Reset()
}
}
func (w *BufferedWriter) WriteByte(c byte) error {
w.access.Lock()
defer w.access.Unlock()
if w.buffer == nil {
return common.Error(w.upstream.Write([]byte{c}))
}
for {
err := w.buffer.WriteByte(c)
if err == nil {
return nil
}
_, err = w.upstream.Write(w.buffer.Bytes())
if err != nil {
return err
}
w.buffer.Reset()
w.buffer.FullReset()
}
}
@ -78,6 +124,13 @@ func (w *BufferedWriter) Fallthrough() error {
return nil
}
func (w *BufferedWriter) ReadFrom(r io.Reader) (n int64, err error) {
if w.buffer == nil {
return Copy(w.upstream, r)
}
return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r), w.buffer)
}
func (w *BufferedWriter) WriterReplaceable() bool {
return w.buffer == nil
}

View file

@ -3,6 +3,7 @@ package bufio
import (
"io"
"net"
"time"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
@ -59,6 +60,13 @@ func (c *CachedConn) WriteTo(w io.Writer) (n int64, err error) {
return
}
func (c *CachedConn) SetReadDeadline(t time.Time) error {
if c.buffer != nil && !c.buffer.IsEmpty() {
return nil
}
return c.Conn.SetReadDeadline(t)
}
func (c *CachedConn) ReadFrom(r io.Reader) (n int64, err error) {
return Copy(c.Conn, r)
}
@ -178,18 +186,13 @@ func (c *CachedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
return c.PacketConn.ReadPacket(buffer)
}
func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
buffer := c.buffer
func (c *CachedPacketConn) ReadCachedPacket() (destination M.Socksaddr, buffer *buf.Buffer) {
buffer = c.buffer
c.buffer = nil
if buffer != nil {
buffer.DecRef()
}
packet := N.NewPacketBuffer()
*packet = N.PacketBuffer{
Buffer: buffer,
Destination: c.destination,
}
return packet
return c.destination, buffer
}
func (c *CachedPacketConn) Upstream() any {

View file

@ -30,7 +30,7 @@ func (c *ChunkReader) ReadBuffer(buffer *buf.Buffer) error {
} else if !c.cache.IsEmpty() {
return common.Error(buffer.ReadFrom(c.cache))
}
c.cache.Reset()
c.cache.FullReset()
err := c.upstream.ReadBuffer(c.cache)
if err != nil {
c.cache.Release()
@ -46,7 +46,7 @@ func (c *ChunkReader) Read(p []byte) (n int, err error) {
} else if !c.cache.IsEmpty() {
return c.cache.Read(p)
}
c.cache.Reset()
c.cache.FullReset()
err = c.upstream.ReadBuffer(c.cache)
if err != nil {
c.cache.Release()
@ -56,21 +56,13 @@ func (c *ChunkReader) Read(p []byte) (n int, err error) {
return c.cache.Read(p)
}
func (c *ChunkReader) ReadByte() (byte, error) {
buffer, err := c.ReadChunk()
if err != nil {
return 0, err
}
return buffer.ReadByte()
}
func (c *ChunkReader) ReadChunk() (*buf.Buffer, error) {
if c.cache == nil {
c.cache = buf.NewSize(c.maxChunkSize)
} else if !c.cache.IsEmpty() {
return c.cache, nil
}
c.cache.Reset()
c.cache.FullReset()
err := c.upstream.ReadBuffer(c.cache)
if err != nil {
c.cache.Release()

View file

@ -35,7 +35,14 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
if destination.IsFqdn() {
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
if err != nil {
return err
}
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
}
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
}
func (w *ExtendedUDPConn) Upstream() any {
@ -63,6 +70,69 @@ func (w *ExtendedPacketConn) Upstream() any {
return w.PacketConn
}
type BindPacketConn struct {
net.PacketConn
Addr net.Addr
}
func (c *BindPacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *BindPacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.Addr)
}
func (c *BindPacketConn) RemoteAddr() net.Addr {
return c.Addr
}
func (c *BindPacketConn) Upstream() any {
return c.PacketConn
}
type UnbindPacketConn struct {
N.ExtendedConn
Addr M.Socksaddr
}
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.ExtendedConn.Read(p)
if err == nil {
addr = c.Addr.UDPAddr()
}
return
}
func (c *UnbindPacketConn) WriteTo(p []byte, _ net.Addr) (n int, err error) {
return c.ExtendedConn.Write(p)
}
func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
err = c.ExtendedConn.ReadBuffer(buffer)
if err != nil {
return
}
destination = c.Addr
return
}
func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *UnbindPacketConn) Upstream() any {
return c.ExtendedConn
}
func NewUnbindPacketConn(conn net.Conn) *UnbindPacketConn {
return &UnbindPacketConn{
NewExtendedConn(conn),
M.SocksaddrFromNet(conn.RemoteAddr()),
}
}
type ExtendedReaderWrapper struct {
io.Reader
}
@ -118,7 +188,7 @@ func (w *ExtendedWriterWrapper) Upstream() any {
return w.Writer
}
func (w *ExtendedWriterWrapper) WriterReplaceable() bool {
func (w *ExtendedReaderWrapper) WriterReplaceable() bool {
return true
}

View file

@ -2,316 +2,348 @@ package bufio
import (
"context"
"errors"
"io"
"net"
"syscall"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
)
func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
if source == nil {
type readOnlyReader struct {
io.Reader
}
func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) {
return Copy(w, r.Reader)
}
func (r *readOnlyReader) Upstream() any {
return r.Reader
}
func (r *readOnlyReader) ReaderReplaceable() bool {
return true
}
type writeOnlyWriter struct {
io.Writer
}
func (w *writeOnlyWriter) ReadFrom(r io.Reader) (n int64, err error) {
return Copy(w.Writer, r)
}
func (w *writeOnlyWriter) Upstream() any {
return w.Writer
}
func (w *writeOnlyWriter) WriterReplaceable() bool {
return true
}
func needWrapper(src, dst any) bool {
_, srcTCPConn := src.(*net.TCPConn)
_, dstTCPConn := dst.(*net.TCPConn)
return (srcTCPConn || dstTCPConn) && !(srcTCPConn && dstTCPConn)
}
func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
if src == nil {
return 0, E.New("nil reader")
} else if destination == nil {
} else if dst == nil {
return 0, E.New("nil writer")
}
originSource := source
var readCounters, writeCounters []N.CountFunc
for {
source, readCounters = N.UnwrapCountReader(source, readCounters)
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
if cachedSrc, isCached := source.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil {
dataLen := cachedBuffer.Len()
_, err = destination.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err != nil {
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
continue
}
src = N.UnwrapReader(src)
dst = N.UnwrapWriter(dst)
if wt, ok := src.(io.WriterTo); ok {
if needWrapper(dst, src) {
dst = &writeOnlyWriter{dst}
}
break
return wt.WriteTo(dst)
}
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
if rt, ok := dst.(io.ReaderFrom); ok {
if needWrapper(rt, src) {
src = &readOnlyReader{src}
}
return rt.ReadFrom(src)
}
return CopyExtended(NewExtendedWriter(dst), NewExtendedReader(src))
}
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
safeSrc := N.IsSafeReader(src)
headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst)
if safeSrc != nil {
if headroom == 0 {
return CopyExtendedWithSrcBuffer(dst, safeSrc)
}
}
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
}
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
readWaiter, isReadWaiter := CreateReadWaiter(source)
if isReadWaiter {
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destination),
})
if !needCopy || common.LowMemory {
var handled bool
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
if handled {
return
}
}
if N.IsUnsafeWriter(dst) {
return CopyExtendedWithPool(dst, src)
}
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += headroom
} else {
bufferSize = buf.BufferSize
}
_buffer := buf.StackNewSize(bufferSize)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
return CopyExtendedBuffer(dst, src, buffer)
}
// Deprecated: not used
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) {
buffer.IncRef()
defer buffer.DecRef()
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
var notFirstTime bool
for {
err = source.ReadBuffer(buffer)
readBuffer.Resize(frontHeadroom, 0)
err = src.ReadBuffer(readBuffer)
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destination.WriteBuffer(buffer)
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WriteBuffer(buffer)
if err != nil {
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
options := N.NewReadWaitOptions(source, destination)
func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) {
var notFirstTime bool
for {
buffer := options.NewBuffer()
err = source.ReadBuffer(buffer)
var buffer *buf.Buffer
buffer, err = src.ReadBufferThreadSafe()
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := buffer.Len()
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := buffer.Len()
options.PostReturn(buffer)
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
var notFirstTime bool
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
err = src.ReadBuffer(readBuffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
var group task.Group
if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex {
if _, dstDuplex := common.Cast[rw.WriteCloser](dest); dstDuplex {
group.Append("upload", func(ctx context.Context) error {
err := common.Error(Copy(destination, source))
if err == nil {
N.CloseWrite(destination)
err := common.Error(Copy(dest, conn))
if E.IsMulti(err, io.EOF) {
rw.CloseWrite(dest)
} else {
common.Close(destination)
common.Close(dest)
}
return err
})
} else {
group.Append("upload", func(ctx context.Context) error {
defer common.Close(destination)
return common.Error(Copy(destination, source))
defer common.Close(dest)
return common.Error(Copy(dest, conn))
})
}
if _, srcDuplex := common.Cast[N.WriteCloser](source); srcDuplex {
if _, srcDuplex := common.Cast[rw.WriteCloser](conn); srcDuplex {
group.Append("download", func(ctx context.Context) error {
err := common.Error(Copy(source, destination))
if err == nil {
N.CloseWrite(source)
err := common.Error(Copy(conn, dest))
if E.IsMulti(err, io.EOF) {
rw.CloseWrite(conn)
} else {
common.Close(source)
common.Close(conn)
}
return err
})
} else {
group.Append("download", func(ctx context.Context) error {
defer common.Close(source)
return common.Error(Copy(source, destination))
defer common.Close(conn)
return common.Error(Copy(conn, dest))
})
}
group.Cleanup(func() {
common.Close(source, destination)
common.Close(conn, dest)
})
return group.Run(ctx)
}
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
var readCounters, writeCounters []N.CountFunc
var cachedPackets []*N.PacketBuffer
originSource := source
for {
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters)
if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
packet := cachedReader.ReadCachedPacket()
if packet != nil {
cachedPackets = append(cachedPackets, packet)
continue
}
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
src = N.UnwrapPacketReader(src)
dst = N.UnwrapPacketWriter(dst)
safeSrc := N.IsSafePacketReader(src)
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
headroom := frontHeadroom + rearHeadroom
if safeSrc != nil {
if headroom == 0 {
return CopyPacketWithSrcBuffer(dst, safeSrc)
}
break
}
if cachedPackets != nil {
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
if N.IsUnsafeWriter(dst) {
return CopyPacketWithPool(dst, src)
}
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += headroom
} else {
bufferSize = buf.UDPBufferSize
}
_buffer := buf.StackNewSize(bufferSize)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
buffer.IncRef()
defer buffer.DecRef()
var destination M.Socksaddr
var notFirstTime bool
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
for {
readBuffer.Resize(frontHeadroom, 0)
destination, err = src.ReadPacket(readBuffer)
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WritePacket(buffer, destination)
if err != nil {
return
}
n += int64(dataLen)
notFirstTime = true
}
copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
n += copeN
return
}
func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var (
handled bool
copeN int64
)
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter {
needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled {
n += copeN
return
}
}
}
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
n += copeN
return
}
func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
options := N.NewReadWaitOptions(source, destination)
var destinationAddress M.Socksaddr
func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) {
var buffer *buf.Buffer
var destination M.Socksaddr
var notFirstTime bool
for {
buffer := options.NewPacketBuffer()
destinationAddress, err = source.ReadPacket(buffer)
buffer, destination, err = src.ReadPacketThreadSafe()
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := buffer.Len()
err = dst.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
dataLen := buffer.Len()
options.PostReturn(buffer)
err = destination.WritePacket(buffer, destinationAddress)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
n += int64(dataLen)
notFirstTime = true
}
}
func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
options := N.NewReadWaitOptions(nil, destination)
func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var destination M.Socksaddr
var notFirstTime bool
for _, packetBuffer := range packetBuffers {
buffer := options.Copy(packetBuffer.Buffer)
dataLen := buffer.Len()
err = destination.WritePacket(buffer, packetBuffer.Destination)
N.PutPacketBuffer(packetBuffer)
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
destination, err = src.ReadPacket(readBuffer)
if err != nil {
buffer.Leak()
buffer.Release()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
err = N.HandshakeFailure(dst, err)
}
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
return
}
func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
var group task.Group
group.Append("upload", func(ctx context.Context) error {
return common.Error(CopyPacket(destination, source))
return common.Error(CopyPacket(dest, conn))
})
group.Append("download", func(ctx context.Context) error {
return common.Error(CopyPacket(source, destination))
return common.Error(CopyPacket(conn, dest))
})
group.Cleanup(func() {
common.Close(source, destination)
common.Close(conn, dest)
})
group.FastFail()
return group.Run(ctx)

View file

@ -1,90 +0,0 @@
package bufio
import (
"errors"
"io"
"syscall"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
rawSource, err := source.SyscallConn()
if err != nil {
return
}
rawDestination, err := destination.SyscallConn()
if err != nil {
return
}
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
return
}
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
notFirstTime bool
)
for {
buffer, err = source.WaitReadBuffer()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := buffer.Len()
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
destination M.Socksaddr
)
for {
buffer, destination, err = source.WaitReadPacket()
if err != nil {
return
}
dataLen := buffer.Len()
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}

View file

@ -1,146 +0,0 @@
//go:build !windows
package bufio
import (
"io"
"net/netip"
"os"
"syscall"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
type syscallReadWaiter struct {
rawConn syscall.RawConn
readErr error
readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer
options N.ReadWaitOptions
}
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
rawConn, err := syscallConn.SyscallConn()
if err == nil {
return &syscallReadWaiter{rawConn: rawConn}, true
}
}
return nil, false
}
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
buffer := w.options.NewBuffer()
var readN int
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
if readN > 0 {
buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
}
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr == syscall.EAGAIN {
return false
}
if readN == 0 && w.readErr == nil {
w.readErr = io.EOF
}
return true
}
return false
}
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
if w.readFunc == nil {
return nil, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
return
}
if w.readErr != nil {
if w.readErr == io.EOF {
return nil, io.EOF
}
return nil, E.Cause(w.readErr, "raw read")
}
buffer = w.buffer
w.buffer = nil
return
}
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
type syscallPacketReadWaiter struct {
rawConn syscall.RawConn
readErr error
readFrom M.Socksaddr
readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer
options N.ReadWaitOptions
}
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
rawConn, err := syscallConn.SyscallConn()
if err == nil {
return &syscallPacketReadWaiter{rawConn: rawConn}, true
}
}
return nil, false
}
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
buffer := w.options.NewPacketBuffer()
var readN int
var from syscall.Sockaddr
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr != nil {
buffer.Release()
return w.readErr != syscall.EAGAIN
}
if readN > 0 {
buffer.Truncate(readN)
}
w.options.PostReturn(buffer)
w.buffer = buffer
switch fromAddr := from.(type) {
case *syscall.SockaddrInet4:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *syscall.SockaddrInet6:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
return true
}
return false
}
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if w.readFunc == nil {
return nil, M.Socksaddr{}, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
return
}
if w.readErr != nil {
err = E.Cause(w.readErr, "raw read")
return
}
buffer = w.buffer
w.buffer = nil
destination = w.readFrom
return
}

View file

@ -1,77 +0,0 @@
package bufio
import (
"net"
"testing"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
"github.com/stretchr/testify/require"
)
func TestCopyWaitTCP(t *testing.T) {
t.Parallel()
inputConn, outputConn := TCPPipe(t)
readWaiter, created := CreateReadWaiter(outputConn)
require.True(t, created)
require.NotNil(t, readWaiter)
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{
Conn: outputConn,
readWaiter: readWaiter,
}))
}
type readWaitWrapper struct {
net.Conn
readWaiter N.ReadWaiter
buffer *buf.Buffer
}
func (r *readWaitWrapper) Read(p []byte) (n int, err error) {
if r.buffer != nil {
if r.buffer.Len() > 0 {
return r.buffer.Read(p)
}
if r.buffer.IsEmpty() {
r.buffer.Release()
r.buffer = nil
}
}
buffer, err := r.readWaiter.WaitReadBuffer()
if err != nil {
return
}
r.buffer = buffer
return r.buffer.Read(p)
}
func TestCopyWaitUDP(t *testing.T) {
t.Parallel()
inputConn, outputConn, outputAddr := UDPPipe(t)
readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn))
require.True(t, created)
require.NotNil(t, readWaiter)
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{
PacketConn: outputConn,
readWaiter: readWaiter,
}, outputAddr))
}
type packetReadWaitWrapper struct {
net.PacketConn
readWaiter N.PacketReadWaiter
}
func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
buffer, destination, err := r.readWaiter.WaitReadPacket()
if err != nil {
return
}
n = copy(p, buffer.Bytes())
buffer.Release()
addr = destination.UDPAddr()
return
}

View file

@ -1,163 +0,0 @@
package bufio
import (
"io"
"net/netip"
"os"
"syscall"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/windows"
)
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
type syscallReadWaiter struct {
rawConn syscall.RawConn
readErr error
readFunc func(fd uintptr) (done bool)
hasData bool
buffer *buf.Buffer
options N.ReadWaitOptions
}
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
rawConn, err := syscallConn.SyscallConn()
if err == nil {
return &syscallReadWaiter{rawConn: rawConn}, true
}
}
return nil, false
}
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
if !w.hasData {
w.hasData = true
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
// socket is readable if we return false. So the `recv` syscall will not block the system thread.
return false
}
buffer := w.options.NewBuffer()
var readN int32
readN, w.readErr = recv(windows.Handle(fd), buffer.FreeBytes(), 0)
if readN > 0 {
buffer.Truncate(int(readN))
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
}
if w.readErr == windows.WSAEWOULDBLOCK {
return false
}
if readN == 0 && w.readErr == nil {
w.readErr = io.EOF
}
w.hasData = false
return true
}
return false
}
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
if w.readFunc == nil {
return nil, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
return
}
if w.readErr != nil {
if w.readErr == io.EOF {
return nil, io.EOF
}
return nil, E.Cause(w.readErr, "raw read")
}
buffer = w.buffer
w.buffer = nil
return
}
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
type syscallPacketReadWaiter struct {
rawConn syscall.RawConn
readErr error
readFrom M.Socksaddr
readFunc func(fd uintptr) (done bool)
hasData bool
buffer *buf.Buffer
options N.ReadWaitOptions
}
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
rawConn, err := syscallConn.SyscallConn()
if err == nil {
return &syscallPacketReadWaiter{rawConn: rawConn}, true
}
}
return nil, false
}
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
if !w.hasData {
w.hasData = true
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
// socket is readable if we return false. So the `recvfrom` syscall will not block the system thread.
return false
}
buffer := w.options.NewPacketBuffer()
var readN int
var from windows.Sockaddr
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr != nil {
buffer.Release()
return w.readErr != windows.WSAEWOULDBLOCK
}
if readN > 0 {
buffer.Truncate(readN)
}
w.options.PostReturn(buffer)
w.buffer = buffer
if from != nil {
switch fromAddr := from.(type) {
case *windows.SockaddrInet4:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *windows.SockaddrInet6:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
}
w.hasData = false
return true
}
return false
}
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if w.readFunc == nil {
return nil, M.Socksaddr{}, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
return
}
if w.readErr != nil {
err = E.Cause(w.readErr, "raw read")
return
}
buffer = w.buffer
w.buffer = nil
destination = w.readFrom
return
}

View file

@ -1,96 +0,0 @@
package bufio
import (
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
)
func NewInt64CounterConn(conn net.Conn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *CounterConn {
return &CounterConn{
NewExtendedConn(conn),
common.Map(readCounter, func(it *atomic.Int64) N.CountFunc {
return func(n int64) {
it.Add(n)
}
}),
common.Map(writeCounter, func(it *atomic.Int64) N.CountFunc {
return func(n int64) {
it.Add(n)
}
}),
}
}
func NewCounterConn(conn net.Conn, readCounter []N.CountFunc, writeCounter []N.CountFunc) *CounterConn {
return &CounterConn{NewExtendedConn(conn), readCounter, writeCounter}
}
type CounterConn struct {
N.ExtendedConn
readCounter []N.CountFunc
writeCounter []N.CountFunc
}
func (c *CounterConn) Read(p []byte) (n int, err error) {
n, err = c.ExtendedConn.Read(p)
if n > 0 {
for _, counter := range c.readCounter {
counter(int64(n))
}
}
return n, err
}
func (c *CounterConn) ReadBuffer(buffer *buf.Buffer) error {
err := c.ExtendedConn.ReadBuffer(buffer)
if err != nil {
return err
}
if buffer.Len() > 0 {
for _, counter := range c.readCounter {
counter(int64(buffer.Len()))
}
}
return nil
}
func (c *CounterConn) Write(p []byte) (n int, err error) {
n, err = c.ExtendedConn.Write(p)
if n > 0 {
for _, counter := range c.writeCounter {
counter(int64(n))
}
}
return n, err
}
func (c *CounterConn) WriteBuffer(buffer *buf.Buffer) error {
dataLen := int64(buffer.Len())
err := c.ExtendedConn.WriteBuffer(buffer)
if err != nil {
return err
}
if dataLen > 0 {
for _, counter := range c.writeCounter {
counter(dataLen)
}
}
return nil
}
func (c *CounterConn) UnwrapReader() (io.Reader, []N.CountFunc) {
return c.ExtendedConn, c.readCounter
}
func (c *CounterConn) UnwrapWriter() (io.Writer, []N.CountFunc) {
return c.ExtendedConn, c.writeCounter
}
func (c *CounterConn) Upstream() any {
return c.ExtendedConn
}

View file

@ -1,73 +0,0 @@
package bufio
import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type CounterPacketConn struct {
N.PacketConn
readCounter []N.CountFunc
writeCounter []N.CountFunc
}
func NewInt64CounterPacketConn(conn N.PacketConn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *CounterPacketConn {
return &CounterPacketConn{
conn,
common.Map(readCounter, func(it *atomic.Int64) N.CountFunc {
return func(n int64) {
it.Add(n)
}
}),
common.Map(writeCounter, func(it *atomic.Int64) N.CountFunc {
return func(n int64) {
it.Add(n)
}
}),
}
}
func NewCounterPacketConn(conn N.PacketConn, readCounter []N.CountFunc, writeCounter []N.CountFunc) *CounterPacketConn {
return &CounterPacketConn{conn, readCounter, writeCounter}
}
func (c *CounterPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.PacketConn.ReadPacket(buffer)
if err == nil {
if buffer.Len() > 0 {
for _, counter := range c.readCounter {
counter(int64(buffer.Len()))
}
}
}
return
}
func (c *CounterPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
dataLen := int64(buffer.Len())
err := c.PacketConn.WritePacket(buffer, destination)
if err != nil {
return err
}
if dataLen > 0 {
for _, counter := range c.writeCounter {
counter(dataLen)
}
}
return nil
}
func (c *CounterPacketConn) UnwrapPacketReader() (N.PacketReader, []N.CountFunc) {
return c.PacketConn, c.readCounter
}
func (c *CounterPacketConn) UnwrapPacketWriter() (N.PacketWriter, []N.CountFunc) {
return c.PacketConn, c.writeCounter
}
func (c *CounterPacketConn) Upstream() any {
return c.PacketConn
}

View file

@ -1,23 +0,0 @@
package deadline
import (
"github.com/sagernet/sing/common"
N "github.com/sagernet/sing/common/network"
)
type WithoutReadDeadline interface {
NeedAdditionalReadDeadline() bool
}
func NeedAdditionalReadDeadline(rawReader any) bool {
if deadlineReader, loaded := rawReader.(WithoutReadDeadline); loaded {
return deadlineReader.NeedAdditionalReadDeadline()
}
if upstream, hasUpstream := rawReader.(N.WithUpstreamReader); hasUpstream {
return NeedAdditionalReadDeadline(upstream.UpstreamReader())
}
if upstream, hasUpstream := rawReader.(common.WithUpstream); hasUpstream {
return NeedAdditionalReadDeadline(upstream.Upstream())
}
return false
}

View file

@ -1,61 +0,0 @@
package deadline
import (
"net"
"time"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
type Conn struct {
N.ExtendedConn
reader Reader
}
func NewConn(conn net.Conn) N.ExtendedConn {
if deadlineConn, isDeadline := conn.(*Conn); isDeadline {
return deadlineConn
}
return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)})
}
func NewFallbackConn(conn net.Conn) N.ExtendedConn {
if deadlineConn, isDeadline := conn.(*Conn); isDeadline {
return deadlineConn
}
return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)})
}
func (c *Conn) Read(p []byte) (n int, err error) {
return c.reader.Read(p)
}
func (c *Conn) ReadBuffer(buffer *buf.Buffer) error {
return c.reader.ReadBuffer(buffer)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.reader.SetReadDeadline(t)
}
func (c *Conn) ReaderReplaceable() bool {
return c.reader.ReaderReplaceable()
}
func (c *Conn) UpstreamReader() any {
return c.reader.UpstreamReader()
}
func (c *Conn) WriterReplaceable() bool {
return true
}
func (c *Conn) Upstream() any {
return c.ExtendedConn
}
func (c *Conn) NeedAdditionalReadDeadline() bool {
return false
}

View file

@ -1,57 +0,0 @@
package deadline
import (
"net"
"time"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type PacketConn struct {
N.NetPacketConn
reader PacketReader
}
func NewPacketConn(conn N.NetPacketConn) N.NetPacketConn {
if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline {
return deadlineConn
}
return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)})
}
func NewFallbackPacketConn(conn N.NetPacketConn) N.NetPacketConn {
if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline {
return deadlineConn
}
return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)})
}
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return c.reader.ReadFrom(p)
}
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
return c.reader.ReadPacket(buffer)
}
func (c *PacketConn) SetReadDeadline(t time.Time) error {
return c.reader.SetReadDeadline(t)
}
func (c *PacketConn) ReaderReplaceable() bool {
return c.reader.ReaderReplaceable()
}
func (c *PacketConn) WriterReplaceable() bool {
return true
}
func (c *PacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *PacketConn) NeedAdditionalReadDeadline() bool {
return false
}

View file

@ -1,159 +0,0 @@
package deadline
import (
"net"
"os"
"time"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type TimeoutPacketReader interface {
N.NetPacketReader
SetReadDeadline(t time.Time) error
}
type PacketReader interface {
TimeoutPacketReader
N.WithUpstreamReader
N.ReaderWithUpstream
}
type packetReader struct {
TimeoutPacketReader
deadline atomic.TypedValue[time.Time]
pipeDeadline pipeDeadline
result chan *packetReadResult
done chan struct{}
}
type packetReadResult struct {
buffer *buf.Buffer
destination M.Socksaddr
err error
}
func NewPacketReader(timeoutReader TimeoutPacketReader) PacketReader {
return &packetReader{
TimeoutPacketReader: timeoutReader,
pipeDeadline: makePipeDeadline(),
result: make(chan *packetReadResult, 1),
done: makeFilledChan(),
}
}
func (r *packetReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
default:
}
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
case <-r.done:
go r.pipeReadFrom(len(p))
}
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
}
}
func (r *packetReader) pipeReadFrom(pLen int) {
buffer := buf.NewSize(pLen)
n, addr, err := r.TimeoutPacketReader.ReadFrom(buffer.FreeBytes())
buffer.Truncate(n)
r.result <- &packetReadResult{
buffer: buffer,
destination: M.SocksaddrFromNet(addr),
err: err,
}
r.done <- struct{}{}
}
func (r *packetReader) pipeReturnFrom(result *packetReadResult, p []byte) (n int, addr net.Addr, err error) {
n = copy(p, result.buffer.Bytes())
if result.destination.IsValid() {
if result.destination.IsFqdn() {
addr = result.destination
} else {
addr = result.destination.UDPAddr()
}
}
result.buffer.Advance(n)
if result.buffer.IsEmpty() {
result.buffer.Release()
err = result.err
} else {
r.result <- result
}
return
}
func (r *packetReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
default:
}
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
case <-r.done:
go r.pipeReadFrom(buffer.FreeLen())
}
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
func (r *packetReader) pipeReturnFromBuffer(result *packetReadResult, buffer *buf.Buffer) (M.Socksaddr, error) {
n, _ := buffer.Write(result.buffer.Bytes())
result.buffer.Advance(n)
if !result.buffer.IsEmpty() {
r.result <- result
return result.destination, nil
} else {
result.buffer.Release()
return result.destination, result.err
}
}
func (r *packetReader) SetReadDeadline(t time.Time) error {
r.deadline.Store(t)
r.pipeDeadline.set(t)
return nil
}
func (r *packetReader) ReaderReplaceable() bool {
select {
case <-r.done:
r.done <- struct{}{}
default:
return false
}
select {
case result := <-r.result:
r.result <- result
return false
default:
}
return r.deadline.Load().IsZero()
}
func (r *packetReader) UpstreamReader() any {
return r.TimeoutPacketReader
}

View file

@ -1,101 +0,0 @@
package deadline
import (
"net"
"os"
"time"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
type fallbackPacketReader struct {
*packetReader
disablePipe atomic.Bool
inRead atomic.Bool
}
func NewFallbackPacketReader(timeoutReader TimeoutPacketReader) PacketReader {
return &fallbackPacketReader{packetReader: NewPacketReader(timeoutReader).(*packetReader)}
}
func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
default:
}
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
case <-r.done:
if r.disablePipe.Load() {
return r.TimeoutPacketReader.ReadFrom(p)
} else if r.deadline.Load().IsZero() {
r.done <- struct{}{}
r.inRead.Store(true)
defer r.inRead.Store(false)
n, addr, err = r.TimeoutPacketReader.ReadFrom(p)
return
}
go r.pipeReadFrom(len(p))
}
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
}
}
func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
default:
}
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
case <-r.done:
if r.disablePipe.Load() {
return r.TimeoutPacketReader.ReadPacket(buffer)
} else if r.deadline.Load().IsZero() {
r.done <- struct{}{}
r.inRead.Store(true)
defer r.inRead.Store(false)
destination, err = r.TimeoutPacketReader.ReadPacket(buffer)
return
}
go r.pipeReadFrom(buffer.FreeLen())
}
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
func (r *fallbackPacketReader) SetReadDeadline(t time.Time) error {
if r.disablePipe.Load() {
return r.TimeoutPacketReader.SetReadDeadline(t)
} else if r.inRead.Load() {
r.disablePipe.Store(true)
return r.TimeoutPacketReader.SetReadDeadline(t)
}
return r.packetReader.SetReadDeadline(t)
}
func (r *fallbackPacketReader) ReaderReplaceable() bool {
return r.disablePipe.Load() || r.packetReader.ReaderReplaceable()
}
func (r *fallbackPacketReader) UpstreamReader() any {
return r.packetReader.UpstreamReader()
}

View file

@ -1,84 +0,0 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package deadline
import (
"sync"
"time"
)
// pipeDeadline is an abstraction for handling timeouts.
type pipeDeadline struct {
mu sync.Mutex // Guards timer and cancel
timer *time.Timer
cancel chan struct{} // Must be non-nil
}
func makePipeDeadline() pipeDeadline {
return pipeDeadline{cancel: make(chan struct{})}
}
// set sets the point in time when the deadline will time out.
// A timeout event is signaled by closing the channel returned by waiter.
// Once a timeout has occurred, the deadline can be refreshed by specifying a
// t value in the future.
//
// A zero value for t prevents timeout.
func (d *pipeDeadline) set(t time.Time) {
d.mu.Lock()
defer d.mu.Unlock()
if d.timer != nil && !d.timer.Stop() {
<-d.cancel // Wait for the timer callback to finish and close cancel
}
d.timer = nil
// Time is zero, then there is no deadline.
closed := isClosedChan(d.cancel)
if t.IsZero() {
if closed {
d.cancel = make(chan struct{})
}
return
}
// Time in the future, setup a timer to cancel in the future.
if dur := time.Until(t); dur > 0 {
if closed {
d.cancel = make(chan struct{})
}
d.timer = time.AfterFunc(dur, func() {
close(d.cancel)
})
return
}
// Time in the past, so close immediately.
if !closed {
close(d.cancel)
}
}
// wait returns a channel that is closed when the deadline is exceeded.
func (d *pipeDeadline) wait() chan struct{} {
d.mu.Lock()
defer d.mu.Unlock()
return d.cancel
}
func isClosedChan(c <-chan struct{}) bool {
select {
case <-c:
return true
default:
return false
}
}
func makeFilledChan() chan struct{} {
ch := make(chan struct{}, 1)
ch <- struct{}{}
return ch
}

View file

@ -1,152 +0,0 @@
package deadline
import (
"io"
"os"
"time"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
type TimeoutReader interface {
io.Reader
SetReadDeadline(t time.Time) error
}
type Reader interface {
N.ExtendedReader
TimeoutReader
N.WithUpstreamReader
N.ReaderWithUpstream
}
type reader struct {
N.ExtendedReader
timeoutReader TimeoutReader
deadline atomic.TypedValue[time.Time]
pipeDeadline pipeDeadline
result chan *readResult
done chan struct{}
}
type readResult struct {
buffer *buf.Buffer
err error
}
func NewReader(timeoutReader TimeoutReader) Reader {
return &reader{
ExtendedReader: bufio.NewExtendedReader(timeoutReader),
timeoutReader: timeoutReader,
pipeDeadline: makePipeDeadline(),
result: make(chan *readResult, 1),
done: makeFilledChan(),
}
}
func (r *reader) Read(p []byte) (n int, err error) {
select {
case result := <-r.result:
return r.pipeReturn(result, p)
default:
}
select {
case result := <-r.result:
return r.pipeReturn(result, p)
case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
case <-r.done:
go r.pipeRead(len(p))
}
select {
case result := <-r.result:
return r.pipeReturn(result, p)
case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}
func (r *reader) pipeReturn(result *readResult, p []byte) (n int, err error) {
n = copy(p, result.buffer.Bytes())
result.buffer.Advance(n)
if result.buffer.IsEmpty() {
result.buffer.Release()
err = result.err
} else {
r.result <- result
}
return
}
func (r *reader) pipeRead(pLen int) {
buffer := buf.NewSize(pLen)
_, err := buffer.ReadOnceFrom(r.ExtendedReader)
r.result <- &readResult{
buffer: buffer,
err: err,
}
r.done <- struct{}{}
}
func (r *reader) ReadBuffer(buffer *buf.Buffer) error {
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
default:
}
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
case <-r.done:
go r.pipeRead(buffer.FreeLen())
}
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
}
}
func (r *reader) pipeReturnBuffer(result *readResult, buffer *buf.Buffer) error {
n, _ := buffer.Write(result.buffer.Bytes())
result.buffer.Advance(n)
if !result.buffer.IsEmpty() {
r.result <- result
return nil
} else {
result.buffer.Release()
return result.err
}
}
func (r *reader) SetReadDeadline(t time.Time) error {
r.deadline.Store(t)
r.pipeDeadline.set(t)
return nil
}
func (r *reader) ReaderReplaceable() bool {
select {
case <-r.done:
r.done <- struct{}{}
default:
return false
}
select {
case result := <-r.result:
r.result <- result
return false
default:
}
return r.deadline.Load().IsZero()
}
func (r *reader) UpstreamReader() any {
return r.ExtendedReader
}

View file

@ -1,98 +0,0 @@
package deadline
import (
"os"
"time"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
)
type fallbackReader struct {
*reader
disablePipe atomic.Bool
inRead atomic.Bool
}
func NewFallbackReader(timeoutReader TimeoutReader) Reader {
return &fallbackReader{reader: NewReader(timeoutReader).(*reader)}
}
func (r *fallbackReader) Read(p []byte) (n int, err error) {
select {
case result := <-r.result:
return r.pipeReturn(result, p)
default:
}
select {
case result := <-r.result:
return r.pipeReturn(result, p)
case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
case <-r.done:
if r.disablePipe.Load() {
return r.ExtendedReader.Read(p)
} else if r.deadline.Load().IsZero() {
r.done <- struct{}{}
r.inRead.Store(true)
defer r.inRead.Store(false)
n, err = r.ExtendedReader.Read(p)
return
}
go r.pipeRead(len(p))
}
select {
case result := <-r.result:
return r.pipeReturn(result, p)
case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}
func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error {
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
default:
}
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
case <-r.done:
if r.disablePipe.Load() {
return r.ExtendedReader.ReadBuffer(buffer)
} else if r.deadline.Load().IsZero() {
r.done <- struct{}{}
r.inRead.Store(true)
defer r.inRead.Store(false)
return r.ExtendedReader.ReadBuffer(buffer)
}
go r.pipeRead(buffer.FreeLen())
}
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
}
}
func (r *fallbackReader) SetReadDeadline(t time.Time) error {
if r.disablePipe.Load() {
return r.timeoutReader.SetReadDeadline(t)
} else if r.inRead.Load() {
r.disablePipe.Store(true)
return r.timeoutReader.SetReadDeadline(t)
}
return r.reader.SetReadDeadline(t)
}
func (r *fallbackReader) ReaderReplaceable() bool {
return r.disablePipe.Load() || r.reader.ReaderReplaceable()
}
func (r *fallbackReader) UpstreamReader() any {
return r.reader.UpstreamReader()
}

View file

@ -1,75 +0,0 @@
package deadline
import (
"net"
"sync"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/debug"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type SerialConn struct {
N.ExtendedConn
access sync.Mutex
}
func NewSerialConn(conn N.ExtendedConn) N.ExtendedConn {
if !debug.Enabled {
return conn
}
return &SerialConn{ExtendedConn: conn}
}
func (c *SerialConn) Read(p []byte) (n int, err error) {
if !c.access.TryLock() {
panic("concurrent read on deadline conn")
}
defer c.access.Unlock()
return c.ExtendedConn.Read(p)
}
func (c *SerialConn) ReadBuffer(buffer *buf.Buffer) error {
if !c.access.TryLock() {
panic("concurrent read on deadline conn")
}
defer c.access.Unlock()
return c.ExtendedConn.ReadBuffer(buffer)
}
func (c *SerialConn) Upstream() any {
return c.ExtendedConn
}
type SerialPacketConn struct {
N.NetPacketConn
access sync.Mutex
}
func NewSerialPacketConn(conn N.NetPacketConn) N.NetPacketConn {
if !debug.Enabled {
return conn
}
return &SerialPacketConn{NetPacketConn: conn}
}
func (c *SerialPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if !c.access.TryLock() {
panic("concurrent read on deadline conn")
}
defer c.access.Unlock()
return c.NetPacketConn.ReadFrom(p)
}
func (c *SerialPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if !c.access.TryLock() {
panic("concurrent read on deadline conn")
}
defer c.access.Unlock()
return c.NetPacketConn.ReadPacket(buffer)
}
func (c *SerialPacketConn) Upstream() any {
return c.NetPacketConn
}

View file

@ -1,104 +0,0 @@
package bufio
import (
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var _ N.NetPacketConn = (*FallbackPacketConn)(nil)
type FallbackPacketConn struct {
N.PacketConn
writer N.NetPacketWriter
}
func NewNetPacketConn(conn N.PacketConn) N.NetPacketConn {
if packetConn, loaded := conn.(N.NetPacketConn); loaded {
return packetConn
}
return &FallbackPacketConn{
PacketConn: conn,
writer: NewNetPacketWriter(conn),
}
}
func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
buffer := buf.With(p)
destination, err := c.ReadPacket(buffer)
if err != nil {
return
}
n = buffer.Len()
if buffer.Start() > 0 {
copy(p, buffer.Bytes())
}
addr = destination.UDPAddr()
return
}
func (c *FallbackPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return c.writer.WriteTo(p, addr)
}
func (c *FallbackPacketConn) ReaderReplaceable() bool {
return true
}
func (c *FallbackPacketConn) WriterReplaceable() bool {
return true
}
func (c *FallbackPacketConn) Upstream() any {
return c.PacketConn
}
func (c *FallbackPacketConn) UpstreamWriter() any {
return c.writer
}
var _ N.NetPacketWriter = (*FallbackPacketWriter)(nil)
type FallbackPacketWriter struct {
N.PacketWriter
frontHeadroom int
rearHeadroom int
}
func NewNetPacketWriter(writer N.PacketWriter) N.NetPacketWriter {
if packetWriter, loaded := writer.(N.NetPacketWriter); loaded {
return packetWriter
}
return &FallbackPacketWriter{
PacketWriter: writer,
frontHeadroom: N.CalculateFrontHeadroom(writer),
rearHeadroom: N.CalculateRearHeadroom(writer),
}
}
func (c *FallbackPacketWriter) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if c.frontHeadroom > 0 || c.rearHeadroom > 0 {
buffer := buf.NewSize(len(p) + c.frontHeadroom + c.rearHeadroom)
buffer.Resize(c.frontHeadroom, 0)
common.Must1(buffer.Write(p))
err = c.PacketWriter.WritePacket(buffer, M.SocksaddrFromNet(addr))
} else {
err = c.PacketWriter.WritePacket(buf.As(p), M.SocksaddrFromNet(addr))
}
if err != nil {
return
}
n = len(p)
return
}
func (c *FallbackPacketWriter) WriterReplaceable() bool {
return true
}
func (c *FallbackPacketWriter) Upstream() any {
return c.PacketWriter
}

View file

@ -10,14 +10,12 @@ import (
N "github.com/sagernet/sing/common/network"
)
// Deprecated: bad usage
func ReadBuffer(reader N.ExtendedReader, buffer *buf.Buffer) (n int, err error) {
n, err = reader.Read(buffer.FreeBytes())
buffer.Truncate(n)
return
}
// Deprecated: bad usage
func ReadPacket(reader N.PacketReader, buffer *buf.Buffer) (n int, addr net.Addr, err error) {
startLen := buffer.Len()
addr, err = reader.ReadPacket(buffer)
@ -25,45 +23,6 @@ func ReadPacket(reader N.PacketReader, buffer *buf.Buffer) (n int, addr net.Addr
return
}
func ReadBufferSize(reader io.Reader, bufferSize int) (buffer *buf.Buffer, err error) {
readWaiter, isReadWaiter := CreateReadWaiter(reader)
if isReadWaiter {
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
MTU: bufferSize,
})
return readWaiter.WaitReadBuffer()
}
buffer = buf.NewSize(bufferSize)
if extendedReader, isExtendedReader := reader.(N.ExtendedReader); isExtendedReader {
err = extendedReader.ReadBuffer(buffer)
} else {
_, err = buffer.ReadOnceFrom(reader)
}
if err != nil {
buffer.Release()
buffer = nil
}
return
}
func ReadPacketSize(reader N.PacketReader, packetSize int) (buffer *buf.Buffer, destination M.Socksaddr, err error) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(reader)
if isReadWaiter {
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
MTU: packetSize,
})
buffer, destination, err = readWaiter.WaitReadPacket()
return
}
buffer = buf.NewSize(packetSize)
destination, err = reader.ReadPacket(buffer)
if err != nil {
buffer.Release()
buffer = nil
}
return
}
func Write(writer io.Writer, data []byte) (n int, err error) {
if extendedWriter, isExtended := writer.(N.ExtendedWriter); isExtended {
return WriteBuffer(extendedWriter, buf.As(data))
@ -76,7 +35,13 @@ func WriteBuffer(writer N.ExtendedWriter, buffer *buf.Buffer) (n int, err error)
frontHeadroom := N.CalculateFrontHeadroom(writer)
rearHeadroom := N.CalculateRearHeadroom(writer)
if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() {
newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom)
bufferSize := N.CalculateMTU(nil, writer)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
newBuffer := buf.NewSize(bufferSize)
newBuffer.Resize(frontHeadroom, 0)
common.Must1(newBuffer.Write(buffer.Bytes()))
buffer.Release()
@ -102,7 +67,13 @@ func WritePacketBuffer(writer N.PacketWriter, buffer *buf.Buffer, destination M.
frontHeadroom := N.CalculateFrontHeadroom(writer)
rearHeadroom := N.CalculateRearHeadroom(writer)
if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() {
newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom)
bufferSize := N.CalculateMTU(nil, writer)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
newBuffer := buf.NewSize(bufferSize)
newBuffer.Resize(frontHeadroom, 0)
common.Must1(newBuffer.Write(buffer.Bytes()))
buffer.Release()

View file

@ -1,212 +0,0 @@
package bufio
import (
"net"
"net/netip"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type NATPacketConn interface {
N.NetPacketConn
UpdateDestination(destinationAddress netip.Addr)
}
func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &unidirectionalNATPacketConn{
NetPacketConn: conn,
origin: socksaddrWithoutPort(origin),
destination: socksaddrWithoutPort(destination),
}
}
func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &bidirectionalNATPacketConn{
NetPacketConn: conn,
origin: socksaddrWithoutPort(origin),
destination: socksaddrWithoutPort(destination),
}
}
func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &destinationNATPacketConn{
NetPacketConn: conn,
origin: origin,
destination: destination,
}
}
type unidirectionalNATPacketConn struct {
N.NetPacketConn
origin M.Socksaddr
destination M.Socksaddr
}
func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
}
func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
}
func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func (c *unidirectionalNATPacketConn) Upstream() any {
return c.NetPacketConn
}
type bidirectionalNATPacketConn struct {
N.NetPacketConn
origin M.Socksaddr
destination M.Socksaddr
}
func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
addr = destination.UDPAddr()
return
}
func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
}
func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return
}
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
return
}
func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
}
func (c *bidirectionalNATPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
type destinationNATPacketConn struct {
N.NetPacketConn
origin M.Socksaddr
destination M.Socksaddr
}
func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
if M.SocksaddrFromNet(addr) == c.origin {
addr = c.destination.UDPAddr()
}
return
}
func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if M.SocksaddrFromNet(addr) == c.destination {
addr = c.origin.UDPAddr()
}
return c.NetPacketConn.WriteTo(p, addr)
}
func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return
}
if destination == c.origin {
destination = c.destination
}
return
}
func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination == c.destination {
destination = c.origin
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
}
func (c *destinationNATPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *destinationNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
destination.Port = 0
return destination
}

View file

@ -1,39 +0,0 @@
package bufio
import (
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func (c *bidirectionalNATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaiter, bool) {
waiter, created := CreatePacketReadWaiter(c.NetPacketConn)
if !created {
return nil, false
}
return &waitBidirectionalNATPacketConn{c, waiter}, true
}
type waitBidirectionalNATPacketConn struct {
*bidirectionalNATPacketConn
readWaiter N.PacketReadWaiter
}
func (c *waitBidirectionalNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return c.readWaiter.InitializeReadWaiter(options)
}
func (c *waitBidirectionalNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, destination, err = c.readWaiter.WaitReadPacket()
if err != nil {
return
}
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
return
}

View file

@ -1,277 +0,0 @@
package bufio
import (
"context"
"crypto/md5"
"crypto/rand"
"errors"
"io"
"net"
"sync"
"testing"
"time"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/task"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
var (
group task.Group
serverConn net.Conn
clientConn net.Conn
)
group.Append0(func(ctx context.Context) error {
var serverErr error
serverConn, serverErr = listener.Accept()
return serverErr
})
group.Append0(func(ctx context.Context) error {
var clientErr error
clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
return clientErr
})
err = group.Run(context.Background())
require.NoError(t, err)
listener.Close()
t.Cleanup(func() {
serverConn.Close()
clientConn.Close()
})
return serverConn, clientConn
}
func UDPPipe(t *testing.T) (net.PacketConn, net.PacketConn, M.Socksaddr) {
serverConn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
clientConn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
return serverConn, clientConn, M.SocksaddrFromNet(clientConn.LocalAddr())
}
func Timeout(t *testing.T) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-ctx.Done():
return
case <-time.After(5 * time.Second):
t.Error("timeout")
}
}()
return cancel
}
type hashPair struct {
sendHash map[int][]byte
recvHash map[int][]byte
}
func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) {
pingCh := make(chan hashPair)
pongCh := make(chan hashPair)
test := func(t *testing.T) error {
defer close(pingCh)
defer close(pongCh)
pingOpen := false
pongOpen := false
var serverPair hashPair
var clientPair hashPair
for {
if pingOpen && pongOpen {
break
}
select {
case serverPair, pingOpen = <-pingCh:
assert.True(t, pingOpen)
case clientPair, pongOpen = <-pongCh:
assert.True(t, pongOpen)
case <-time.After(10 * time.Second):
return errors.New("timeout")
}
}
assert.Equal(t, serverPair.recvHash, clientPair.sendHash)
assert.Equal(t, serverPair.sendHash, clientPair.recvHash)
return nil
}
return pingCh, pongCh, test
}
func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error {
times := 100
chunkSize := int64(64 * 1024)
pingCh, pongCh, test := newLargeDataPair()
writeRandData := func(conn net.Conn) (map[int][]byte, error) {
buf := make([]byte, chunkSize)
hashMap := map[int][]byte{}
for i := 0; i < times; i++ {
if _, err := rand.Read(buf[1:]); err != nil {
return nil, err
}
buf[0] = byte(i)
hash := md5.Sum(buf)
hashMap[i] = hash[:]
if _, err := conn.Write(buf); err != nil {
return nil, err
}
}
return hashMap, nil
}
go func() {
hashMap := map[int][]byte{}
buf := make([]byte, chunkSize)
for i := 0; i < times; i++ {
_, err := io.ReadFull(outputConn, buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf)
hashMap[int(buf[0])] = hash[:]
}
sendHash, err := writeRandData(outputConn)
if err != nil {
t.Log(err.Error())
return
}
pingCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
go func() {
sendHash, err := writeRandData(inputConn)
if err != nil {
t.Log(err.Error())
return
}
hashMap := map[int][]byte{}
buf := make([]byte, chunkSize)
for i := 0; i < times; i++ {
_, err = io.ReadFull(inputConn, buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf)
hashMap[int(buf[0])] = hash[:]
}
pongCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
return test(t)
}
func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error {
rAddr := outputAddr.UDPAddr()
times := 50
chunkSize := 9000
pingCh, pongCh, test := newLargeDataPair()
writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) {
hashMap := map[int][]byte{}
mux := sync.Mutex{}
for i := 0; i < times; i++ {
buf := make([]byte, chunkSize)
if _, err := rand.Read(buf[1:]); err != nil {
t.Log(err.Error())
continue
}
buf[0] = byte(i)
hash := md5.Sum(buf)
mux.Lock()
hashMap[i] = hash[:]
mux.Unlock()
if _, err := pc.WriteTo(buf, addr); err != nil {
t.Log(err.Error())
}
time.Sleep(10 * time.Millisecond)
}
return hashMap, nil
}
go func() {
var (
lAddr net.Addr
err error
)
hashMap := map[int][]byte{}
buf := make([]byte, 64*1024)
for i := 0; i < times; i++ {
_, lAddr, err = outputConn.ReadFrom(buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf[:chunkSize])
hashMap[int(buf[0])] = hash[:]
}
sendHash, err := writeRandData(outputConn, lAddr)
if err != nil {
t.Log(err.Error())
return
}
pingCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
go func() {
sendHash, err := writeRandData(inputConn, rAddr)
if err != nil {
t.Log(err.Error())
return
}
hashMap := map[int][]byte{}
buf := make([]byte, 64*1024)
for i := 0; i < times; i++ {
_, _, err := inputConn.ReadFrom(buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf[:chunkSize])
hashMap[int(buf[0])] = hash[:]
}
pongCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
return test(t)
}

127
common/bufio/once.go Normal file
View file

@ -0,0 +1,127 @@
package bufio
import (
"io"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
)
func CopyTimes(dst io.Writer, src io.Reader, times int) (n int64, err error) {
return CopyExtendedTimes(NewExtendedWriter(N.UnwrapWriter(dst)), NewExtendedReader(N.UnwrapReader(src)), times)
}
func CopyExtendedTimes(dst N.ExtendedWriter, src N.ExtendedReader, times int) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
dstUnsafe := N.IsUnsafeWriter(dst)
var buffer *buf.Buffer
if !dstUnsafe {
_buffer := buf.StackNewSize(bufferSize)
defer common.KeepAlive(_buffer)
buffer = common.Dup(_buffer)
defer buffer.Release()
buffer.IncRef()
defer buffer.DecRef()
}
notFirstTime := true
for i := 0; i < times; i++ {
if dstUnsafe {
buffer = buf.NewSize(bufferSize)
}
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
err = src.ReadBuffer(readBuffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
return
}
type ReadFromWriter interface {
io.ReaderFrom
io.Writer
}
func ReadFrom0(readerFrom ReadFromWriter, reader io.Reader) (n int64, err error) {
n, err = CopyTimes(readerFrom, reader, 1)
if err != nil {
return
}
var rn int64
rn, err = readerFrom.ReadFrom(reader)
if err != nil {
return
}
n += rn
return
}
func ReadFromN(readerFrom ReadFromWriter, reader io.Reader, times int) (n int64, err error) {
n, err = CopyTimes(readerFrom, reader, times)
if err != nil {
return
}
var rn int64
rn, err = readerFrom.ReadFrom(reader)
if err != nil {
return
}
n += rn
return
}
type WriteToReader interface {
io.WriterTo
io.Reader
}
func WriteTo0(writerTo WriteToReader, writer io.Writer) (n int64, err error) {
n, err = CopyTimes(writer, writerTo, 1)
if err != nil {
return
}
var wn int64
wn, err = writerTo.WriteTo(writer)
if err != nil {
return
}
n += wn
return
}
func WriteToN(writerTo WriteToReader, writer io.Writer, times int) (n int64, err error) {
n, err = CopyTimes(writer, writerTo, times)
if err != nil {
return
}
var wn int64
wn, err = writerTo.WriteTo(writer)
if err != nil {
return
}
n += wn
return
}

View file

@ -1,79 +0,0 @@
package bufio
import (
"syscall"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
)
const maxSpliceSize = 1 << 20
func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
handed = true
var pipeFDs [2]int
err = unix.Pipe2(pipeFDs[:], syscall.O_CLOEXEC|syscall.O_NONBLOCK)
if err != nil {
return
}
defer unix.Close(pipeFDs[0])
defer unix.Close(pipeFDs[1])
_, _ = unix.FcntlInt(uintptr(pipeFDs[0]), unix.F_SETPIPE_SZ, maxSpliceSize)
var readN int
var readErr error
var writeSize int
var writeErr error
readFunc := func(fd uintptr) (done bool) {
p0, p1 := unix.Splice(int(fd), nil, pipeFDs[1], nil, maxSpliceSize, unix.SPLICE_F_NONBLOCK)
readN = int(p0)
readErr = p1
return readErr != unix.EAGAIN
}
writeFunc := func(fd uintptr) (done bool) {
for writeSize > 0 {
p0, p1 := unix.Splice(pipeFDs[0], nil, int(fd), nil, writeSize, unix.SPLICE_F_NONBLOCK|unix.SPLICE_F_MOVE)
writeN := int(p0)
writeErr = p1
if writeErr != nil {
return writeErr != unix.EAGAIN
}
writeSize -= writeN
}
return true
}
for {
err = source.Read(readFunc)
if err != nil {
readErr = err
}
if readErr != nil {
if readErr == unix.EINVAL || readErr == unix.ENOSYS {
handed = false
return
}
err = E.Cause(readErr, "splice read")
return
}
if readN == 0 {
return
}
writeSize = readN
err = destination.Write(writeFunc)
if err != nil {
writeErr = err
}
if writeErr != nil {
err = E.Cause(writeErr, "splice write")
return
}
for _, readCounter := range readCounters {
readCounter(int64(readN))
}
for _, writeCounter := range writeCounters {
writeCounter(int64(readN))
}
}
}

View file

@ -1,13 +0,0 @@
//go:build !linux
package bufio
import (
"syscall"
N "github.com/sagernet/sing/common/network"
)
func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
return
}

View file

@ -1,5 +0,0 @@
package bufio
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
//sys recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) [failretval == -1] = ws2_32.recv

View file

@ -12,7 +12,7 @@ import (
)
func NewVectorisedWriter(writer io.Writer) N.VectorisedWriter {
if vectorisedWriter, ok := CreateVectorisedWriter(N.UnwrapWriter(writer)); ok {
if vectorisedWriter, ok := CreateVectorisedWriter(writer); ok {
return vectorisedWriter
}
return &BufferedVectorisedWriter{upstream: writer}
@ -33,10 +33,10 @@ func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) {
case syscall.Conn:
rawConn, err := w.SyscallConn()
if err == nil {
return &SyscallVectorisedWriter{upstream: writer, rawConn: rawConn}, true
return &SyscallVectorisedWriter{writer, rawConn}, true
}
case syscall.RawConn:
return &SyscallVectorisedWriter{upstream: writer, rawConn: w}, true
return &SyscallVectorisedWriter{writer, w}, true
}
return nil, false
}
@ -48,10 +48,10 @@ func CreateVectorisedPacketWriter(writer any) (N.VectorisedPacketWriter, bool) {
case syscall.Conn:
rawConn, err := w.SyscallConn()
if err == nil {
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: rawConn}, true
return &SyscallVectorisedPacketWriter{writer, rawConn}, true
}
case syscall.RawConn:
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: w}, true
return &SyscallVectorisedPacketWriter{writer, w}, true
}
return nil, false
}
@ -65,16 +65,13 @@ type BufferedVectorisedWriter struct {
func (w *BufferedVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
defer buf.ReleaseMulti(buffers)
bufferLen := buf.LenMulti(buffers)
if bufferLen == 0 {
return common.Error(w.upstream.Write(nil))
} else if len(buffers) == 1 {
return common.Error(w.upstream.Write(buffers[0].Bytes()))
}
var bufferBytes []byte
if bufferLen > 65535 {
bufferBytes = make([]byte, bufferLen)
} else {
buffer := buf.NewSize(bufferLen)
_buffer := buf.StackNewSize(bufferLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
bufferBytes = buffer.FreeBytes()
}
@ -111,7 +108,6 @@ var _ N.VectorisedWriter = (*SyscallVectorisedWriter)(nil)
type SyscallVectorisedWriter struct {
upstream any
rawConn syscall.RawConn
syscallVectorisedWriterFields
}
func (w *SyscallVectorisedWriter) Upstream() any {
@ -127,7 +123,6 @@ var _ N.VectorisedPacketWriter = (*SyscallVectorisedPacketWriter)(nil)
type SyscallVectorisedPacketWriter struct {
upstream any
rawConn syscall.RawConn
syscallVectorisedWriterFields
}
func (w *SyscallVectorisedPacketWriter) Upstream() any {

View file

@ -1,60 +0,0 @@
package bufio
import (
"crypto/rand"
"io"
"testing"
"github.com/stretchr/testify/require"
)
func TestWriteVectorised(t *testing.T) {
t.Parallel()
inputConn, outputConn := TCPPipe(t)
vectorisedWriter, created := CreateVectorisedWriter(inputConn)
require.True(t, created)
require.NotNil(t, vectorisedWriter)
var bufA [1024]byte
var bufB [1024]byte
var bufC [2048]byte
_, err := io.ReadFull(rand.Reader, bufA[:])
require.NoError(t, err)
_, err = io.ReadFull(rand.Reader, bufB[:])
require.NoError(t, err)
copy(bufC[:], bufA[:])
copy(bufC[1024:], bufB[:])
finish := Timeout(t)
_, err = WriteVectorised(vectorisedWriter, [][]byte{bufA[:], bufB[:]})
require.NoError(t, err)
output := make([]byte, 2048)
_, err = io.ReadFull(outputConn, output)
finish()
require.NoError(t, err)
require.Equal(t, bufC[:], output)
}
func TestWriteVectorisedPacket(t *testing.T) {
t.Parallel()
inputConn, outputConn, outputAddr := UDPPipe(t)
vectorisedWriter, created := CreateVectorisedPacketWriter(inputConn)
require.True(t, created)
require.NotNil(t, vectorisedWriter)
var bufA [1024]byte
var bufB [1024]byte
var bufC [2048]byte
_, err := io.ReadFull(rand.Reader, bufA[:])
require.NoError(t, err)
_, err = io.ReadFull(rand.Reader, bufB[:])
require.NoError(t, err)
copy(bufC[:], bufA[:])
copy(bufC[1024:], bufB[:])
finish := Timeout(t)
_, err = WriteVectorisedPacket(vectorisedWriter, [][]byte{bufA[:], bufB[:]}, outputAddr)
require.NoError(t, err)
output := make([]byte, 2048)
n, _, err := outputConn.ReadFrom(output)
finish()
require.NoError(t, err)
require.Equal(t, 2048, n)
require.Equal(t, bufC[:], output)
}

View file

@ -3,8 +3,6 @@
package bufio
import (
"os"
"sync"
"unsafe"
"github.com/sagernet/sing/common/buf"
@ -13,81 +11,49 @@ import (
"golang.org/x/sys/unix"
)
type syscallVectorisedWriterFields struct {
access sync.Mutex
iovecList *[]unix.Iovec
}
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
var iovecList []unix.Iovec
if w.iovecList != nil {
iovecList = *w.iovecList
iovecList := make([]unix.Iovec, 0, len(buffers))
for _, buffer := range buffers {
var iovec unix.Iovec
iovec.Base = &buffer.Bytes()[0]
iovec.SetLen(buffer.Len())
iovecList = append(iovecList, iovec)
}
iovecList = iovecList[:0]
for index, buffer := range buffers {
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
iovecList[index].SetLen(buffer.Len())
}
if w.iovecList == nil {
w.iovecList = new([]unix.Iovec)
}
*w.iovecList = iovecList // cache
var innerErr unix.Errno
err := w.rawConn.Write(func(fd uintptr) (done bool) {
//nolint:staticcheck
//goland:noinspection GoDeprecation
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
})
if innerErr != 0 {
err = os.NewSyscallError("SYS_WRITEV", innerErr)
}
for index := range iovecList {
iovecList[index] = unix.Iovec{}
err = innerErr
}
return err
}
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
var iovecList []unix.Iovec
if w.iovecList != nil {
iovecList = *w.iovecList
var sockaddr unix.Sockaddr
if destination.IsIPv4() {
sockaddr = &unix.SockaddrInet4{
Port: int(destination.Port),
Addr: destination.Addr.As4(),
}
} else {
sockaddr = &unix.SockaddrInet6{
Port: int(destination.Port),
Addr: destination.Addr.As16(),
}
}
iovecList = iovecList[:0]
for index, buffer := range buffers {
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
iovecList[index].SetLen(buffer.Len())
}
if w.iovecList == nil {
w.iovecList = new([]unix.Iovec)
}
*w.iovecList = iovecList // cache
var innerErr error
err := w.rawConn.Write(func(fd uintptr) (done bool) {
var msg unix.Msghdr
name, nameLen := ToSockaddr(destination.AddrPort())
msg.Name = (*byte)(name)
msg.Namelen = nameLen
if len(iovecList) > 0 {
msg.Iov = &iovecList[0]
msg.SetIovlen(len(iovecList))
}
_, innerErr = sendmsg(int(fd), &msg, 0)
_, innerErr = unix.SendmsgBuffers(int(fd), buf.ToSliceMulti(buffers), nil, sockaddr, 0)
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
for index := range iovecList {
iovecList[index] = unix.Iovec{}
}
return err
}
//go:linkname sendmsg golang.org/x/sys/unix.sendmsg
func sendmsg(s int, msg *unix.Msghdr, flags int) (n int, err error)

View file

@ -1,93 +1,62 @@
package bufio
import (
"sync"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/windows"
)
type syscallVectorisedWriterFields struct {
access sync.Mutex
iovecList *[]windows.WSABuf
}
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
var iovecList []windows.WSABuf
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for _, buffer := range buffers {
iovecList = append(iovecList, windows.WSABuf{
Buf: &buffer.Bytes()[0],
iovecList := make([]*windows.WSABuf, len(buffers))
for i, buffer := range buffers {
iovecList[i] = &windows.WSABuf{
Len: uint32(buffer.Len()),
})
Buf: &buffer.Bytes()[0],
}
}
if w.iovecList == nil {
w.iovecList = new([]windows.WSABuf)
}
*w.iovecList = iovecList // cache
var n uint32
var innerErr error
err := w.rawConn.Write(func(fd uintptr) (done bool) {
innerErr = windows.WSASend(windows.Handle(fd), &iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
innerErr = windows.WSASend(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
return innerErr != windows.WSAEWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
for index := range iovecList {
iovecList[index] = windows.WSABuf{}
}
return err
}
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
var iovecList []windows.WSABuf
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for _, buffer := range buffers {
iovecList = append(iovecList, windows.WSABuf{
Buf: &buffer.Bytes()[0],
iovecList := make([]*windows.WSABuf, len(buffers))
for i, buffer := range buffers {
iovecList[i] = &windows.WSABuf{
Len: uint32(buffer.Len()),
})
Buf: &buffer.Bytes()[0],
}
}
if w.iovecList == nil {
w.iovecList = new([]windows.WSABuf)
var sockaddr windows.Sockaddr
if destination.IsIPv4() {
sockaddr = &windows.SockaddrInet4{
Port: int(destination.Port),
Addr: destination.Addr.As4(),
}
} else {
sockaddr = &windows.SockaddrInet6{
Port: int(destination.Port),
Addr: destination.Addr.As16(),
}
}
*w.iovecList = iovecList // cache
var n uint32
var innerErr error
err := w.rawConn.Write(func(fd uintptr) (done bool) {
name, nameLen := ToSockaddr(destination.AddrPort())
innerErr = windows.WSASendTo(
windows.Handle(fd),
&iovecList[0],
uint32(len(iovecList)),
&n,
0,
(*windows.RawSockaddrAny)(name),
nameLen,
nil,
nil)
innerErr = windows.WSASendto(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, sockaddr, nil, nil)
return innerErr != windows.WSAEWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
for index := range iovecList {
iovecList[index] = windows.WSABuf{}
}
return err
}

View file

@ -1,35 +0,0 @@
package bufio
import (
"io"
N "github.com/sagernet/sing/common/network"
)
func CreateReadWaiter(reader io.Reader) (N.ReadWaiter, bool) {
reader = N.UnwrapReader(reader)
if readWaiter, isReadWaiter := reader.(N.ReadWaiter); isReadWaiter {
return readWaiter, true
}
if readWaitCreator, isCreator := reader.(N.ReadWaitCreator); isCreator {
return readWaitCreator.CreateReadWaiter()
}
if readWaiter, created := createSyscallReadWaiter(reader); created {
return readWaiter, true
}
return nil, false
}
func CreatePacketReadWaiter(reader N.PacketReader) (N.PacketReadWaiter, bool) {
reader = N.UnwrapPacketReader(reader)
if readWaiter, isReadWaiter := reader.(N.PacketReadWaiter); isReadWaiter {
return readWaiter, true
}
if readWaitCreator, isCreator := reader.(N.PacketReadWaitCreator); isCreator {
return readWaitCreator.CreateReadWaiter()
}
if readWaiter, created := createSyscallPacketReadWaiter(reader); created {
return readWaiter, true
}
return nil, false
}

View file

@ -1,57 +0,0 @@
// Code generated by 'go generate'; DO NOT EDIT.
package bufio
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procrecv = modws2_32.NewProc("recv")
)
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
r0, _, e1 := syscall.Syscall6(procrecv.Addr(), 4, uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags), 0, 0)
n = int32(r0)
if n == -1 {
err = errnoErr(e1)
}
return
}

View file

@ -258,14 +258,6 @@ func (c *LruCache[K, V]) Delete(key K) {
c.mu.Unlock()
}
func (c *LruCache[K, V]) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
for element := c.lru.Front(); element != nil; element = element.Next() {
c.deleteElement(element)
}
}
func (c *LruCache[K, V]) maybeDeleteOldest() {
if !c.staleReturn && c.maxAge > 0 {
now := time.Now().Unix()

View file

@ -1,65 +0,0 @@
package canceler
import (
"context"
"net"
"os"
"time"
"github.com/sagernet/sing/common"
)
type Instance struct {
ctx context.Context
cancelFunc common.ContextCancelCauseFunc
timer *time.Timer
timeout time.Duration
}
func New(ctx context.Context, cancelFunc common.ContextCancelCauseFunc, timeout time.Duration) *Instance {
instance := &Instance{
ctx,
cancelFunc,
time.NewTimer(timeout),
timeout,
}
go instance.wait()
return instance
}
func (i *Instance) Update() bool {
if !i.timer.Stop() {
return false
}
if !i.timer.Reset(i.timeout) {
return false
}
return true
}
func (i *Instance) Timeout() time.Duration {
return i.timeout
}
func (i *Instance) SetTimeout(timeout time.Duration) bool {
i.timeout = timeout
return i.Update()
}
func (i *Instance) wait() {
select {
case <-i.timer.C:
case <-i.ctx.Done():
}
i.CloseWithError(os.ErrDeadlineExceeded)
}
func (i *Instance) Close() error {
i.CloseWithError(net.ErrClosed)
return nil
}
func (i *Instance) CloseWithError(err error) {
i.timer.Stop()
i.cancelFunc(err)
}

View file

@ -1,76 +0,0 @@
package canceler
import (
"context"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type PacketConn interface {
N.PacketConn
Timeout() time.Duration
SetTimeout(timeout time.Duration) bool
}
type TimerPacketConn struct {
N.PacketConn
instance *Instance
}
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
oldTimeout := timeoutConn.Timeout()
if oldTimeout > 0 && timeout >= oldTimeout {
return ctx, conn
}
if timeoutConn.SetTimeout(timeout) {
return ctx, conn
}
}
err := conn.SetReadDeadline(time.Time{})
if err == nil {
return NewTimeoutPacketConn(ctx, conn, timeout)
}
ctx, cancel := common.ContextWithCancelCause(ctx)
instance := New(ctx, cancel, timeout)
return ctx, &TimerPacketConn{conn, instance}
}
func (c *TimerPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.PacketConn.ReadPacket(buffer)
if err == nil {
c.instance.Update()
}
return
}
func (c *TimerPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
err := c.PacketConn.WritePacket(buffer, destination)
if err == nil {
c.instance.Update()
}
return err
}
func (c *TimerPacketConn) Timeout() time.Duration {
return c.instance.Timeout()
}
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
return c.instance.SetTimeout(timeout)
}
func (c *TimerPacketConn) Close() error {
return common.Close(
c.PacketConn,
c.instance,
)
}
func (c *TimerPacketConn) Upstream() any {
return c.PacketConn
}

View file

@ -1,76 +0,0 @@
package canceler
import (
"context"
"net"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type TimeoutPacketConn struct {
N.PacketConn
timeout time.Duration
cancel common.ContextCancelCauseFunc
active time.Time
}
func NewTimeoutPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) {
ctx, cancel := common.ContextWithCancelCause(ctx)
return ctx, &TimeoutPacketConn{
PacketConn: conn,
timeout: timeout,
cancel: cancel,
}
}
func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
for {
err = c.PacketConn.SetReadDeadline(time.Now().Add(c.timeout))
if err != nil {
return
}
destination, err = c.PacketConn.ReadPacket(buffer)
if err == nil {
c.active = time.Now()
return
} else if E.IsTimeout(err) {
if time.Since(c.active) > c.timeout {
c.cancel(err)
return
}
} else {
return
}
}
}
func (c *TimeoutPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
err := c.PacketConn.WritePacket(buffer, destination)
if err == nil {
c.active = time.Now()
}
return err
}
func (c *TimeoutPacketConn) Timeout() time.Duration {
return c.timeout
}
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
c.timeout = timeout
return c.PacketConn.SetReadDeadline(time.Now()) == nil
}
func (c *TimeoutPacketConn) Close() error {
c.cancel(net.ErrClosed)
return c.PacketConn.Close()
}
func (c *TimeoutPacketConn) Upstream() any {
return c.PacketConn
}

View file

@ -1,11 +0,0 @@
//go:build go1.21
package common
func ClearArray[T ~[]E, E any](t T) {
clear(t)
}
func ClearMap[T ~map[K]V, K comparable, V any](t T) {
clear(t)
}

View file

@ -1,16 +0,0 @@
//go:build !go1.21
package common
func ClearArray[T ~[]E, E any](t T) {
var defaultValue E
for i := range t {
t[i] = defaultValue
}
}
func ClearMap[T ~map[K]V, K comparable, V any](t T) {
for k := range t {
delete(t, k)
}
}

View file

@ -20,8 +20,8 @@ func Any[T any](array []T, block func(it T) bool) bool {
}
func AnyIndexed[T any](array []T, block func(index int, it T) bool) bool {
for index, it := range array {
if block(index, it) {
for i, it := range array {
if block(i, it) {
return true
}
}
@ -38,8 +38,8 @@ func All[T any](array []T, block func(it T) bool) bool {
}
func AllIndexed[T any](array []T, block func(index int, it T) bool) bool {
for index, it := range array {
if !block(index, it) {
for i, it := range array {
if !block(i, it) {
return false
}
}
@ -47,8 +47,8 @@ func AllIndexed[T any](array []T, block func(index int, it T) bool) bool {
}
func Contains[T comparable](arr []T, target T) bool {
for index := range arr {
if target == arr[index] {
for i := range arr {
if target == arr[i] {
return true
}
}
@ -81,8 +81,8 @@ func FlatMap[T any, N any](arr []T, block func(it T) []N) []N {
func FlatMapIndexed[T any, N any](arr []T, block func(index int, it T) []N) []N {
var retAddr []N
for index, item := range arr {
retAddr = append(retAddr, block(index, item)...)
for i, item := range arr {
retAddr = append(retAddr, block(i, item)...)
}
return retAddr
}
@ -113,8 +113,8 @@ func FilterNotDefault[T comparable](arr []T) []T {
func FilterIndexed[T any](arr []T, block func(index int, it T) bool) []T {
var retArr []T
for index, it := range arr {
if block(index, it) {
for i, it := range arr {
if block(i, it) {
retArr = append(retArr, it)
}
}
@ -130,55 +130,22 @@ func Find[T any](arr []T, block func(it T) bool) T {
return DefaultValue[T]()
}
func FindIndexed[T any](arr []T, block func(index int, it T) bool) T {
for index, it := range arr {
if block(index, it) {
return it
}
}
return DefaultValue[T]()
}
func Index[T any](arr []T, block func(it T) bool) int {
for index, it := range arr {
if block(it) {
return index
}
}
return -1
}
func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
for index, it := range arr {
if block(index, it) {
return index
}
}
return -1
}
func Equal[S ~[]E, E comparable](s1, s2 S) bool {
if len(s1) != len(s2) {
return false
}
for i := range s1 {
if s1[i] != s2[i] {
return false
}
}
return true
}
//go:norace
func Dup[T any](obj T) T {
pointer := uintptr(unsafe.Pointer(&obj))
//nolint:staticcheck
//goland:noinspection GoVetUnsafePointer
return *(*T)(unsafe.Pointer(pointer))
if UnsafeBuffer {
pointer := uintptr(unsafe.Pointer(&obj))
//nolint:staticcheck
//goland:noinspection GoVetUnsafePointer
return *(*T)(unsafe.Pointer(pointer))
} else {
return obj
}
}
func KeepAlive(obj any) {
runtime.KeepAlive(obj)
if UnsafeBuffer {
runtime.KeepAlive(obj)
}
}
func Uniq[T comparable](arr []T) []T {
@ -280,14 +247,6 @@ func Reverse[T any](arr []T) []T {
return arr
}
func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K {
ret := make(map[V]K, len(m))
for k, v := range m {
ret[v] = k
}
return ret
}
func Done(ctx context.Context) bool {
select {
case <-ctx.Done():
@ -309,18 +268,16 @@ func Must(errs ...error) {
}
}
func Must1[T any](result T, err error) T {
func Must1(_ any, err error) {
if err != nil {
panic(err)
}
return result
}
func Must2[T any, T2 any](result T, result2 T2, err error) (T, T2) {
func Must2(_, _ any, err error) {
if err != nil {
panic(err)
}
return result, result2
}
// Deprecated: use E.Errors
@ -356,10 +313,6 @@ func DefaultValue[T any]() T {
return defaultValue
}
func Ptr[T any](obj T) *T {
return &obj
}
func Close(closers ...any) error {
var retErr error
for _, closer := range closers {
@ -382,3 +335,22 @@ func Close(closers ...any) error {
}
return retErr
}
type Starter interface {
Start() error
}
func Start(starters ...any) error {
for _, rawStarter := range starters {
if rawStarter == nil {
continue
}
if starter, isStarter := rawStarter.(Starter); isStarter {
err := starter.Start()
if err != nil {
return err
}
}
}
return nil
}

View file

@ -1,23 +0,0 @@
package common
import (
"context"
"reflect"
)
// Deprecated: not used
func SelectContext(contextList []context.Context) (int, error) {
if len(contextList) == 1 {
<-contextList[0].Done()
return 0, contextList[0].Err()
}
chosen, _, _ := reflect.Select(Map(Filter(contextList, func(it context.Context) bool {
return it.Done() != nil
}), func(it context.Context) reflect.SelectCase {
return reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(it.Done()),
}
}))
return chosen, contextList[chosen].Err()
}

View file

@ -1,14 +0,0 @@
//go:build go1.20
package common
import "context"
type (
ContextCancelCauseFunc = context.CancelCauseFunc
)
var (
ContextWithCancelCause = context.WithCancelCause
ContextCause = context.Cause
)

View file

@ -1,16 +0,0 @@
//go:build !go1.20
package common
import "context"
type ContextCancelCauseFunc func(cause error)
func ContextWithCancelCause(parentContext context.Context) (context.Context, ContextCancelCauseFunc) {
ctx, cancel := context.WithCancel(parentContext)
return ctx, func(_ error) { cancel() }
}
func ContextCause(context context.Context) error {
return context.Err()
}

View file

@ -1,35 +1,59 @@
package control
import (
"os"
"runtime"
"syscall"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func BindToInterface(finder InterfaceFinder, interfaceName string, interfaceIndex int) Func {
return func(network, address string, conn syscall.RawConn) error {
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false)
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
}
}
func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int, err error)) Func {
func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int)) Func {
return func(network, address string, conn syscall.RawConn) error {
interfaceName, interfaceIndex, err := block(network, address)
if err != nil {
return err
}
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex, false)
interfaceName, interfaceIndex := block(network, address)
return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex)
}
}
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
if interfaceName == "" && interfaceIndex == -1 {
return E.New("interface not found: ", interfaceName)
}
const useInterfaceName = runtime.GOOS == "linux" || runtime.GOOS == "android"
func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) {
return nil
}
return bindToInterface(conn, network, address, finder, interfaceName, interfaceIndex, preferInterfaceName)
if interfaceName == "" && interfaceIndex == -1 {
return nil
}
if interfaceName != "" && useInterfaceName || interfaceIndex != -1 && !useInterfaceName {
return bindToInterface(conn, network, address, interfaceName, interfaceIndex)
}
if finder == nil {
return os.ErrInvalid
}
var err error
if useInterfaceName {
interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex)
} else {
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
}
if err != nil {
return err
}
if useInterfaceName {
if interfaceName == "" {
return nil
}
} else {
if interfaceIndex == -1 {
return nil
}
}
return bindToInterface(conn, network, address, interfaceName, interfaceIndex)
}

View file

@ -1,24 +1,16 @@
package control
import (
"os"
"syscall"
"golang.org/x/sys/unix"
)
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
if interfaceIndex == -1 {
return nil
}
return Raw(conn, func(fd uintptr) error {
if interfaceIndex == -1 {
if finder == nil {
return os.ErrInvalid
}
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
switch network {
case "tcp6", "udp6":
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, interfaceIndex)

View file

@ -1,59 +1,30 @@
package control
import (
"net"
"net/netip"
"unsafe"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
)
import "net"
type InterfaceFinder interface {
Update() error
Interfaces() []Interface
ByName(name string) (*Interface, error)
ByIndex(index int) (*Interface, error)
ByAddr(addr netip.Addr) (*Interface, error)
InterfaceIndexByName(name string) (int, error)
InterfaceNameByIndex(index int) (string, error)
}
type Interface struct {
Index int
MTU int
Name string
HardwareAddr net.HardwareAddr
Flags net.Flags
Addresses []netip.Prefix
func DefaultInterfaceFinder() InterfaceFinder {
return (*netInterfaceFinder)(nil)
}
func (i Interface) Equals(other Interface) bool {
return i.Index == other.Index &&
i.MTU == other.MTU &&
i.Name == other.Name &&
common.Equal(i.HardwareAddr, other.HardwareAddr) &&
i.Flags == other.Flags &&
common.Equal(i.Addresses, other.Addresses)
}
type netInterfaceFinder struct{}
func (i Interface) NetInterface() net.Interface {
return *(*net.Interface)(unsafe.Pointer(&i))
}
func InterfaceFromNet(iif net.Interface) (Interface, error) {
ifAddrs, err := iif.Addrs()
func (w *netInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
netInterface, err := net.InterfaceByName(name)
if err != nil {
return Interface{}, err
return 0, err
}
return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil
return netInterface.Index, nil
}
func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface {
return Interface{
Index: iif.Index,
MTU: iif.MTU,
Name: iif.Name,
HardwareAddr: iif.HardwareAddr,
Flags: iif.Flags,
Addresses: addresses,
func (w *netInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
netInterface, err := net.InterfaceByIndex(index)
if err != nil {
return "", err
}
return netInterface.Name, nil
}

View file

@ -1,89 +0,0 @@
package control
import (
"net"
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
)
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
type DefaultInterfaceFinder struct {
interfaces []Interface
}
func NewDefaultInterfaceFinder() *DefaultInterfaceFinder {
return &DefaultInterfaceFinder{}
}
func (f *DefaultInterfaceFinder) Update() error {
netIfs, err := net.Interfaces()
if err != nil {
return err
}
interfaces := make([]Interface, 0, len(netIfs))
for _, netIf := range netIfs {
var iif Interface
iif, err = InterfaceFromNet(netIf)
if err != nil {
return err
}
interfaces = append(interfaces, iif)
}
f.interfaces = interfaces
return nil
}
func (f *DefaultInterfaceFinder) UpdateInterfaces(interfaces []Interface) {
f.interfaces = interfaces
}
func (f *DefaultInterfaceFinder) Interfaces() []Interface {
return f.interfaces
}
func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) {
for _, netInterface := range f.interfaces {
if netInterface.Name == name {
return &netInterface, nil
}
}
_, err := net.InterfaceByName(name)
if err == nil {
err = f.Update()
if err != nil {
return nil, err
}
return f.ByName(name)
}
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
}
func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) {
for _, netInterface := range f.interfaces {
if netInterface.Index == index {
return &netInterface, nil
}
}
_, err := net.InterfaceByIndex(index)
if err == nil {
err = f.Update()
if err != nil {
return nil, err
}
return f.ByIndex(index)
}
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
}
func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) {
for _, netInterface := range f.interfaces {
for _, prefix := range netInterface.Addresses {
if prefix.Contains(addr) {
return &netInterface, nil
}
}
}
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: addr.AsSlice()}, Err: E.New("no such network interface")}
}

View file

@ -1,42 +1,13 @@
package control
import (
"os"
"syscall"
"github.com/sagernet/sing/common/atomic"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/sys/unix"
)
var ifIndexDisabled atomic.Bool
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
return Raw(conn, func(fd uintptr) error {
if !preferInterfaceName && !ifIndexDisabled.Load() {
if interfaceIndex == -1 {
if interfaceName == "" {
return os.ErrInvalid
}
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
if err == nil {
return nil
} else if E.IsMulti(err, unix.ENOPROTOOPT, unix.EINVAL) {
ifIndexDisabled.Store(true)
} else {
return err
}
}
if interfaceName == "" {
return os.ErrInvalid
}
return unix.BindToDevice(int(fd), interfaceName)
})
}

View file

@ -4,6 +4,6 @@ package control
import "syscall"
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
return nil
}

View file

@ -2,25 +2,14 @@ package control
import (
"encoding/binary"
"os"
"syscall"
"unsafe"
M "github.com/sagernet/sing/common/metadata"
)
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error {
return Raw(conn, func(fd uintptr) error {
if interfaceIndex == -1 {
if finder == nil {
return os.ErrInvalid
}
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
handle := syscall.Handle(fd)
if M.ParseSocksaddr(address).AddrString() == "" {
err := bind4(handle, interfaceIndex)

View file

@ -1,33 +0,0 @@
package control
import (
"os"
"syscall"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
)
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
if N.NetworkName(network) != N.NetworkUDP {
return nil
}
return Raw(conn, func(fd uintptr) error {
if network == "udp" || network == "udp4" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
}
}
if network == "udp" || network == "udp6" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
}
}
return nil
})
}
}

View file

@ -11,19 +11,17 @@ import (
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
if N.NetworkName(network) != N.NetworkUDP {
switch N.NetworkName(network) {
case N.NetworkUDP:
default:
return nil
}
return Raw(conn, func(fd uintptr) error {
if network == "udp" || network == "udp4" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
if network == "udp" || network == "udp6" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
if err != nil {
if network == "udp6" {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}

View file

@ -1,4 +1,4 @@
//go:build !(linux || windows || darwin)
//go:build !((go1.19 && unix) || (!go1.19 && (linux || darwin)) || windows)
package control

View file

@ -0,0 +1,28 @@
//go:build (go1.19 && unix && !linux) || (!go1.19 && darwin)
package control
import (
"os"
"syscall"
"golang.org/x/sys/unix"
)
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
return Raw(conn, func(fd uintptr) error {
switch network {
case "udp4":
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1); err != nil {
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
}
case "udp6":
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1); err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
}
}
return nil
})
}
}

View file

@ -25,19 +25,17 @@ const (
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
if N.NetworkName(network) != N.NetworkUDP {
switch N.NetworkName(network) {
case N.NetworkUDP:
default:
return nil
}
return Raw(conn, func(fd uintptr) error {
if network == "udp" || network == "udp4" {
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
if network == "udp" || network == "udp6" {
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO)
if err != nil {
if network == "udp6" {
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}

View file

@ -3,7 +3,6 @@ package control
import (
"syscall"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
@ -31,14 +30,6 @@ func Conn(conn syscall.Conn, block func(fd uintptr) error) error {
return Raw(rawConn, block)
}
func Conn0[T any](conn syscall.Conn, block func(fd uintptr) (T, error)) (T, error) {
rawConn, err := conn.SyscallConn()
if err != nil {
return common.DefaultValue[T](), err
}
return Raw0[T](rawConn, block)
}
func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
var innerErr error
err := rawConn.Control(func(fd uintptr) {
@ -46,14 +37,3 @@ func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
})
return E.Errors(innerErr, err)
}
func Raw0[T any](rawConn syscall.RawConn, block func(fd uintptr) (T, error)) (T, error) {
var (
value T
innerErr error
)
err := rawConn.Control(func(fd uintptr) {
value, innerErr = block(fd)
})
return value, E.Errors(innerErr, err)
}

View file

@ -4,10 +4,10 @@ import (
"syscall"
)
func RoutingMark(mark uint32) Func {
func RoutingMark(mark int) Func {
return func(network, address string, conn syscall.RawConn) error {
return Raw(conn, func(fd uintptr) error {
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(mark))
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark)
})
}
}

View file

@ -2,6 +2,6 @@
package control
func RoutingMark(mark uint32) Func {
func RoutingMark(mark int) Func {
return nil
}

View file

@ -1,58 +0,0 @@
package control
import (
"encoding/binary"
"net"
"net/netip"
"syscall"
"unsafe"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
const (
PF_OUT = 0x2
DIOCNATLOOK = 0xc0544417
)
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
pfFd, err := syscall.Open("/dev/pf", 0, syscall.O_RDONLY)
if err != nil {
return netip.AddrPort{}, err
}
defer syscall.Close(pfFd)
nl := struct {
saddr, daddr, rsaddr, rdaddr [16]byte
sxport, dxport, rsxport, rdxport [4]byte
af, proto, protoVariant, direction uint8
}{
af: syscall.AF_INET,
proto: syscall.IPPROTO_TCP,
direction: PF_OUT,
}
localAddr := M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
removeAddr := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
if localAddr.IsIPv4() {
copy(nl.saddr[:net.IPv4len], removeAddr.Addr.AsSlice())
copy(nl.daddr[:net.IPv4len], localAddr.Addr.AsSlice())
nl.af = syscall.AF_INET
} else {
copy(nl.saddr[:], removeAddr.Addr.AsSlice())
copy(nl.daddr[:], localAddr.Addr.AsSlice())
nl.af = syscall.AF_INET6
}
binary.BigEndian.PutUint16(nl.sxport[:], removeAddr.Port)
binary.BigEndian.PutUint16(nl.dxport[:], localAddr.Port)
if _, _, errno := unix.Syscall(syscall.SYS_IOCTL, uintptr(pfFd), DIOCNATLOOK, uintptr(unsafe.Pointer(&nl))); errno != 0 {
return netip.AddrPort{}, errno
}
var address netip.Addr
if nl.af == unix.AF_INET {
address = M.AddrFromIP(nl.rdaddr[:net.IPv4len])
} else {
address = netip.AddrFrom16(nl.rdaddr)
}
return netip.AddrPortFrom(address, binary.BigEndian.Uint16(nl.rdxport[:])), nil
}

View file

@ -1,38 +0,0 @@
package control
import (
"encoding/binary"
"net"
"net/netip"
"os"
"syscall"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
syscallConn, loaded := common.Cast[syscall.Conn](conn)
if !loaded {
return netip.AddrPort{}, os.ErrInvalid
}
return Conn0[netip.AddrPort](syscallConn, func(fd uintptr) (netip.AddrPort, error) {
if M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap().IsIPv4() {
raw, err := unix.GetsockoptIPv6Mreq(int(fd), unix.IPPROTO_IP, unix.SO_ORIGINAL_DST)
if err != nil {
return netip.AddrPort{}, err
}
return netip.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
} else {
raw, err := unix.GetsockoptIPv6MTUInfo(int(fd), unix.IPPROTO_IPV6, unix.SO_ORIGINAL_DST)
if err != nil {
return netip.AddrPort{}, err
}
var port [2]byte
binary.BigEndian.PutUint16(port[:], raw.Addr.Port)
return netip.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), binary.LittleEndian.Uint16(port[:])), nil
}
})
}

View file

@ -1,13 +0,0 @@
//go:build !linux && !darwin
package control
import (
"net"
"net/netip"
"os"
)
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
return netip.AddrPort{}, os.ErrInvalid
}

View file

@ -1,30 +0,0 @@
package control
import (
"syscall"
"time"
_ "unsafe"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
)
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
return func(network, address string, conn syscall.RawConn) error {
if N.NetworkName(network) != N.NetworkTCP {
return nil
}
return Raw(conn, func(fd uintptr) error {
return E.Errors(
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPIDLE, int(roundDurationUp(idle, time.Second))),
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPINTVL, int(roundDurationUp(interval, time.Second))),
)
})
}
}
func roundDurationUp(d time.Duration, to time.Duration) time.Duration {
return (d + to - 1) / to
}

View file

@ -1,11 +0,0 @@
//go:build !linux
package control
import (
"time"
)
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
return nil
}

View file

@ -1,56 +0,0 @@
package control
import (
"encoding/binary"
"net/netip"
"syscall"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
func TProxy(fd uintptr, family int) error {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
if err == nil {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
}
if err == nil && family == unix.AF_INET6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
}
if err == nil {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
}
if err == nil && family == unix.AF_INET6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
}
return err
}
func TProxyWriteBack() Func {
return func(network, address string, conn syscall.RawConn) error {
return Raw(conn, func(fd uintptr) error {
if M.ParseSocksaddr(address).Addr.Is6() {
return syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
} else {
return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
}
})
}
}
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
controlMessages, err := unix.ParseSocketControlMessage(oob)
if err != nil {
return netip.AddrPort{}, err
}
for _, message := range controlMessages {
if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR {
return netip.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil
} else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR {
return netip.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil
}
}
return netip.AddrPort{}, E.New("not found")
}

Some files were not shown because too many files have changed in this diff Show more