mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 20:37:40 +03:00
Compare commits
106 commits
v0.5.0-alp
...
dev
Author | SHA1 | Date | |
---|---|---|---|
|
159e489fc3 | ||
|
d39c2c2fdd | ||
|
ea82ac275f | ||
|
ea0ac932ae | ||
|
2b41455f5a | ||
|
23b0180a1b | ||
|
ce1b4851a4 | ||
|
2238a05966 | ||
|
b55d1c78b3 | ||
|
d54716612c | ||
|
9eafc7fc62 | ||
|
d8153df67f | ||
|
d9f6eb136d | ||
|
4dabb9be97 | ||
|
be9840c70f | ||
|
aa7d2543a3 | ||
|
33beacc053 | ||
|
442cceb9fa | ||
|
3374a45475 | ||
|
73776cf797 | ||
|
957166799e | ||
|
809d8eca13 | ||
|
9f69e7f9f7 | ||
|
478265cd45 | ||
|
3f30aaf25e | ||
|
39040e06dc | ||
|
6edd2ce0ea | ||
|
0a2e2a3eaf | ||
|
4ba1eb123c | ||
|
c44912a861 | ||
|
a8f5bf4eb0 | ||
|
30e9d91b57 | ||
|
7fd3517e4d | ||
|
a8285e06a5 | ||
|
3613ead480 | ||
|
c8f251c668 | ||
|
fa5355e99e | ||
|
30fbafd954 | ||
|
fdca9b3f8e | ||
|
e52e04f721 | ||
|
7f621fdd78 | ||
|
ae139d9ee1 | ||
|
c432befd02 | ||
|
cc7e630923 | ||
|
0998999911 | ||
|
72ff654ee0 | ||
|
11ffb962ae | ||
|
fcb19641e6 | ||
|
524a6bd0d1 | ||
|
b5f9e70ffd | ||
|
c80c8f907c | ||
|
a4eb7fa900 | ||
|
7ec09d6045 | ||
|
0641c71805 | ||
|
e7ec021b81 | ||
|
0f2447a95b | ||
|
72db784fc7 | ||
|
d59ac57aaa | ||
|
c63546470b | ||
|
55908bea36 | ||
|
6567829958 | ||
|
c324d4143d | ||
|
0acb36c118 | ||
|
26511a251f | ||
|
afd8993773 | ||
|
96bef0733f | ||
|
ec1df651e8 | ||
|
e33b1d67d5 | ||
|
ed6cde73f7 | ||
|
73cc65605e | ||
|
6c19e0736d | ||
|
08e8c02fb1 | ||
|
7beca62e4f | ||
|
e422e3d048 | ||
|
fa81eabc29 | ||
|
4498e57839 | ||
|
f97054e917 | ||
|
a2f9fef936 | ||
|
7893a74f75 | ||
|
332e470075 | ||
|
2bf9cc7253 | ||
|
bf8fc103a4 | ||
|
774893928c | ||
|
7ceaf63d41 | ||
|
8806e421f2 | ||
|
c37f988a4f | ||
|
0b4c0a1283 | ||
|
4745c34b4c | ||
|
e0196407a3 | ||
|
d8ec9c46cc | ||
|
a33349366d | ||
|
3155c16990 | ||
|
caa4340dc9 | ||
|
a31dba8ad2 | ||
|
9571124cf4 | ||
|
1c495c9b07 | ||
|
0f95dfe0e3 | ||
|
f3380c8dfe | ||
|
ab4353dd13 | ||
|
e0e490af7b | ||
|
ad4d59e2ed | ||
|
aca2a85545 | ||
|
589c7eb4df | ||
|
2873799b6d | ||
|
47cc308abf | ||
|
d9f2559214 |
171 changed files with 8298 additions and 5556 deletions
8
.github/renovate.json
vendored
8
.github/renovate.json
vendored
|
@ -1,11 +1,13 @@
|
|||
{
|
||||
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
||||
"commitMessagePrefix": "[dependencies]",
|
||||
"branchName": "main",
|
||||
"extends": [
|
||||
"config:base",
|
||||
":disableRateLimiting"
|
||||
],
|
||||
"baseBranches": [
|
||||
"dev"
|
||||
],
|
||||
"packageRules": [
|
||||
{
|
||||
"matchManagers": [
|
||||
|
@ -15,9 +17,9 @@
|
|||
},
|
||||
{
|
||||
"matchManagers": [
|
||||
"gomod"
|
||||
"dockerfile"
|
||||
],
|
||||
"groupName": "gomod"
|
||||
"groupName": "Dockerfile"
|
||||
}
|
||||
]
|
||||
}
|
10
.github/workflows/lint.yml
vendored
10
.github/workflows/lint.yml
vendored
|
@ -1,4 +1,4 @@
|
|||
name: Lint
|
||||
name: lint
|
||||
|
||||
on:
|
||||
push:
|
||||
|
@ -24,16 +24,16 @@ jobs:
|
|||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.22
|
||||
go-version: ^1.23
|
||||
- name: Cache go module
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
key: go-${{ hashFiles('**/go.sum') }}
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
with:
|
||||
version: latest
|
|
@ -1,4 +1,4 @@
|
|||
name: Debug build
|
||||
name: test
|
||||
|
||||
on:
|
||||
push:
|
||||
|
@ -16,7 +16,7 @@ on:
|
|||
|
||||
jobs:
|
||||
build:
|
||||
name: Linux Debug build
|
||||
name: Linux
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
@ -24,46 +24,14 @@ jobs:
|
|||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.22
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go118:
|
||||
name: Linux Debug build (Go 1.18)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ~1.18
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go119:
|
||||
name: Linux Debug build (Go 1.19)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ~1.19
|
||||
continue-on-error: true
|
||||
go-version: ^1.23
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go120:
|
||||
name: Linux Debug build (Go 1.20)
|
||||
name: Linux (Go 1.20)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
@ -71,7 +39,7 @@ jobs:
|
|||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.20
|
||||
continue-on-error: true
|
||||
|
@ -79,7 +47,7 @@ jobs:
|
|||
run: |
|
||||
make test
|
||||
build_go121:
|
||||
name: Linux Debug build (Go 1.21)
|
||||
name: Linux (Go 1.21)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
@ -87,15 +55,31 @@ jobs:
|
|||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.21
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build__windows:
|
||||
name: Windows Debug build
|
||||
build_go122:
|
||||
name: Linux (Go 1.22)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.22
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_windows:
|
||||
name: Windows
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
@ -103,15 +87,15 @@ jobs:
|
|||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.22
|
||||
go-version: ^1.23
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_darwin:
|
||||
name: macOS Debug build
|
||||
name: macOS
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
@ -119,9 +103,9 @@ jobs:
|
|||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.22
|
||||
go-version: ^1.23
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
|
@ -5,6 +5,8 @@ linters:
|
|||
- govet
|
||||
- gci
|
||||
- staticcheck
|
||||
- paralleltest
|
||||
- ineffassign
|
||||
|
||||
linters-settings:
|
||||
gci:
|
||||
|
@ -14,4 +16,9 @@ linters-settings:
|
|||
- prefix(github.com/sagernet/)
|
||||
- default
|
||||
staticcheck:
|
||||
go: '1.20'
|
||||
checks:
|
||||
- all
|
||||
- -SA1003
|
||||
|
||||
run:
|
||||
go: "1.23"
|
12
Makefile
12
Makefile
|
@ -8,14 +8,14 @@ fmt_install:
|
|||
go install -v github.com/daixiang0/gci@latest
|
||||
|
||||
lint:
|
||||
GOOS=linux golangci-lint run ./...
|
||||
GOOS=android golangci-lint run ./...
|
||||
GOOS=windows golangci-lint run ./...
|
||||
GOOS=darwin golangci-lint run ./...
|
||||
GOOS=freebsd golangci-lint run ./...
|
||||
GOOS=linux golangci-lint run
|
||||
GOOS=android golangci-lint run
|
||||
GOOS=windows golangci-lint run
|
||||
GOOS=darwin golangci-lint run
|
||||
GOOS=freebsd golangci-lint run
|
||||
|
||||
lint_install:
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
test:
|
||||
go test $(shell go list ./... | grep -v /internal/)
|
||||
go test ./...
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# sing
|
||||
|
||||

|
||||

|
||||
|
||||
Do you hear the people sing?
|
|
@ -2,11 +2,10 @@ package baderror
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func Contains(err error, msgList ...string) bool {
|
||||
|
@ -22,8 +21,7 @@ func WrapH2(err error) error {
|
|||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
err = E.Unwrap(err)
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return io.EOF
|
||||
}
|
||||
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {
|
||||
|
|
3
common/binary/README.md
Normal file
3
common/binary/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# binary
|
||||
|
||||
mod from go 1.22.3
|
817
common/binary/binary.go
Normal file
817
common/binary/binary.go
Normal file
|
@ -0,0 +1,817 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package binary implements simple translation between numbers and byte
|
||||
// sequences and encoding and decoding of varints.
|
||||
//
|
||||
// Numbers are translated by reading and writing fixed-size values.
|
||||
// A fixed-size value is either a fixed-size arithmetic
|
||||
// type (bool, int8, uint8, int16, float32, complex64, ...)
|
||||
// or an array or struct containing only fixed-size values.
|
||||
//
|
||||
// The varint functions encode and decode single integer values using
|
||||
// a variable-length encoding; smaller values require fewer bytes.
|
||||
// For a specification, see
|
||||
// https://developers.google.com/protocol-buffers/docs/encoding.
|
||||
//
|
||||
// This package favors simplicity over efficiency. Clients that require
|
||||
// high-performance serialization, especially for large data structures,
|
||||
// should look at more advanced solutions such as the [encoding/gob]
|
||||
// package or [google.golang.org/protobuf] for protocol buffers.
|
||||
package binary
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A ByteOrder specifies how to convert byte slices into
|
||||
// 16-, 32-, or 64-bit unsigned integers.
|
||||
//
|
||||
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
|
||||
type ByteOrder interface {
|
||||
Uint16([]byte) uint16
|
||||
Uint32([]byte) uint32
|
||||
Uint64([]byte) uint64
|
||||
PutUint16([]byte, uint16)
|
||||
PutUint32([]byte, uint32)
|
||||
PutUint64([]byte, uint64)
|
||||
String() string
|
||||
}
|
||||
|
||||
// AppendByteOrder specifies how to append 16-, 32-, or 64-bit unsigned integers
|
||||
// into a byte slice.
|
||||
//
|
||||
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
|
||||
type AppendByteOrder interface {
|
||||
AppendUint16([]byte, uint16) []byte
|
||||
AppendUint32([]byte, uint32) []byte
|
||||
AppendUint64([]byte, uint64) []byte
|
||||
String() string
|
||||
}
|
||||
|
||||
// LittleEndian is the little-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var LittleEndian littleEndian
|
||||
|
||||
// BigEndian is the big-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var BigEndian bigEndian
|
||||
|
||||
type littleEndian struct{}
|
||||
|
||||
func (littleEndian) Uint16(b []byte) uint16 {
|
||||
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint16(b[0]) | uint16(b[1])<<8
|
||||
}
|
||||
|
||||
func (littleEndian) PutUint16(b []byte, v uint16) {
|
||||
_ = b[1] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v)
|
||||
b[1] = byte(v >> 8)
|
||||
}
|
||||
|
||||
func (littleEndian) AppendUint16(b []byte, v uint16) []byte {
|
||||
return append(b,
|
||||
byte(v),
|
||||
byte(v>>8),
|
||||
)
|
||||
}
|
||||
|
||||
func (littleEndian) Uint32(b []byte) uint32 {
|
||||
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
|
||||
}
|
||||
|
||||
func (littleEndian) PutUint32(b []byte, v uint32) {
|
||||
_ = b[3] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v)
|
||||
b[1] = byte(v >> 8)
|
||||
b[2] = byte(v >> 16)
|
||||
b[3] = byte(v >> 24)
|
||||
}
|
||||
|
||||
func (littleEndian) AppendUint32(b []byte, v uint32) []byte {
|
||||
return append(b,
|
||||
byte(v),
|
||||
byte(v>>8),
|
||||
byte(v>>16),
|
||||
byte(v>>24),
|
||||
)
|
||||
}
|
||||
|
||||
func (littleEndian) Uint64(b []byte) uint64 {
|
||||
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
|
||||
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
|
||||
}
|
||||
|
||||
func (littleEndian) PutUint64(b []byte, v uint64) {
|
||||
_ = b[7] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v)
|
||||
b[1] = byte(v >> 8)
|
||||
b[2] = byte(v >> 16)
|
||||
b[3] = byte(v >> 24)
|
||||
b[4] = byte(v >> 32)
|
||||
b[5] = byte(v >> 40)
|
||||
b[6] = byte(v >> 48)
|
||||
b[7] = byte(v >> 56)
|
||||
}
|
||||
|
||||
func (littleEndian) AppendUint64(b []byte, v uint64) []byte {
|
||||
return append(b,
|
||||
byte(v),
|
||||
byte(v>>8),
|
||||
byte(v>>16),
|
||||
byte(v>>24),
|
||||
byte(v>>32),
|
||||
byte(v>>40),
|
||||
byte(v>>48),
|
||||
byte(v>>56),
|
||||
)
|
||||
}
|
||||
|
||||
func (littleEndian) String() string { return "LittleEndian" }
|
||||
|
||||
func (littleEndian) GoString() string { return "binary.LittleEndian" }
|
||||
|
||||
type bigEndian struct{}
|
||||
|
||||
func (bigEndian) Uint16(b []byte) uint16 {
|
||||
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint16(b[1]) | uint16(b[0])<<8
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint16(b []byte, v uint16) {
|
||||
_ = b[1] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v >> 8)
|
||||
b[1] = byte(v)
|
||||
}
|
||||
|
||||
func (bigEndian) AppendUint16(b []byte, v uint16) []byte {
|
||||
return append(b,
|
||||
byte(v>>8),
|
||||
byte(v),
|
||||
)
|
||||
}
|
||||
|
||||
func (bigEndian) Uint32(b []byte) uint32 {
|
||||
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint32(b []byte, v uint32) {
|
||||
_ = b[3] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v >> 24)
|
||||
b[1] = byte(v >> 16)
|
||||
b[2] = byte(v >> 8)
|
||||
b[3] = byte(v)
|
||||
}
|
||||
|
||||
func (bigEndian) AppendUint32(b []byte, v uint32) []byte {
|
||||
return append(b,
|
||||
byte(v>>24),
|
||||
byte(v>>16),
|
||||
byte(v>>8),
|
||||
byte(v),
|
||||
)
|
||||
}
|
||||
|
||||
func (bigEndian) Uint64(b []byte) uint64 {
|
||||
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
|
||||
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint64(b []byte, v uint64) {
|
||||
_ = b[7] // early bounds check to guarantee safety of writes below
|
||||
b[0] = byte(v >> 56)
|
||||
b[1] = byte(v >> 48)
|
||||
b[2] = byte(v >> 40)
|
||||
b[3] = byte(v >> 32)
|
||||
b[4] = byte(v >> 24)
|
||||
b[5] = byte(v >> 16)
|
||||
b[6] = byte(v >> 8)
|
||||
b[7] = byte(v)
|
||||
}
|
||||
|
||||
func (bigEndian) AppendUint64(b []byte, v uint64) []byte {
|
||||
return append(b,
|
||||
byte(v>>56),
|
||||
byte(v>>48),
|
||||
byte(v>>40),
|
||||
byte(v>>32),
|
||||
byte(v>>24),
|
||||
byte(v>>16),
|
||||
byte(v>>8),
|
||||
byte(v),
|
||||
)
|
||||
}
|
||||
|
||||
func (bigEndian) String() string { return "BigEndian" }
|
||||
|
||||
func (bigEndian) GoString() string { return "binary.BigEndian" }
|
||||
|
||||
func (nativeEndian) String() string { return "NativeEndian" }
|
||||
|
||||
func (nativeEndian) GoString() string { return "binary.NativeEndian" }
|
||||
|
||||
// Read reads structured binary data from r into data.
|
||||
// Data must be a pointer to a fixed-size value or a slice
|
||||
// of fixed-size values.
|
||||
// Bytes read from r are decoded using the specified byte order
|
||||
// and written to successive fields of the data.
|
||||
// When decoding boolean values, a zero byte is decoded as false, and
|
||||
// any other non-zero byte is decoded as true.
|
||||
// When reading into structs, the field data for fields with
|
||||
// blank (_) field names is skipped; i.e., blank field names
|
||||
// may be used for padding.
|
||||
// When reading into a struct, all non-blank fields must be exported
|
||||
// or Read may panic.
|
||||
//
|
||||
// The error is [io.EOF] only if no bytes were read.
|
||||
// If an [io.EOF] happens after reading some but not all the bytes,
|
||||
// Read returns [io.ErrUnexpectedEOF].
|
||||
func Read(r io.Reader, order ByteOrder, data any) error {
|
||||
// Fast path for basic types and slices.
|
||||
if n := intDataSize(data); n != 0 {
|
||||
bs := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, bs); err != nil {
|
||||
return err
|
||||
}
|
||||
switch data := data.(type) {
|
||||
case *bool:
|
||||
*data = bs[0] != 0
|
||||
case *int8:
|
||||
*data = int8(bs[0])
|
||||
case *uint8:
|
||||
*data = bs[0]
|
||||
case *int16:
|
||||
*data = int16(order.Uint16(bs))
|
||||
case *uint16:
|
||||
*data = order.Uint16(bs)
|
||||
case *int32:
|
||||
*data = int32(order.Uint32(bs))
|
||||
case *uint32:
|
||||
*data = order.Uint32(bs)
|
||||
case *int64:
|
||||
*data = int64(order.Uint64(bs))
|
||||
case *uint64:
|
||||
*data = order.Uint64(bs)
|
||||
case *float32:
|
||||
*data = math.Float32frombits(order.Uint32(bs))
|
||||
case *float64:
|
||||
*data = math.Float64frombits(order.Uint64(bs))
|
||||
case []bool:
|
||||
for i, x := range bs { // Easier to loop over the input for 8-bit values.
|
||||
data[i] = x != 0
|
||||
}
|
||||
case []int8:
|
||||
for i, x := range bs {
|
||||
data[i] = int8(x)
|
||||
}
|
||||
case []uint8:
|
||||
copy(data, bs)
|
||||
case []int16:
|
||||
for i := range data {
|
||||
data[i] = int16(order.Uint16(bs[2*i:]))
|
||||
}
|
||||
case []uint16:
|
||||
for i := range data {
|
||||
data[i] = order.Uint16(bs[2*i:])
|
||||
}
|
||||
case []int32:
|
||||
for i := range data {
|
||||
data[i] = int32(order.Uint32(bs[4*i:]))
|
||||
}
|
||||
case []uint32:
|
||||
for i := range data {
|
||||
data[i] = order.Uint32(bs[4*i:])
|
||||
}
|
||||
case []int64:
|
||||
for i := range data {
|
||||
data[i] = int64(order.Uint64(bs[8*i:]))
|
||||
}
|
||||
case []uint64:
|
||||
for i := range data {
|
||||
data[i] = order.Uint64(bs[8*i:])
|
||||
}
|
||||
case []float32:
|
||||
for i := range data {
|
||||
data[i] = math.Float32frombits(order.Uint32(bs[4*i:]))
|
||||
}
|
||||
case []float64:
|
||||
for i := range data {
|
||||
data[i] = math.Float64frombits(order.Uint64(bs[8*i:]))
|
||||
}
|
||||
default:
|
||||
n = 0 // fast path doesn't apply
|
||||
}
|
||||
if n != 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to reflect-based decoding.
|
||||
v := reflect.ValueOf(data)
|
||||
size := -1
|
||||
switch v.Kind() {
|
||||
case reflect.Pointer:
|
||||
v = v.Elem()
|
||||
size = dataSize(v)
|
||||
case reflect.Slice:
|
||||
size = dataSize(v)
|
||||
}
|
||||
if size < 0 {
|
||||
return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String())
|
||||
}
|
||||
d := &decoder{order: order, buf: make([]byte, size)}
|
||||
if _, err := io.ReadFull(r, d.buf); err != nil {
|
||||
return err
|
||||
}
|
||||
d.value(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes the binary representation of data into w.
|
||||
// Data must be a fixed-size value or a slice of fixed-size
|
||||
// values, or a pointer to such data.
|
||||
// Boolean values encode as one byte: 1 for true, and 0 for false.
|
||||
// Bytes written to w are encoded using the specified byte order
|
||||
// and read from successive fields of the data.
|
||||
// When writing structs, zero values are written for fields
|
||||
// with blank (_) field names.
|
||||
func Write(w io.Writer, order ByteOrder, data any) error {
|
||||
// Fast path for basic types and slices.
|
||||
if n := intDataSize(data); n != 0 {
|
||||
bs := make([]byte, n)
|
||||
switch v := data.(type) {
|
||||
case *bool:
|
||||
if *v {
|
||||
bs[0] = 1
|
||||
} else {
|
||||
bs[0] = 0
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
bs[0] = 1
|
||||
} else {
|
||||
bs[0] = 0
|
||||
}
|
||||
case []bool:
|
||||
for i, x := range v {
|
||||
if x {
|
||||
bs[i] = 1
|
||||
} else {
|
||||
bs[i] = 0
|
||||
}
|
||||
}
|
||||
case *int8:
|
||||
bs[0] = byte(*v)
|
||||
case int8:
|
||||
bs[0] = byte(v)
|
||||
case []int8:
|
||||
for i, x := range v {
|
||||
bs[i] = byte(x)
|
||||
}
|
||||
case *uint8:
|
||||
bs[0] = *v
|
||||
case uint8:
|
||||
bs[0] = v
|
||||
case []uint8:
|
||||
bs = v
|
||||
case *int16:
|
||||
order.PutUint16(bs, uint16(*v))
|
||||
case int16:
|
||||
order.PutUint16(bs, uint16(v))
|
||||
case []int16:
|
||||
for i, x := range v {
|
||||
order.PutUint16(bs[2*i:], uint16(x))
|
||||
}
|
||||
case *uint16:
|
||||
order.PutUint16(bs, *v)
|
||||
case uint16:
|
||||
order.PutUint16(bs, v)
|
||||
case []uint16:
|
||||
for i, x := range v {
|
||||
order.PutUint16(bs[2*i:], x)
|
||||
}
|
||||
case *int32:
|
||||
order.PutUint32(bs, uint32(*v))
|
||||
case int32:
|
||||
order.PutUint32(bs, uint32(v))
|
||||
case []int32:
|
||||
for i, x := range v {
|
||||
order.PutUint32(bs[4*i:], uint32(x))
|
||||
}
|
||||
case *uint32:
|
||||
order.PutUint32(bs, *v)
|
||||
case uint32:
|
||||
order.PutUint32(bs, v)
|
||||
case []uint32:
|
||||
for i, x := range v {
|
||||
order.PutUint32(bs[4*i:], x)
|
||||
}
|
||||
case *int64:
|
||||
order.PutUint64(bs, uint64(*v))
|
||||
case int64:
|
||||
order.PutUint64(bs, uint64(v))
|
||||
case []int64:
|
||||
for i, x := range v {
|
||||
order.PutUint64(bs[8*i:], uint64(x))
|
||||
}
|
||||
case *uint64:
|
||||
order.PutUint64(bs, *v)
|
||||
case uint64:
|
||||
order.PutUint64(bs, v)
|
||||
case []uint64:
|
||||
for i, x := range v {
|
||||
order.PutUint64(bs[8*i:], x)
|
||||
}
|
||||
case *float32:
|
||||
order.PutUint32(bs, math.Float32bits(*v))
|
||||
case float32:
|
||||
order.PutUint32(bs, math.Float32bits(v))
|
||||
case []float32:
|
||||
for i, x := range v {
|
||||
order.PutUint32(bs[4*i:], math.Float32bits(x))
|
||||
}
|
||||
case *float64:
|
||||
order.PutUint64(bs, math.Float64bits(*v))
|
||||
case float64:
|
||||
order.PutUint64(bs, math.Float64bits(v))
|
||||
case []float64:
|
||||
for i, x := range v {
|
||||
order.PutUint64(bs[8*i:], math.Float64bits(x))
|
||||
}
|
||||
}
|
||||
_, err := w.Write(bs)
|
||||
return err
|
||||
}
|
||||
|
||||
// Fallback to reflect-based encoding.
|
||||
v := reflect.Indirect(reflect.ValueOf(data))
|
||||
size := dataSize(v)
|
||||
if size < 0 {
|
||||
return errors.New("binary.Write: some values are not fixed-sized in type " + reflect.TypeOf(data).String())
|
||||
}
|
||||
buf := make([]byte, size)
|
||||
e := &encoder{order: order, buf: buf}
|
||||
e.value(v)
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// Size returns how many bytes [Write] would generate to encode the value v, which
|
||||
// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data.
|
||||
// If v is neither of these, Size returns -1.
|
||||
func Size(v any) int {
|
||||
return dataSize(reflect.Indirect(reflect.ValueOf(v)))
|
||||
}
|
||||
|
||||
var structSize sync.Map // map[reflect.Type]int
|
||||
|
||||
// dataSize returns the number of bytes the actual data represented by v occupies in memory.
|
||||
// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice
|
||||
// it returns the length of the slice times the element size and does not count the memory
|
||||
// occupied by the header. If the type of v is not acceptable, dataSize returns -1.
|
||||
func dataSize(v reflect.Value) int {
|
||||
switch v.Kind() {
|
||||
case reflect.Slice:
|
||||
if s := sizeof(v.Type().Elem()); s >= 0 {
|
||||
return s * v.Len()
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
if size, ok := structSize.Load(t); ok {
|
||||
return size.(int)
|
||||
}
|
||||
size := sizeof(t)
|
||||
structSize.Store(t, size)
|
||||
return size
|
||||
|
||||
default:
|
||||
if v.IsValid() {
|
||||
return sizeof(v.Type())
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// sizeof returns the size >= 0 of variables for the given type or -1 if the type is not acceptable.
|
||||
func sizeof(t reflect.Type) int {
|
||||
switch t.Kind() {
|
||||
case reflect.Array:
|
||||
if s := sizeof(t.Elem()); s >= 0 {
|
||||
return s * t.Len()
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
sum := 0
|
||||
for i, n := 0, t.NumField(); i < n; i++ {
|
||||
s := sizeof(t.Field(i).Type)
|
||||
if s < 0 {
|
||||
return -1
|
||||
}
|
||||
sum += s
|
||||
}
|
||||
return sum
|
||||
|
||||
case reflect.Bool,
|
||||
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
return int(t.Size())
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
type coder struct {
|
||||
order ByteOrder
|
||||
buf []byte
|
||||
offset int
|
||||
}
|
||||
|
||||
type (
|
||||
decoder coder
|
||||
encoder coder
|
||||
)
|
||||
|
||||
func (d *decoder) bool() bool {
|
||||
x := d.buf[d.offset]
|
||||
d.offset++
|
||||
return x != 0
|
||||
}
|
||||
|
||||
func (e *encoder) bool(x bool) {
|
||||
if x {
|
||||
e.buf[e.offset] = 1
|
||||
} else {
|
||||
e.buf[e.offset] = 0
|
||||
}
|
||||
e.offset++
|
||||
}
|
||||
|
||||
func (d *decoder) uint8() uint8 {
|
||||
x := d.buf[d.offset]
|
||||
d.offset++
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint8(x uint8) {
|
||||
e.buf[e.offset] = x
|
||||
e.offset++
|
||||
}
|
||||
|
||||
func (d *decoder) uint16() uint16 {
|
||||
x := d.order.Uint16(d.buf[d.offset : d.offset+2])
|
||||
d.offset += 2
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint16(x uint16) {
|
||||
e.order.PutUint16(e.buf[e.offset:e.offset+2], x)
|
||||
e.offset += 2
|
||||
}
|
||||
|
||||
func (d *decoder) uint32() uint32 {
|
||||
x := d.order.Uint32(d.buf[d.offset : d.offset+4])
|
||||
d.offset += 4
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint32(x uint32) {
|
||||
e.order.PutUint32(e.buf[e.offset:e.offset+4], x)
|
||||
e.offset += 4
|
||||
}
|
||||
|
||||
func (d *decoder) uint64() uint64 {
|
||||
x := d.order.Uint64(d.buf[d.offset : d.offset+8])
|
||||
d.offset += 8
|
||||
return x
|
||||
}
|
||||
|
||||
func (e *encoder) uint64(x uint64) {
|
||||
e.order.PutUint64(e.buf[e.offset:e.offset+8], x)
|
||||
e.offset += 8
|
||||
}
|
||||
|
||||
func (d *decoder) int8() int8 { return int8(d.uint8()) }
|
||||
|
||||
func (e *encoder) int8(x int8) { e.uint8(uint8(x)) }
|
||||
|
||||
func (d *decoder) int16() int16 { return int16(d.uint16()) }
|
||||
|
||||
func (e *encoder) int16(x int16) { e.uint16(uint16(x)) }
|
||||
|
||||
func (d *decoder) int32() int32 { return int32(d.uint32()) }
|
||||
|
||||
func (e *encoder) int32(x int32) { e.uint32(uint32(x)) }
|
||||
|
||||
func (d *decoder) int64() int64 { return int64(d.uint64()) }
|
||||
|
||||
func (e *encoder) int64(x int64) { e.uint64(uint64(x)) }
|
||||
|
||||
func (d *decoder) value(v reflect.Value) {
|
||||
switch v.Kind() {
|
||||
case reflect.Array:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
d.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
l := v.NumField()
|
||||
for i := 0; i < l; i++ {
|
||||
// Note: Calling v.CanSet() below is an optimization.
|
||||
// It would be sufficient to check the field name,
|
||||
// but creating the StructField info for each field is
|
||||
// costly (run "go test -bench=ReadStruct" and compare
|
||||
// results when making changes to this code).
|
||||
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
|
||||
d.value(v)
|
||||
} else {
|
||||
d.skip(v)
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
d.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Bool:
|
||||
v.SetBool(d.bool())
|
||||
|
||||
case reflect.Int8:
|
||||
v.SetInt(int64(d.int8()))
|
||||
case reflect.Int16:
|
||||
v.SetInt(int64(d.int16()))
|
||||
case reflect.Int32:
|
||||
v.SetInt(int64(d.int32()))
|
||||
case reflect.Int64:
|
||||
v.SetInt(d.int64())
|
||||
|
||||
case reflect.Uint8:
|
||||
v.SetUint(uint64(d.uint8()))
|
||||
case reflect.Uint16:
|
||||
v.SetUint(uint64(d.uint16()))
|
||||
case reflect.Uint32:
|
||||
v.SetUint(uint64(d.uint32()))
|
||||
case reflect.Uint64:
|
||||
v.SetUint(d.uint64())
|
||||
|
||||
case reflect.Float32:
|
||||
v.SetFloat(float64(math.Float32frombits(d.uint32())))
|
||||
case reflect.Float64:
|
||||
v.SetFloat(math.Float64frombits(d.uint64()))
|
||||
|
||||
case reflect.Complex64:
|
||||
v.SetComplex(complex(
|
||||
float64(math.Float32frombits(d.uint32())),
|
||||
float64(math.Float32frombits(d.uint32())),
|
||||
))
|
||||
case reflect.Complex128:
|
||||
v.SetComplex(complex(
|
||||
math.Float64frombits(d.uint64()),
|
||||
math.Float64frombits(d.uint64()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *encoder) value(v reflect.Value) {
|
||||
switch v.Kind() {
|
||||
case reflect.Array:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
e.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
l := v.NumField()
|
||||
for i := 0; i < l; i++ {
|
||||
// see comment for corresponding code in decoder.value()
|
||||
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
|
||||
e.value(v)
|
||||
} else {
|
||||
e.skip(v)
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
l := v.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
e.value(v.Index(i))
|
||||
}
|
||||
|
||||
case reflect.Bool:
|
||||
e.bool(v.Bool())
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Int8:
|
||||
e.int8(int8(v.Int()))
|
||||
case reflect.Int16:
|
||||
e.int16(int16(v.Int()))
|
||||
case reflect.Int32:
|
||||
e.int32(int32(v.Int()))
|
||||
case reflect.Int64:
|
||||
e.int64(v.Int())
|
||||
}
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Uint8:
|
||||
e.uint8(uint8(v.Uint()))
|
||||
case reflect.Uint16:
|
||||
e.uint16(uint16(v.Uint()))
|
||||
case reflect.Uint32:
|
||||
e.uint32(uint32(v.Uint()))
|
||||
case reflect.Uint64:
|
||||
e.uint64(v.Uint())
|
||||
}
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Float32:
|
||||
e.uint32(math.Float32bits(float32(v.Float())))
|
||||
case reflect.Float64:
|
||||
e.uint64(math.Float64bits(v.Float()))
|
||||
}
|
||||
|
||||
case reflect.Complex64, reflect.Complex128:
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Complex64:
|
||||
x := v.Complex()
|
||||
e.uint32(math.Float32bits(float32(real(x))))
|
||||
e.uint32(math.Float32bits(float32(imag(x))))
|
||||
case reflect.Complex128:
|
||||
x := v.Complex()
|
||||
e.uint64(math.Float64bits(real(x)))
|
||||
e.uint64(math.Float64bits(imag(x)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *decoder) skip(v reflect.Value) {
|
||||
d.offset += dataSize(v)
|
||||
}
|
||||
|
||||
func (e *encoder) skip(v reflect.Value) {
|
||||
n := dataSize(v)
|
||||
zero := e.buf[e.offset : e.offset+n]
|
||||
for i := range zero {
|
||||
zero[i] = 0
|
||||
}
|
||||
e.offset += n
|
||||
}
|
||||
|
||||
// intDataSize returns the size of the data required to represent the data when encoded.
|
||||
// It returns zero if the type cannot be implemented by the fast path in Read or Write.
|
||||
func intDataSize(data any) int {
|
||||
switch data := data.(type) {
|
||||
case bool, int8, uint8, *bool, *int8, *uint8:
|
||||
return 1
|
||||
case []bool:
|
||||
return len(data)
|
||||
case []int8:
|
||||
return len(data)
|
||||
case []uint8:
|
||||
return len(data)
|
||||
case int16, uint16, *int16, *uint16:
|
||||
return 2
|
||||
case []int16:
|
||||
return 2 * len(data)
|
||||
case []uint16:
|
||||
return 2 * len(data)
|
||||
case int32, uint32, *int32, *uint32:
|
||||
return 4
|
||||
case []int32:
|
||||
return 4 * len(data)
|
||||
case []uint32:
|
||||
return 4 * len(data)
|
||||
case int64, uint64, *int64, *uint64:
|
||||
return 8
|
||||
case []int64:
|
||||
return 8 * len(data)
|
||||
case []uint64:
|
||||
return 8 * len(data)
|
||||
case float32, *float32:
|
||||
return 4
|
||||
case float64, *float64:
|
||||
return 8
|
||||
case []float32:
|
||||
return 4 * len(data)
|
||||
case []float64:
|
||||
return 8 * len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
18
common/binary/export.go
Normal file
18
common/binary/export.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package binary
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func DataSize(t reflect.Value) int {
|
||||
return dataSize(t)
|
||||
}
|
||||
|
||||
func EncodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
|
||||
(&encoder{order: order, buf: buf}).value(v)
|
||||
}
|
||||
|
||||
func DecodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
|
||||
(&decoder{order: order, buf: buf}).value(v)
|
||||
}
|
14
common/binary/native_endian_big.go
Normal file
14
common/binary/native_endian_big.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build armbe || arm64be || m68k || mips || mips64 || mips64p32 || ppc || ppc64 || s390 || s390x || shbe || sparc || sparc64
|
||||
|
||||
package binary
|
||||
|
||||
type nativeEndian struct {
|
||||
bigEndian
|
||||
}
|
||||
|
||||
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var NativeEndian nativeEndian
|
14
common/binary/native_endian_little.go
Normal file
14
common/binary/native_endian_little.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build 386 || amd64 || amd64p32 || alpha || arm || arm64 || loong64 || mipsle || mips64le || mips64p32le || nios2 || ppc64le || riscv || riscv64 || sh || wasm
|
||||
|
||||
package binary
|
||||
|
||||
type nativeEndian struct {
|
||||
littleEndian
|
||||
}
|
||||
|
||||
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
|
||||
var NativeEndian nativeEndian
|
166
common/binary/varint.go
Normal file
166
common/binary/varint.go
Normal file
|
@ -0,0 +1,166 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package binary
|
||||
|
||||
// This file implements "varint" encoding of 64-bit integers.
|
||||
// The encoding is:
|
||||
// - unsigned integers are serialized 7 bits at a time, starting with the
|
||||
// least significant bits
|
||||
// - the most significant bit (msb) in each output byte indicates if there
|
||||
// is a continuation byte (msb = 1)
|
||||
// - signed integers are mapped to unsigned integers using "zig-zag"
|
||||
// encoding: Positive values x are written as 2*x + 0, negative values
|
||||
// are written as 2*(^x) + 1; that is, negative numbers are complemented
|
||||
// and whether to complement is encoded in bit 0.
|
||||
//
|
||||
// Design note:
|
||||
// At most 10 bytes are needed for 64-bit values. The encoding could
|
||||
// be more dense: a full 64-bit value needs an extra byte just to hold bit 63.
|
||||
// Instead, the msb of the previous byte could be used to hold bit 63 since we
|
||||
// know there can't be more than 64 bits. This is a trivial improvement and
|
||||
// would reduce the maximum encoding length to 9 bytes. However, it breaks the
|
||||
// invariant that the msb is always the "continuation bit" and thus makes the
|
||||
// format incompatible with a varint encoding for larger numbers (say 128-bit).
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer.
|
||||
const (
|
||||
MaxVarintLen16 = 3
|
||||
MaxVarintLen32 = 5
|
||||
MaxVarintLen64 = 10
|
||||
)
|
||||
|
||||
// AppendUvarint appends the varint-encoded form of x,
|
||||
// as generated by [PutUvarint], to buf and returns the extended buffer.
|
||||
func AppendUvarint(buf []byte, x uint64) []byte {
|
||||
for x >= 0x80 {
|
||||
buf = append(buf, byte(x)|0x80)
|
||||
x >>= 7
|
||||
}
|
||||
return append(buf, byte(x))
|
||||
}
|
||||
|
||||
// PutUvarint encodes a uint64 into buf and returns the number of bytes written.
|
||||
// If the buffer is too small, PutUvarint will panic.
|
||||
func PutUvarint(buf []byte, x uint64) int {
|
||||
i := 0
|
||||
for x >= 0x80 {
|
||||
buf[i] = byte(x) | 0x80
|
||||
x >>= 7
|
||||
i++
|
||||
}
|
||||
buf[i] = byte(x)
|
||||
return i + 1
|
||||
}
|
||||
|
||||
// Uvarint decodes a uint64 from buf and returns that value and the
|
||||
// number of bytes read (> 0). If an error occurred, the value is 0
|
||||
// and the number of bytes n is <= 0 meaning:
|
||||
//
|
||||
// n == 0: buf too small
|
||||
// n < 0: value larger than 64 bits (overflow)
|
||||
// and -n is the number of bytes read
|
||||
func Uvarint(buf []byte) (uint64, int) {
|
||||
var x uint64
|
||||
var s uint
|
||||
for i, b := range buf {
|
||||
if i == MaxVarintLen64 {
|
||||
// Catch byte reads past MaxVarintLen64.
|
||||
// See issue https://golang.org/issues/41185
|
||||
return 0, -(i + 1) // overflow
|
||||
}
|
||||
if b < 0x80 {
|
||||
if i == MaxVarintLen64-1 && b > 1 {
|
||||
return 0, -(i + 1) // overflow
|
||||
}
|
||||
return x | uint64(b)<<s, i + 1
|
||||
}
|
||||
x |= uint64(b&0x7f) << s
|
||||
s += 7
|
||||
}
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
// AppendVarint appends the varint-encoded form of x,
|
||||
// as generated by [PutVarint], to buf and returns the extended buffer.
|
||||
func AppendVarint(buf []byte, x int64) []byte {
|
||||
ux := uint64(x) << 1
|
||||
if x < 0 {
|
||||
ux = ^ux
|
||||
}
|
||||
return AppendUvarint(buf, ux)
|
||||
}
|
||||
|
||||
// PutVarint encodes an int64 into buf and returns the number of bytes written.
|
||||
// If the buffer is too small, PutVarint will panic.
|
||||
func PutVarint(buf []byte, x int64) int {
|
||||
ux := uint64(x) << 1
|
||||
if x < 0 {
|
||||
ux = ^ux
|
||||
}
|
||||
return PutUvarint(buf, ux)
|
||||
}
|
||||
|
||||
// Varint decodes an int64 from buf and returns that value and the
|
||||
// number of bytes read (> 0). If an error occurred, the value is 0
|
||||
// and the number of bytes n is <= 0 with the following meaning:
|
||||
//
|
||||
// n == 0: buf too small
|
||||
// n < 0: value larger than 64 bits (overflow)
|
||||
// and -n is the number of bytes read
|
||||
func Varint(buf []byte) (int64, int) {
|
||||
ux, n := Uvarint(buf) // ok to continue in presence of error
|
||||
x := int64(ux >> 1)
|
||||
if ux&1 != 0 {
|
||||
x = ^x
|
||||
}
|
||||
return x, n
|
||||
}
|
||||
|
||||
var errOverflow = errors.New("binary: varint overflows a 64-bit integer")
|
||||
|
||||
// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64.
|
||||
// The error is [io.EOF] only if no bytes were read.
|
||||
// If an [io.EOF] happens after reading some but not all the bytes,
|
||||
// ReadUvarint returns [io.ErrUnexpectedEOF].
|
||||
func ReadUvarint(r io.ByteReader) (uint64, error) {
|
||||
var x uint64
|
||||
var s uint
|
||||
for i := 0; i < MaxVarintLen64; i++ {
|
||||
b, err := r.ReadByte()
|
||||
if err != nil {
|
||||
if i > 0 && err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return x, err
|
||||
}
|
||||
if b < 0x80 {
|
||||
if i == MaxVarintLen64-1 && b > 1 {
|
||||
return x, errOverflow
|
||||
}
|
||||
return x | uint64(b)<<s, nil
|
||||
}
|
||||
x |= uint64(b&0x7f) << s
|
||||
s += 7
|
||||
}
|
||||
return x, errOverflow
|
||||
}
|
||||
|
||||
// ReadVarint reads an encoded signed integer from r and returns it as an int64.
|
||||
// The error is [io.EOF] only if no bytes were read.
|
||||
// If an [io.EOF] happens after reading some but not all the bytes,
|
||||
// ReadVarint returns [io.ErrUnexpectedEOF].
|
||||
func ReadVarint(r io.ByteReader) (int64, error) {
|
||||
ux, err := ReadUvarint(r) // ok to continue in presence of error
|
||||
x := int64(ux >> 1)
|
||||
if ux&1 != 0 {
|
||||
x = ^x
|
||||
}
|
||||
return x, err
|
||||
}
|
|
@ -9,19 +9,20 @@ import (
|
|||
|
||||
type AddrConn struct {
|
||||
net.Conn
|
||||
M.Metadata
|
||||
Source M.Socksaddr
|
||||
Destination M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *AddrConn) LocalAddr() net.Addr {
|
||||
if c.Metadata.Destination.IsValid() {
|
||||
return c.Metadata.Destination.TCPAddr()
|
||||
if c.Destination.IsValid() {
|
||||
return c.Destination.TCPAddr()
|
||||
}
|
||||
return c.Conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *AddrConn) RemoteAddr() net.Addr {
|
||||
if c.Metadata.Source.IsValid() {
|
||||
return c.Metadata.Source.TCPAddr()
|
||||
if c.Source.IsValid() {
|
||||
return c.Source.TCPAddr()
|
||||
}
|
||||
return c.Conn.RemoteAddr()
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
|
@ -41,6 +42,25 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) WriteByte(c byte) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
if w.buffer == nil {
|
||||
return common.Error(w.upstream.Write([]byte{c}))
|
||||
}
|
||||
for {
|
||||
err := w.buffer.WriteByte(c)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
_, err = w.upstream.Write(w.buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.buffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Fallthrough() error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
|
|
|
@ -3,7 +3,6 @@ package bufio
|
|||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
@ -60,13 +59,6 @@ func (c *CachedConn) WriteTo(w io.Writer) (n int64, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *CachedConn) SetReadDeadline(t time.Time) error {
|
||||
if c.buffer != nil && !c.buffer.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
return c.Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *CachedConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
return Copy(c.Conn, r)
|
||||
}
|
||||
|
@ -192,10 +184,12 @@ func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
|
|||
if buffer != nil {
|
||||
buffer.DecRef()
|
||||
}
|
||||
return &N.PacketBuffer{
|
||||
packet := N.NewPacketBuffer()
|
||||
*packet = N.PacketBuffer{
|
||||
Buffer: buffer,
|
||||
Destination: c.destination,
|
||||
}
|
||||
return packet
|
||||
}
|
||||
|
||||
func (c *CachedPacketConn) Upstream() any {
|
||||
|
|
|
@ -35,14 +35,7 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
|||
|
||||
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
defer buffer.Release()
|
||||
if destination.IsFqdn() {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
|
||||
}
|
||||
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
|
||||
return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
|
||||
}
|
||||
|
||||
func (w *ExtendedUDPConn) Upstream() any {
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
)
|
||||
|
||||
|
@ -30,28 +29,36 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
|||
if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
||||
cachedBuffer := cachedSrc.ReadCached()
|
||||
if cachedBuffer != nil {
|
||||
if !cachedBuffer.IsEmpty() {
|
||||
_, err = destination.Write(cachedBuffer.Bytes())
|
||||
if err != nil {
|
||||
cachedBuffer.Release()
|
||||
return
|
||||
}
|
||||
}
|
||||
dataLen := cachedBuffer.Len()
|
||||
_, err = destination.Write(cachedBuffer.Bytes())
|
||||
cachedBuffer.Release()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
|
||||
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
||||
if srcIsSyscall && dstIsSyscall {
|
||||
var handled bool
|
||||
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
|
||||
}
|
||||
|
||||
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
|
||||
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
||||
if srcIsSyscall && dstIsSyscall {
|
||||
var handled bool
|
||||
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
||||
}
|
||||
|
||||
|
@ -76,6 +83,7 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N
|
|||
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
|
||||
}
|
||||
|
||||
// Deprecated: not used
|
||||
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
|
@ -114,19 +122,10 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
|
|||
}
|
||||
|
||||
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
bufferSize := N.CalculateMTU(source, destination)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
options := N.NewReadWaitOptions(source, destination)
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
buffer.Resize(frontHeadroom, 0)
|
||||
buffer.Reserve(rearHeadroom)
|
||||
buffer := options.NewBuffer()
|
||||
err = source.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
|
@ -137,7 +136,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
|
|||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
buffer.OverCap(rearHeadroom)
|
||||
options.PostReturn(buffer)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
|
@ -158,16 +157,12 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
|
|||
}
|
||||
|
||||
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
|
||||
return CopyConnContextList([]context.Context{ctx}, source, destination)
|
||||
}
|
||||
|
||||
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
|
||||
var group task.Group
|
||||
if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex {
|
||||
if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex {
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
err := common.Error(Copy(destination, source))
|
||||
if err == nil {
|
||||
rw.CloseWrite(destination)
|
||||
N.CloseWrite(destination)
|
||||
} else {
|
||||
common.Close(destination)
|
||||
}
|
||||
|
@ -179,11 +174,11 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
|
|||
return common.Error(Copy(destination, source))
|
||||
})
|
||||
}
|
||||
if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex {
|
||||
if _, srcDuplex := common.Cast[N.WriteCloser](source); srcDuplex {
|
||||
group.Append("download", func(ctx context.Context) error {
|
||||
err := common.Error(Copy(source, destination))
|
||||
if err == nil {
|
||||
rw.CloseWrite(source)
|
||||
N.CloseWrite(source)
|
||||
} else {
|
||||
common.Close(source)
|
||||
}
|
||||
|
@ -198,7 +193,7 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
|
|||
group.Cleanup(func() {
|
||||
common.Close(source, destination)
|
||||
})
|
||||
return group.RunContextList(contextList)
|
||||
return group.Run(ctx)
|
||||
}
|
||||
|
||||
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
|
||||
|
@ -218,24 +213,24 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
|
|||
break
|
||||
}
|
||||
if cachedPackets != nil {
|
||||
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
|
||||
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
|
||||
n += copeN
|
||||
return
|
||||
}
|
||||
|
||||
func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
var (
|
||||
handled bool
|
||||
copeN int64
|
||||
)
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
||||
if isReadWaiter {
|
||||
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
FrontHeadroom: frontHeadroom,
|
||||
RearHeadroom: rearHeadroom,
|
||||
MTU: N.CalculateMTU(source, destinationConn),
|
||||
})
|
||||
needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
|
||||
if !needCopy || common.LowMemory {
|
||||
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
||||
if handled {
|
||||
|
@ -249,28 +244,19 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
|
|||
return
|
||||
}
|
||||
|
||||
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
var destination M.Socksaddr
|
||||
func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||
options := N.NewReadWaitOptions(source, destination)
|
||||
var destinationAddress M.Socksaddr
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
buffer.Resize(frontHeadroom, 0)
|
||||
buffer.Reserve(rearHeadroom)
|
||||
destination, err = source.ReadPacket(buffer)
|
||||
buffer := options.NewPacketBuffer()
|
||||
destinationAddress, err = source.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
buffer.OverCap(rearHeadroom)
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
options.PostReturn(buffer)
|
||||
err = destination.WritePacket(buffer, destinationAddress)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
|
@ -278,34 +264,25 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
|
|||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
options := N.NewReadWaitOptions(nil, destination)
|
||||
var notFirstTime bool
|
||||
for _, packetBuffer := range packetBuffers {
|
||||
buffer := buf.NewPacket()
|
||||
buffer.Resize(frontHeadroom, 0)
|
||||
buffer.Reserve(rearHeadroom)
|
||||
_, err = buffer.Write(packetBuffer.Buffer.Bytes())
|
||||
packetBuffer.Buffer.Release()
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
continue
|
||||
}
|
||||
buffer := options.Copy(packetBuffer.Buffer)
|
||||
dataLen := buffer.Len()
|
||||
buffer.OverCap(rearHeadroom)
|
||||
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
|
||||
err = destination.WritePacket(buffer, packetBuffer.Destination)
|
||||
N.PutPacketBuffer(packetBuffer)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
|
@ -313,16 +290,19 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr
|
|||
}
|
||||
return
|
||||
}
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
|
||||
return CopyPacketConnContextList([]context.Context{ctx}, source, destination)
|
||||
}
|
||||
|
||||
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
|
||||
var group task.Group
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
return common.Error(CopyPacket(destination, source))
|
||||
|
@ -334,5 +314,5 @@ func CopyPacketConnContextList(contextList []context.Context, source N.PacketCon
|
|||
common.Close(source, destination)
|
||||
})
|
||||
group.FastFail()
|
||||
return group.RunContextList(contextList)
|
||||
return group.Run(ctx)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
@ -15,49 +14,6 @@ import (
|
|||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||
|
||||
var procrecv = modws2_32.NewProc("recv")
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
|
||||
var _p0 *byte
|
||||
if len(buf) > 0 {
|
||||
_p0 = &buf[0]
|
||||
}
|
||||
r0, _, e1 := syscall.SyscallN(procrecv.Addr(), uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags))
|
||||
n = int32(r0)
|
||||
if n == -1 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
||||
|
||||
type syscallReadWaiter struct {
|
||||
|
@ -164,16 +120,16 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
|
|||
var readN int
|
||||
var from windows.Sockaddr
|
||||
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if w.readErr != nil {
|
||||
buffer.Release()
|
||||
return w.readErr != windows.WSAEWOULDBLOCK
|
||||
}
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
} else {
|
||||
buffer.Release()
|
||||
}
|
||||
if w.readErr == windows.WSAEWOULDBLOCK {
|
||||
return false
|
||||
}
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
if from != nil {
|
||||
switch fromAddr := from.(type) {
|
||||
case *windows.SockaddrInet4:
|
||||
|
|
|
@ -25,6 +25,45 @@ func ReadPacket(reader N.PacketReader, buffer *buf.Buffer) (n int, addr net.Addr
|
|||
return
|
||||
}
|
||||
|
||||
func ReadBufferSize(reader io.Reader, bufferSize int) (buffer *buf.Buffer, err error) {
|
||||
readWaiter, isReadWaiter := CreateReadWaiter(reader)
|
||||
if isReadWaiter {
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
MTU: bufferSize,
|
||||
})
|
||||
return readWaiter.WaitReadBuffer()
|
||||
}
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
if extendedReader, isExtendedReader := reader.(N.ExtendedReader); isExtendedReader {
|
||||
err = extendedReader.ReadBuffer(buffer)
|
||||
} else {
|
||||
_, err = buffer.ReadOnceFrom(reader)
|
||||
}
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ReadPacketSize(reader N.PacketReader, packetSize int) (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(reader)
|
||||
if isReadWaiter {
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
MTU: packetSize,
|
||||
})
|
||||
buffer, destination, err = readWaiter.WaitReadPacket()
|
||||
return
|
||||
}
|
||||
buffer = buf.NewSize(packetSize)
|
||||
destination, err = reader.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func Write(writer io.Writer, data []byte) (n int, err error) {
|
||||
if extendedWriter, isExtended := writer.(N.ExtendedWriter); isExtended {
|
||||
return WriteBuffer(extendedWriter, buf.As(data))
|
||||
|
|
|
@ -30,6 +30,14 @@ func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.So
|
|||
}
|
||||
}
|
||||
|
||||
func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
|
||||
return &destinationNATPacketConn{
|
||||
NetPacketConn: conn,
|
||||
origin: origin,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
type unidirectionalNATPacketConn struct {
|
||||
N.NetPacketConn
|
||||
origin M.Socksaddr
|
||||
|
@ -144,6 +152,60 @@ func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
|
|||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
type destinationNATPacketConn struct {
|
||||
N.NetPacketConn
|
||||
origin M.Socksaddr
|
||||
destination M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = c.NetPacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if M.SocksaddrFromNet(addr) == c.origin {
|
||||
addr = c.destination.UDPAddr()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if M.SocksaddrFromNet(addr) == c.destination {
|
||||
addr = c.origin.UDPAddr()
|
||||
}
|
||||
return c.NetPacketConn.WriteTo(p, addr)
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
destination, err = c.NetPacketConn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if destination == c.origin {
|
||||
destination = c.destination
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if destination == c.destination {
|
||||
destination = c.origin
|
||||
}
|
||||
return c.NetPacketConn.WritePacket(buffer, destination)
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
|
||||
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
func (c *destinationNATPacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
|
||||
destination.Port = 0
|
||||
return destination
|
||||
|
|
|
@ -36,7 +36,7 @@ func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
|
|||
clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
|
||||
return clientErr
|
||||
})
|
||||
err = group.Run()
|
||||
err = group.Run(context.Background())
|
||||
require.NoError(t, err)
|
||||
listener.Close()
|
||||
t.Cleanup(func() {
|
||||
|
|
5
common/bufio/syscall_windows.go
Normal file
5
common/bufio/syscall_windows.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package bufio
|
||||
|
||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
|
||||
|
||||
//sys recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) [failretval == -1] = ws2_32.recv
|
|
@ -38,7 +38,6 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
|
|||
var innerErr unix.Errno
|
||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
//nolint:staticcheck
|
||||
//goland:noinspection GoDeprecation
|
||||
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
|
||||
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
|
||||
})
|
||||
|
|
57
common/bufio/zsyscall_windows.go
Normal file
57
common/bufio/zsyscall_windows.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package bufio
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||
|
||||
procrecv = modws2_32.NewProc("recv")
|
||||
)
|
||||
|
||||
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
|
||||
var _p0 *byte
|
||||
if len(buf) > 0 {
|
||||
_p0 = &buf[0]
|
||||
}
|
||||
r0, _, e1 := syscall.Syscall6(procrecv.Addr(), 4, uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags), 0, 0)
|
||||
n = int32(r0)
|
||||
if n == -1 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
|
@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration {
|
|||
return i.timeout
|
||||
}
|
||||
|
||||
func (i *Instance) SetTimeout(timeout time.Duration) {
|
||||
func (i *Instance) SetTimeout(timeout time.Duration) bool {
|
||||
i.timeout = timeout
|
||||
i.Update()
|
||||
return i.Update()
|
||||
}
|
||||
|
||||
func (i *Instance) wait() {
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
type PacketConn interface {
|
||||
N.PacketConn
|
||||
Timeout() time.Duration
|
||||
SetTimeout(timeout time.Duration)
|
||||
SetTimeout(timeout time.Duration) bool
|
||||
}
|
||||
|
||||
type TimerPacketConn struct {
|
||||
|
@ -24,10 +24,12 @@ type TimerPacketConn struct {
|
|||
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
|
||||
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
|
||||
oldTimeout := timeoutConn.Timeout()
|
||||
if timeout < oldTimeout {
|
||||
timeoutConn.SetTimeout(timeout)
|
||||
if oldTimeout > 0 && timeout >= oldTimeout {
|
||||
return ctx, conn
|
||||
}
|
||||
if timeoutConn.SetTimeout(timeout) {
|
||||
return ctx, conn
|
||||
}
|
||||
return ctx, conn
|
||||
}
|
||||
err := conn.SetReadDeadline(time.Time{})
|
||||
if err == nil {
|
||||
|
@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration {
|
|||
return c.instance.Timeout()
|
||||
}
|
||||
|
||||
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) {
|
||||
c.instance.SetTimeout(timeout)
|
||||
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
|
||||
return c.instance.SetTimeout(timeout)
|
||||
}
|
||||
|
||||
func (c *TimerPacketConn) Close() error {
|
||||
|
|
|
@ -61,9 +61,9 @@ func (c *TimeoutPacketConn) Timeout() time.Duration {
|
|||
return c.timeout
|
||||
}
|
||||
|
||||
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
|
||||
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
|
||||
c.timeout = timeout
|
||||
c.PacketConn.SetReadDeadline(time.Now())
|
||||
return c.PacketConn.SetReadDeadline(time.Now()) == nil
|
||||
}
|
||||
|
||||
func (c *TimeoutPacketConn) Close() error {
|
||||
|
|
|
@ -157,6 +157,18 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
|
|||
return -1
|
||||
}
|
||||
|
||||
func Equal[S ~[]E, E comparable](s1, s2 S) bool {
|
||||
if len(s1) != len(s2) {
|
||||
return false
|
||||
}
|
||||
for i := range s1 {
|
||||
if s1[i] != s2[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
//go:norace
|
||||
func Dup[T any](obj T) T {
|
||||
pointer := uintptr(unsafe.Pointer(&obj))
|
||||
|
@ -268,6 +280,14 @@ func Reverse[T any](arr []T) []T {
|
|||
return arr
|
||||
}
|
||||
|
||||
func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K {
|
||||
ret := make(map[V]K, len(m))
|
||||
for k, v := range m {
|
||||
ret[v] = k
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func Done(ctx context.Context) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
@ -362,22 +382,3 @@ func Close(closers ...any) error {
|
|||
}
|
||||
return retErr
|
||||
}
|
||||
|
||||
type Starter interface {
|
||||
Start() error
|
||||
}
|
||||
|
||||
func Start(starters ...any) error {
|
||||
for _, rawStarter := range starters {
|
||||
if rawStarter == nil {
|
||||
continue
|
||||
}
|
||||
if starter, isStarter := rawStarter.(Starter); isStarter {
|
||||
err := starter.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"reflect"
|
||||
)
|
||||
|
||||
// Deprecated: not used
|
||||
func SelectContext(contextList []context.Context) (int, error) {
|
||||
if len(contextList) == 1 {
|
||||
<-contextList[0].Done()
|
||||
|
|
|
@ -9,15 +9,15 @@ import (
|
|||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
var err error
|
||||
if interfaceIndex == -1 {
|
||||
if finder == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||
iif, err := finder.ByName(interfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaceIndex = iif.Index
|
||||
}
|
||||
switch network {
|
||||
case "tcp6", "udp6":
|
||||
|
|
|
@ -1,18 +1,59 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type InterfaceFinder interface {
|
||||
InterfaceIndexByName(name string) (int, error)
|
||||
InterfaceNameByIndex(index int) (string, error)
|
||||
InterfaceByAddr(addr netip.Addr) (*Interface, error)
|
||||
Update() error
|
||||
Interfaces() []Interface
|
||||
ByName(name string) (*Interface, error)
|
||||
ByIndex(index int) (*Interface, error)
|
||||
ByAddr(addr netip.Addr) (*Interface, error)
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
Index int
|
||||
MTU int
|
||||
Name string
|
||||
Addresses []netip.Prefix
|
||||
Index int
|
||||
MTU int
|
||||
Name string
|
||||
HardwareAddr net.HardwareAddr
|
||||
Flags net.Flags
|
||||
Addresses []netip.Prefix
|
||||
}
|
||||
|
||||
func (i Interface) Equals(other Interface) bool {
|
||||
return i.Index == other.Index &&
|
||||
i.MTU == other.MTU &&
|
||||
i.Name == other.Name &&
|
||||
common.Equal(i.HardwareAddr, other.HardwareAddr) &&
|
||||
i.Flags == other.Flags &&
|
||||
common.Equal(i.Addresses, other.Addresses)
|
||||
}
|
||||
|
||||
func (i Interface) NetInterface() net.Interface {
|
||||
return *(*net.Interface)(unsafe.Pointer(&i))
|
||||
}
|
||||
|
||||
func InterfaceFromNet(iif net.Interface) (Interface, error) {
|
||||
ifAddrs, err := iif.Addrs()
|
||||
if err != nil {
|
||||
return Interface{}, err
|
||||
}
|
||||
return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil
|
||||
}
|
||||
|
||||
func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface {
|
||||
return Interface{
|
||||
Index: iif.Index,
|
||||
MTU: iif.MTU,
|
||||
Name: iif.Name,
|
||||
HardwareAddr: iif.HardwareAddr,
|
||||
Flags: iif.Flags,
|
||||
Addresses: addresses,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,12 +3,12 @@ package control
|
|||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
|
||||
|
||||
type DefaultInterfaceFinder struct {
|
||||
interfaces []Interface
|
||||
}
|
||||
|
@ -24,16 +24,12 @@ func (f *DefaultInterfaceFinder) Update() error {
|
|||
}
|
||||
interfaces := make([]Interface, 0, len(netIfs))
|
||||
for _, netIf := range netIfs {
|
||||
ifAddrs, err := netIf.Addrs()
|
||||
var iif Interface
|
||||
iif, err = InterfaceFromNet(netIf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaces = append(interfaces, Interface{
|
||||
Index: netIf.Index,
|
||||
MTU: netIf.MTU,
|
||||
Name: netIf.Name,
|
||||
Addresses: common.Map(ifAddrs, M.PrefixFromNet),
|
||||
})
|
||||
interfaces = append(interfaces, iif)
|
||||
}
|
||||
f.interfaces = interfaces
|
||||
return nil
|
||||
|
@ -43,38 +39,45 @@ func (f *DefaultInterfaceFinder) UpdateInterfaces(interfaces []Interface) {
|
|||
f.interfaces = interfaces
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
|
||||
func (f *DefaultInterfaceFinder) Interfaces() []Interface {
|
||||
return f.interfaces
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) {
|
||||
for _, netInterface := range f.interfaces {
|
||||
if netInterface.Name == name {
|
||||
return netInterface.Index, nil
|
||||
return &netInterface, nil
|
||||
}
|
||||
}
|
||||
netInterface, err := net.InterfaceByName(name)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
_, err := net.InterfaceByName(name)
|
||||
if err == nil {
|
||||
err = f.Update()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f.ByName(name)
|
||||
}
|
||||
f.Update()
|
||||
return netInterface.Index, nil
|
||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
|
||||
}
|
||||
|
||||
func (f *DefaultInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
|
||||
func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) {
|
||||
for _, netInterface := range f.interfaces {
|
||||
if netInterface.Index == index {
|
||||
return netInterface.Name, nil
|
||||
return &netInterface, nil
|
||||
}
|
||||
}
|
||||
netInterface, err := net.InterfaceByIndex(index)
|
||||
if err != nil {
|
||||
return "", err
|
||||
_, err := net.InterfaceByIndex(index)
|
||||
if err == nil {
|
||||
err = f.Update()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f.ByIndex(index)
|
||||
}
|
||||
f.Update()
|
||||
return netInterface.Name, nil
|
||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
|
||||
}
|
||||
|
||||
//go:linkname errNoSuchInterface net.errNoSuchInterface
|
||||
var errNoSuchInterface error
|
||||
|
||||
func (f *DefaultInterfaceFinder) InterfaceByAddr(addr netip.Addr) (*Interface, error) {
|
||||
func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) {
|
||||
for _, netInterface := range f.interfaces {
|
||||
for _, prefix := range netInterface.Addresses {
|
||||
if prefix.Contains(addr) {
|
||||
|
@ -82,16 +85,5 @@ func (f *DefaultInterfaceFinder) InterfaceByAddr(addr netip.Addr) (*Interface, e
|
|||
}
|
||||
}
|
||||
}
|
||||
err := f.Update()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, netInterface := range f.interfaces {
|
||||
for _, prefix := range netInterface.Addresses {
|
||||
if prefix.Contains(addr) {
|
||||
return &netInterface, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: addr.AsSlice()}, Err: errNoSuchInterface}
|
||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: addr.AsSlice()}, Err: E.New("no such network interface")}
|
||||
}
|
||||
|
|
|
@ -19,11 +19,11 @@ func bindToInterface(conn syscall.RawConn, network string, address string, finde
|
|||
if interfaceName == "" {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
var err error
|
||||
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||
iif, err := finder.ByName(interfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaceIndex = iif.Index
|
||||
}
|
||||
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
|
||||
if err == nil {
|
||||
|
|
|
@ -11,19 +11,19 @@ import (
|
|||
|
||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
var err error
|
||||
if interfaceIndex == -1 {
|
||||
if finder == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||
iif, err := finder.ByName(interfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaceIndex = iif.Index
|
||||
}
|
||||
handle := syscall.Handle(fd)
|
||||
if M.ParseSocksaddr(address).AddrString() == "" {
|
||||
err = bind4(handle, interfaceIndex)
|
||||
err := bind4(handle, interfaceIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -4,19 +4,26 @@ import (
|
|||
"os"
|
||||
"syscall"
|
||||
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func DisableUDPFragment() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
if N.NetworkName(network) != N.NetworkUDP {
|
||||
return nil
|
||||
}
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
switch network {
|
||||
case "udp4":
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1); err != nil {
|
||||
if network == "udp" || network == "udp4" {
|
||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
|
||||
}
|
||||
case "udp6":
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1); err != nil {
|
||||
}
|
||||
if network == "udp" || network == "udp6" {
|
||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,17 +11,19 @@ import (
|
|||
|
||||
func DisableUDPFragment() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkUDP:
|
||||
default:
|
||||
if N.NetworkName(network) != N.NetworkUDP {
|
||||
return nil
|
||||
}
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
if network == "udp" || network == "udp4" {
|
||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
}
|
||||
}
|
||||
if network == "udp6" {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
|
||||
if network == "udp" || network == "udp6" {
|
||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,17 +25,19 @@ const (
|
|||
|
||||
func DisableUDPFragment() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkUDP:
|
||||
default:
|
||||
if N.NetworkName(network) != N.NetworkUDP {
|
||||
return nil
|
||||
}
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
if network == "udp" || network == "udp4" {
|
||||
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
}
|
||||
}
|
||||
if network == "udp6" {
|
||||
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
|
||||
if network == "udp" || network == "udp6" {
|
||||
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package control
|
|||
import (
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
|
@ -30,6 +31,14 @@ func Conn(conn syscall.Conn, block func(fd uintptr) error) error {
|
|||
return Raw(rawConn, block)
|
||||
}
|
||||
|
||||
func Conn0[T any](conn syscall.Conn, block func(fd uintptr) (T, error)) (T, error) {
|
||||
rawConn, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), err
|
||||
}
|
||||
return Raw0[T](rawConn, block)
|
||||
}
|
||||
|
||||
func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
|
||||
var innerErr error
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
|
@ -37,3 +46,14 @@ func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
|
|||
})
|
||||
return E.Errors(innerErr, err)
|
||||
}
|
||||
|
||||
func Raw0[T any](rawConn syscall.RawConn, block func(fd uintptr) (T, error)) (T, error) {
|
||||
var (
|
||||
value T
|
||||
innerErr error
|
||||
)
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
value, innerErr = block(fd)
|
||||
})
|
||||
return value, E.Errors(innerErr, err)
|
||||
}
|
||||
|
|
|
@ -4,10 +4,10 @@ import (
|
|||
"syscall"
|
||||
)
|
||||
|
||||
func RoutingMark(mark int) Func {
|
||||
func RoutingMark(mark uint32) Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark)
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(mark))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
package control
|
||||
|
||||
func RoutingMark(mark int) Func {
|
||||
func RoutingMark(mark uint32) Func {
|
||||
return nil
|
||||
}
|
||||
|
|
58
common/control/redirect_darwin.go
Normal file
58
common/control/redirect_darwin.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
PF_OUT = 0x2
|
||||
DIOCNATLOOK = 0xc0544417
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
|
||||
pfFd, err := syscall.Open("/dev/pf", 0, syscall.O_RDONLY)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
defer syscall.Close(pfFd)
|
||||
nl := struct {
|
||||
saddr, daddr, rsaddr, rdaddr [16]byte
|
||||
sxport, dxport, rsxport, rdxport [4]byte
|
||||
af, proto, protoVariant, direction uint8
|
||||
}{
|
||||
af: syscall.AF_INET,
|
||||
proto: syscall.IPPROTO_TCP,
|
||||
direction: PF_OUT,
|
||||
}
|
||||
localAddr := M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
|
||||
removeAddr := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
|
||||
if localAddr.IsIPv4() {
|
||||
copy(nl.saddr[:net.IPv4len], removeAddr.Addr.AsSlice())
|
||||
copy(nl.daddr[:net.IPv4len], localAddr.Addr.AsSlice())
|
||||
nl.af = syscall.AF_INET
|
||||
} else {
|
||||
copy(nl.saddr[:], removeAddr.Addr.AsSlice())
|
||||
copy(nl.daddr[:], localAddr.Addr.AsSlice())
|
||||
nl.af = syscall.AF_INET6
|
||||
}
|
||||
binary.BigEndian.PutUint16(nl.sxport[:], removeAddr.Port)
|
||||
binary.BigEndian.PutUint16(nl.dxport[:], localAddr.Port)
|
||||
if _, _, errno := unix.Syscall(syscall.SYS_IOCTL, uintptr(pfFd), DIOCNATLOOK, uintptr(unsafe.Pointer(&nl))); errno != 0 {
|
||||
return netip.AddrPort{}, errno
|
||||
}
|
||||
var address netip.Addr
|
||||
if nl.af == unix.AF_INET {
|
||||
address = M.AddrFromIP(nl.rdaddr[:net.IPv4len])
|
||||
} else {
|
||||
address = netip.AddrFrom16(nl.rdaddr)
|
||||
}
|
||||
return netip.AddrPortFrom(address, binary.BigEndian.Uint16(nl.rdxport[:])), nil
|
||||
}
|
38
common/control/redirect_linux.go
Normal file
38
common/control/redirect_linux.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
|
||||
syscallConn, loaded := common.Cast[syscall.Conn](conn)
|
||||
if !loaded {
|
||||
return netip.AddrPort{}, os.ErrInvalid
|
||||
}
|
||||
return Conn0[netip.AddrPort](syscallConn, func(fd uintptr) (netip.AddrPort, error) {
|
||||
if M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap().IsIPv4() {
|
||||
raw, err := unix.GetsockoptIPv6Mreq(int(fd), unix.IPPROTO_IP, unix.SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
return netip.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
|
||||
} else {
|
||||
raw, err := unix.GetsockoptIPv6MTUInfo(int(fd), unix.IPPROTO_IPV6, unix.SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
var port [2]byte
|
||||
binary.BigEndian.PutUint16(port[:], raw.Addr.Port)
|
||||
return netip.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), binary.LittleEndian.Uint16(port[:])), nil
|
||||
}
|
||||
})
|
||||
}
|
13
common/control/redirect_other.go
Normal file
13
common/control/redirect_other.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
//go:build !linux && !darwin
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
)
|
||||
|
||||
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, os.ErrInvalid
|
||||
}
|
56
common/control/tproxy_linux.go
Normal file
56
common/control/tproxy_linux.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func TProxy(fd uintptr, family int) error {
|
||||
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
|
||||
if err == nil {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
|
||||
}
|
||||
if err == nil && family == unix.AF_INET6 {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
|
||||
}
|
||||
if err == nil {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
|
||||
}
|
||||
if err == nil && family == unix.AF_INET6 {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func TProxyWriteBack() Func {
|
||||
return func(network, address string, conn syscall.RawConn) error {
|
||||
return Raw(conn, func(fd uintptr) error {
|
||||
if M.ParseSocksaddr(address).Addr.Is6() {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
|
||||
} else {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
|
||||
controlMessages, err := unix.ParseSocketControlMessage(oob)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
for _, message := range controlMessages {
|
||||
if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR {
|
||||
return netip.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil
|
||||
} else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR {
|
||||
return netip.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil
|
||||
}
|
||||
}
|
||||
return netip.AddrPort{}, E.New("not found")
|
||||
}
|
20
common/control/tproxy_other.go
Normal file
20
common/control/tproxy_other.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
//go:build !linux
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
)
|
||||
|
||||
func TProxy(fd uintptr, isIPv6 bool) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func TProxyWriteBack() Func {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, os.ErrInvalid
|
||||
}
|
67
common/domain/adguard_matcher_test.go
Normal file
67
common/domain/adguard_matcher_test.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package domain_test
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/domain"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdGuardMatcher(t *testing.T) {
|
||||
t.Parallel()
|
||||
ruleLines := []string{
|
||||
"||example.org^",
|
||||
"|example.com^",
|
||||
"example.net^",
|
||||
"||example.edu",
|
||||
"||example.edu.tw^",
|
||||
"|example.gov",
|
||||
"example.arpa",
|
||||
}
|
||||
matcher := domain.NewAdGuardMatcher(ruleLines)
|
||||
require.NotNil(t, matcher)
|
||||
matchDomain := []string{
|
||||
"example.org",
|
||||
"www.example.org",
|
||||
"example.com",
|
||||
"example.net",
|
||||
"isexample.net",
|
||||
"www.example.net",
|
||||
"example.edu",
|
||||
"example.edu.cn",
|
||||
"example.edu.tw",
|
||||
"www.example.edu",
|
||||
"www.example.edu.cn",
|
||||
"example.gov",
|
||||
"example.gov.cn",
|
||||
"example.arpa",
|
||||
"www.example.arpa",
|
||||
"isexample.arpa",
|
||||
"example.arpa.cn",
|
||||
"www.example.arpa.cn",
|
||||
"isexample.arpa.cn",
|
||||
}
|
||||
notMatchDomain := []string{
|
||||
"example.org.cn",
|
||||
"notexample.org",
|
||||
"example.com.cn",
|
||||
"www.example.com.cn",
|
||||
"example.net.cn",
|
||||
"notexample.edu",
|
||||
"notexample.edu.cn",
|
||||
"www.example.gov",
|
||||
"notexample.gov",
|
||||
}
|
||||
for _, domain := range matchDomain {
|
||||
require.True(t, matcher.Match(domain), domain)
|
||||
}
|
||||
for _, domain := range notMatchDomain {
|
||||
require.False(t, matcher.Match(domain), domain)
|
||||
}
|
||||
dLines := matcher.Dump()
|
||||
sort.Strings(ruleLines)
|
||||
sort.Strings(dLines)
|
||||
require.Equal(t, ruleLines, dLines)
|
||||
}
|
172
common/domain/adgurad_matcher.go
Normal file
172
common/domain/adgurad_matcher.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
package domain
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
)
|
||||
|
||||
const (
|
||||
anyLabel = '*'
|
||||
suffixLabel = '\b'
|
||||
)
|
||||
|
||||
type AdGuardMatcher struct {
|
||||
set *succinctSet
|
||||
}
|
||||
|
||||
func NewAdGuardMatcher(ruleLines []string) *AdGuardMatcher {
|
||||
ruleList := make([]string, 0, len(ruleLines))
|
||||
for _, ruleLine := range ruleLines {
|
||||
var (
|
||||
isSuffix bool // ||
|
||||
hasStart bool // |
|
||||
hasEnd bool // ^
|
||||
)
|
||||
if strings.HasPrefix(ruleLine, "||") {
|
||||
ruleLine = ruleLine[2:]
|
||||
isSuffix = true
|
||||
} else if strings.HasPrefix(ruleLine, "|") {
|
||||
ruleLine = ruleLine[1:]
|
||||
hasStart = true
|
||||
}
|
||||
if strings.HasSuffix(ruleLine, "^") {
|
||||
ruleLine = ruleLine[:len(ruleLine)-1]
|
||||
hasEnd = true
|
||||
}
|
||||
if isSuffix {
|
||||
ruleLine = string(rootLabel) + ruleLine
|
||||
} else if !hasStart {
|
||||
ruleLine = string(prefixLabel) + ruleLine
|
||||
}
|
||||
if !hasEnd {
|
||||
if strings.HasSuffix(ruleLine, ".") {
|
||||
ruleLine = ruleLine[:len(ruleLine)-1]
|
||||
}
|
||||
ruleLine += string(suffixLabel)
|
||||
}
|
||||
ruleList = append(ruleList, reverseDomain(ruleLine))
|
||||
}
|
||||
ruleList = common.Uniq(ruleList)
|
||||
sort.Strings(ruleList)
|
||||
return &AdGuardMatcher{newSuccinctSet(ruleList)}
|
||||
}
|
||||
|
||||
func ReadAdGuardMatcher(reader varbin.Reader) (*AdGuardMatcher, error) {
|
||||
set, err := readSuccinctSet(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AdGuardMatcher{set}, nil
|
||||
}
|
||||
|
||||
func (m *AdGuardMatcher) Write(writer varbin.Writer) error {
|
||||
return m.set.Write(writer)
|
||||
}
|
||||
|
||||
func (m *AdGuardMatcher) Match(domain string) bool {
|
||||
key := reverseDomain(domain)
|
||||
if m.has([]byte(key), 0, 0) {
|
||||
return true
|
||||
}
|
||||
for {
|
||||
if m.has([]byte(string(suffixLabel)+key), 0, 0) {
|
||||
return true
|
||||
}
|
||||
idx := strings.IndexByte(key, '.')
|
||||
if idx == -1 {
|
||||
return false
|
||||
}
|
||||
key = key[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AdGuardMatcher) has(key []byte, nodeId, bmIdx int) bool {
|
||||
for i := 0; i < len(key); i++ {
|
||||
currentChar := key[i]
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(m.set.labelBitmap, bmIdx) != 0 {
|
||||
return false
|
||||
}
|
||||
nextLabel := m.set.labels[bmIdx-nodeId]
|
||||
if nextLabel == prefixLabel {
|
||||
return true
|
||||
}
|
||||
if nextLabel == rootLabel {
|
||||
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
||||
hasNext := getBit(m.set.leaves, nextNodeId) != 0
|
||||
if currentChar == '.' && hasNext {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if nextLabel == currentChar {
|
||||
break
|
||||
}
|
||||
if nextLabel == anyLabel {
|
||||
idx := bytes.IndexRune(key[i:], '.')
|
||||
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
||||
if idx == -1 {
|
||||
if getBit(m.set.leaves, nextNodeId) != 0 {
|
||||
return true
|
||||
}
|
||||
idx = 0
|
||||
}
|
||||
nextBmIdx := selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nextNodeId-1) + 1
|
||||
if m.has(key[i+idx:], nextNodeId, nextBmIdx) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
||||
bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1
|
||||
}
|
||||
if getBit(m.set.leaves, nodeId) != 0 {
|
||||
return true
|
||||
}
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(m.set.labelBitmap, bmIdx) != 0 {
|
||||
return false
|
||||
}
|
||||
nextLabel := m.set.labels[bmIdx-nodeId]
|
||||
if nextLabel == prefixLabel || nextLabel == rootLabel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AdGuardMatcher) Dump() (ruleLines []string) {
|
||||
for _, key := range m.set.keys() {
|
||||
key = reverseDomain(key)
|
||||
var (
|
||||
isSuffix bool
|
||||
hasStart bool
|
||||
hasEnd bool
|
||||
)
|
||||
if key[0] == prefixLabel {
|
||||
key = key[1:]
|
||||
} else if key[0] == rootLabel {
|
||||
key = key[1:]
|
||||
isSuffix = true
|
||||
} else {
|
||||
hasStart = true
|
||||
}
|
||||
if key[len(key)-1] == suffixLabel {
|
||||
key = key[:len(key)-1]
|
||||
} else {
|
||||
hasEnd = true
|
||||
}
|
||||
if isSuffix {
|
||||
key = "||" + key
|
||||
} else if hasStart {
|
||||
key = "|" + key
|
||||
}
|
||||
if hasEnd {
|
||||
key += "^"
|
||||
}
|
||||
ruleLines = append(ruleLines, key)
|
||||
}
|
||||
return
|
||||
}
|
|
@ -1,19 +1,22 @@
|
|||
package domain
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"sort"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
)
|
||||
|
||||
const (
|
||||
prefixLabel = '\r'
|
||||
rootLabel = '\n'
|
||||
)
|
||||
|
||||
type Matcher struct {
|
||||
set *succinctSet
|
||||
}
|
||||
|
||||
func NewMatcher(domains []string, domainSuffix []string) *Matcher {
|
||||
func NewMatcher(domains []string, domainSuffix []string, generateLegacy bool) *Matcher {
|
||||
domainList := make([]string, 0, len(domains)+2*len(domainSuffix))
|
||||
seen := make(map[string]bool, len(domainList))
|
||||
for _, domain := range domainSuffix {
|
||||
|
@ -22,10 +25,16 @@ func NewMatcher(domains []string, domainSuffix []string) *Matcher {
|
|||
}
|
||||
seen[domain] = true
|
||||
if domain[0] == '.' {
|
||||
domainList = append(domainList, reverseDomainSuffix(domain))
|
||||
} else {
|
||||
domainList = append(domainList, reverseDomain(string(prefixLabel)+domain))
|
||||
} else if generateLegacy {
|
||||
domainList = append(domainList, reverseDomain(domain))
|
||||
domainList = append(domainList, reverseRootDomainSuffix(domain))
|
||||
suffixDomain := "." + domain
|
||||
if !seen[suffixDomain] {
|
||||
seen[suffixDomain] = true
|
||||
domainList = append(domainList, reverseDomain(string(prefixLabel)+suffixDomain))
|
||||
}
|
||||
} else {
|
||||
domainList = append(domainList, reverseDomain(string(rootLabel)+domain))
|
||||
}
|
||||
}
|
||||
for _, domain := range domains {
|
||||
|
@ -39,82 +48,91 @@ func NewMatcher(domains []string, domainSuffix []string) *Matcher {
|
|||
return &Matcher{newSuccinctSet(domainList)}
|
||||
}
|
||||
|
||||
func ReadMatcher(reader io.Reader) (*Matcher, error) {
|
||||
var version uint8
|
||||
err := binary.Read(reader, binary.BigEndian, &version)
|
||||
func ReadMatcher(reader varbin.Reader) (*Matcher, error) {
|
||||
set, err := readSuccinctSet(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leavesLength, err := rw.ReadUVariant(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leaves := make([]uint64, leavesLength)
|
||||
err = binary.Read(reader, binary.BigEndian, leaves)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labelBitmapLength, err := rw.ReadUVariant(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labelBitmap := make([]uint64, labelBitmapLength)
|
||||
err = binary.Read(reader, binary.BigEndian, labelBitmap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labelsLength, err := rw.ReadUVariant(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labels := make([]byte, labelsLength)
|
||||
_, err = io.ReadFull(reader, labels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
set := &succinctSet{
|
||||
leaves: leaves,
|
||||
labelBitmap: labelBitmap,
|
||||
labels: labels,
|
||||
}
|
||||
set.init()
|
||||
return &Matcher{set}, nil
|
||||
}
|
||||
|
||||
func (m *Matcher) Match(domain string) bool {
|
||||
return m.set.Has(reverseDomain(domain))
|
||||
func (m *Matcher) Write(writer varbin.Writer) error {
|
||||
return m.set.Write(writer)
|
||||
}
|
||||
|
||||
func (m *Matcher) Write(writer io.Writer) error {
|
||||
err := binary.Write(writer, binary.BigEndian, byte(1))
|
||||
if err != nil {
|
||||
return err
|
||||
func (m *Matcher) Match(domain string) bool {
|
||||
return m.has(reverseDomain(domain))
|
||||
}
|
||||
|
||||
func (m *Matcher) has(key string) bool {
|
||||
var nodeId, bmIdx int
|
||||
for i := 0; i < len(key); i++ {
|
||||
currentChar := key[i]
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(m.set.labelBitmap, bmIdx) != 0 {
|
||||
return false
|
||||
}
|
||||
nextLabel := m.set.labels[bmIdx-nodeId]
|
||||
if nextLabel == prefixLabel {
|
||||
return true
|
||||
}
|
||||
if nextLabel == rootLabel {
|
||||
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
||||
hasNext := getBit(m.set.leaves, nextNodeId) != 0
|
||||
if currentChar == '.' && hasNext {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if nextLabel == currentChar {
|
||||
break
|
||||
}
|
||||
}
|
||||
nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
|
||||
bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1
|
||||
}
|
||||
err = rw.WriteUVariant(writer, uint64(len(m.set.leaves)))
|
||||
if err != nil {
|
||||
return err
|
||||
if getBit(m.set.leaves, nodeId) != 0 {
|
||||
return true
|
||||
}
|
||||
err = binary.Write(writer, binary.BigEndian, m.set.leaves)
|
||||
if err != nil {
|
||||
return err
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(m.set.labelBitmap, bmIdx) != 0 {
|
||||
return false
|
||||
}
|
||||
nextLabel := m.set.labels[bmIdx-nodeId]
|
||||
if nextLabel == prefixLabel || nextLabel == rootLabel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
err = rw.WriteUVariant(writer, uint64(len(m.set.labelBitmap)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Matcher) Dump() (domainList []string, prefixList []string) {
|
||||
domainMap := make(map[string]bool)
|
||||
prefixMap := make(map[string]bool)
|
||||
for _, key := range m.set.keys() {
|
||||
key = reverseDomain(key)
|
||||
if key[0] == prefixLabel {
|
||||
prefixMap[key[1:]] = true
|
||||
} else if key[0] == rootLabel {
|
||||
prefixList = append(prefixList, key[1:])
|
||||
} else {
|
||||
domainMap[key] = true
|
||||
}
|
||||
}
|
||||
err = binary.Write(writer, binary.BigEndian, m.set.labelBitmap)
|
||||
if err != nil {
|
||||
return err
|
||||
for rawPrefix := range prefixMap {
|
||||
if rawPrefix[0] == '.' {
|
||||
if rootDomain := rawPrefix[1:]; domainMap[rootDomain] {
|
||||
delete(domainMap, rootDomain)
|
||||
prefixList = append(prefixList, rootDomain)
|
||||
continue
|
||||
}
|
||||
}
|
||||
prefixList = append(prefixList, rawPrefix)
|
||||
}
|
||||
err = rw.WriteUVariant(writer, uint64(len(m.set.labels)))
|
||||
if err != nil {
|
||||
return err
|
||||
for domain := range domainMap {
|
||||
domainList = append(domainList, domain)
|
||||
}
|
||||
_, err = writer.Write(m.set.labels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
sort.Strings(domainList)
|
||||
sort.Strings(prefixList)
|
||||
return domainList, prefixList
|
||||
}
|
||||
|
||||
func reverseDomain(domain string) string {
|
||||
|
@ -127,28 +145,3 @@ func reverseDomain(domain string) string {
|
|||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func reverseDomainSuffix(domain string) string {
|
||||
l := len(domain)
|
||||
b := make([]byte, l+1)
|
||||
for i := 0; i < l; {
|
||||
r, n := utf8.DecodeRuneInString(domain[i:])
|
||||
i += n
|
||||
utf8.EncodeRune(b[l-i:], r)
|
||||
}
|
||||
b[l] = prefixLabel
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func reverseRootDomainSuffix(domain string) string {
|
||||
l := len(domain)
|
||||
b := make([]byte, l+2)
|
||||
for i := 0; i < l; {
|
||||
r, n := utf8.DecodeRuneInString(domain[i:])
|
||||
i += n
|
||||
utf8.EncodeRune(b[l-i:], r)
|
||||
}
|
||||
b[l] = '.'
|
||||
b[l+1] = prefixLabel
|
||||
return string(b)
|
||||
}
|
||||
|
|
80
common/domain/matcher_test.go
Normal file
80
common/domain/matcher_test.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package domain_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/domain"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMatcher(t *testing.T) {
|
||||
t.Parallel()
|
||||
testDomain := []string{"example.com", "example.org"}
|
||||
testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"}
|
||||
matcher := domain.NewMatcher(testDomain, testDomainSuffix, false)
|
||||
require.NotNil(t, matcher)
|
||||
require.True(t, matcher.Match("example.com"))
|
||||
require.True(t, matcher.Match("example.org"))
|
||||
require.False(t, matcher.Match("example.cn"))
|
||||
require.True(t, matcher.Match("example.com.cn"))
|
||||
require.True(t, matcher.Match("example.org.cn"))
|
||||
require.False(t, matcher.Match("com.cn"))
|
||||
require.False(t, matcher.Match("org.cn"))
|
||||
require.True(t, matcher.Match("sagernet.org"))
|
||||
require.True(t, matcher.Match("sing-box.sagernet.org"))
|
||||
dDomain, dDomainSuffix := matcher.Dump()
|
||||
require.Equal(t, testDomain, dDomain)
|
||||
require.Equal(t, testDomainSuffix, dDomainSuffix)
|
||||
}
|
||||
|
||||
func TestMatcherLegacy(t *testing.T) {
|
||||
t.Parallel()
|
||||
testDomain := []string{"example.com", "example.org"}
|
||||
testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"}
|
||||
matcher := domain.NewMatcher(testDomain, testDomainSuffix, true)
|
||||
require.NotNil(t, matcher)
|
||||
require.True(t, matcher.Match("example.com"))
|
||||
require.True(t, matcher.Match("example.org"))
|
||||
require.False(t, matcher.Match("example.cn"))
|
||||
require.True(t, matcher.Match("example.com.cn"))
|
||||
require.True(t, matcher.Match("example.org.cn"))
|
||||
require.False(t, matcher.Match("com.cn"))
|
||||
require.False(t, matcher.Match("org.cn"))
|
||||
require.True(t, matcher.Match("sagernet.org"))
|
||||
require.True(t, matcher.Match("sing-box.sagernet.org"))
|
||||
dDomain, dDomainSuffix := matcher.Dump()
|
||||
require.Equal(t, testDomain, dDomain)
|
||||
require.Equal(t, testDomainSuffix, dDomainSuffix)
|
||||
}
|
||||
|
||||
type simpleRuleSet struct {
|
||||
Rules []struct {
|
||||
Domain []string `json:"domain"`
|
||||
DomainSuffix []string `json:"domain_suffix"`
|
||||
}
|
||||
}
|
||||
|
||||
func TestDumpLarge(t *testing.T) {
|
||||
t.Parallel()
|
||||
response, err := http.Get("https://raw.githubusercontent.com/MetaCubeX/meta-rules-dat/sing/geo/geosite/cn.json")
|
||||
require.NoError(t, err)
|
||||
defer response.Body.Close()
|
||||
var ruleSet simpleRuleSet
|
||||
err = json.NewDecoder(response.Body).Decode(&ruleSet)
|
||||
require.NoError(t, err)
|
||||
domainList := ruleSet.Rules[0].Domain
|
||||
domainSuffixList := ruleSet.Rules[0].DomainSuffix
|
||||
require.Len(t, ruleSet.Rules, 1)
|
||||
require.True(t, len(domainList)+len(domainSuffixList) > 0)
|
||||
sort.Strings(domainList)
|
||||
sort.Strings(domainSuffixList)
|
||||
matcher := domain.NewMatcher(domainList, domainSuffixList, false)
|
||||
require.NotNil(t, matcher)
|
||||
dDomain, dDomainSuffix := matcher.Dump()
|
||||
require.Equal(t, domainList, dDomain)
|
||||
require.Equal(t, domainSuffixList, dDomainSuffix)
|
||||
}
|
|
@ -1,10 +1,11 @@
|
|||
package domain
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
const prefixLabel = '\r'
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
)
|
||||
|
||||
// mod from https://github.com/openacid/succinct
|
||||
|
||||
|
@ -42,36 +43,61 @@ func newSuccinctSet(keys []string) *succinctSet {
|
|||
return ss
|
||||
}
|
||||
|
||||
func (ss *succinctSet) Has(key string) bool {
|
||||
var nodeId, bmIdx int
|
||||
for i := 0; i < len(key); i++ {
|
||||
currentChar := key[i]
|
||||
func (ss *succinctSet) keys() []string {
|
||||
var result []string
|
||||
var currentKey []byte
|
||||
var bmIdx, nodeId int
|
||||
|
||||
var traverse func(int, int)
|
||||
traverse = func(nodeId, bmIdx int) {
|
||||
if getBit(ss.leaves, nodeId) != 0 {
|
||||
result = append(result, string(currentKey))
|
||||
}
|
||||
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(ss.labelBitmap, bmIdx) != 0 {
|
||||
return false
|
||||
return
|
||||
}
|
||||
nextLabel := ss.labels[bmIdx-nodeId]
|
||||
if nextLabel == prefixLabel {
|
||||
return true
|
||||
}
|
||||
if nextLabel == currentChar {
|
||||
break
|
||||
}
|
||||
}
|
||||
nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
|
||||
bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1
|
||||
}
|
||||
if getBit(ss.leaves, nodeId) != 0 {
|
||||
return true
|
||||
}
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(ss.labelBitmap, bmIdx) != 0 {
|
||||
return false
|
||||
}
|
||||
if ss.labels[bmIdx-nodeId] == prefixLabel {
|
||||
return true
|
||||
currentKey = append(currentKey, nextLabel)
|
||||
nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
|
||||
nextBmIdx := selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nextNodeId-1) + 1
|
||||
traverse(nextNodeId, nextBmIdx)
|
||||
currentKey = currentKey[:len(currentKey)-1]
|
||||
}
|
||||
}
|
||||
|
||||
traverse(nodeId, bmIdx)
|
||||
return result
|
||||
}
|
||||
|
||||
type succinctSetData struct {
|
||||
Reserved uint8
|
||||
Leaves []uint64
|
||||
LabelBitmap []uint64
|
||||
Labels []byte
|
||||
}
|
||||
|
||||
func readSuccinctSet(reader varbin.Reader) (*succinctSet, error) {
|
||||
matcher, err := varbin.ReadValue[succinctSetData](reader, binary.BigEndian)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
set := &succinctSet{
|
||||
leaves: matcher.Leaves,
|
||||
labelBitmap: matcher.LabelBitmap,
|
||||
labels: matcher.Labels,
|
||||
}
|
||||
set.init()
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func (ss *succinctSet) Write(writer varbin.Writer) error {
|
||||
return varbin.Write(writer, binary.BigEndian, succinctSetData{
|
||||
Leaves: ss.leaves,
|
||||
LabelBitmap: ss.labelBitmap,
|
||||
Labels: ss.labels,
|
||||
})
|
||||
}
|
||||
|
||||
func setBit(bm *[]uint64, i int, v int) {
|
||||
|
|
|
@ -12,3 +12,16 @@ func (e *causeError) Error() string {
|
|||
func (e *causeError) Unwrap() error {
|
||||
return e.cause
|
||||
}
|
||||
|
||||
type causeError1 struct {
|
||||
error
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e *causeError1) Error() string {
|
||||
return e.error.Error() + ": " + e.cause.Error()
|
||||
}
|
||||
|
||||
func (e *causeError1) Unwrap() []error {
|
||||
return []error{e.error, e.cause}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
F "github.com/sagernet/sing/common/format"
|
||||
)
|
||||
|
||||
// Deprecated: wtf is this?
|
||||
type Handler interface {
|
||||
NewError(ctx context.Context, err error)
|
||||
}
|
||||
|
@ -31,6 +32,13 @@ func Cause(cause error, message ...any) error {
|
|||
return &causeError{F.ToString(message...), cause}
|
||||
}
|
||||
|
||||
func Cause1(err error, cause error) error {
|
||||
if cause == nil {
|
||||
panic("cause on an nil error")
|
||||
}
|
||||
return &causeError1{err, cause}
|
||||
}
|
||||
|
||||
func Extend(cause error, message ...any) error {
|
||||
if cause == nil {
|
||||
panic("extend on an nil error")
|
||||
|
@ -39,11 +47,11 @@ func Extend(cause error, message ...any) error {
|
|||
}
|
||||
|
||||
func IsClosedOrCanceled(err error) bool {
|
||||
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded)
|
||||
return IsClosed(err) || IsCanceled(err) || IsTimeout(err)
|
||||
}
|
||||
|
||||
func IsClosed(err error) bool {
|
||||
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET)
|
||||
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, syscall.ENOTCONN)
|
||||
}
|
||||
|
||||
func IsCanceled(err error) bool {
|
||||
|
|
|
@ -1,24 +1,14 @@
|
|||
package exceptions
|
||||
|
||||
import "github.com/sagernet/sing/common"
|
||||
import (
|
||||
"errors"
|
||||
|
||||
type HasInnerError interface {
|
||||
Unwrap() error
|
||||
}
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
// Deprecated: Use errors.Unwrap instead.
|
||||
func Unwrap(err error) error {
|
||||
for {
|
||||
inner, ok := err.(HasInnerError)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
innerErr := inner.Unwrap()
|
||||
if innerErr == nil {
|
||||
break
|
||||
}
|
||||
err = innerErr
|
||||
}
|
||||
return err
|
||||
return errors.Unwrap(err)
|
||||
}
|
||||
|
||||
func Cast[T any](err error) (T, bool) {
|
||||
|
|
|
@ -63,12 +63,5 @@ func IsMulti(err error, targetList ...error) bool {
|
|||
return true
|
||||
}
|
||||
}
|
||||
err = Unwrap(err)
|
||||
multiErr, isMulti := err.(MultiError)
|
||||
if !isMulti {
|
||||
return false
|
||||
}
|
||||
return common.All(multiErr.Unwrap(), func(it error) bool {
|
||||
return IsMulti(it, targetList...)
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -1,17 +1,21 @@
|
|||
package exceptions
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
type TimeoutError interface {
|
||||
Timeout() bool
|
||||
}
|
||||
|
||||
func IsTimeout(err error) bool {
|
||||
if netErr, isNetErr := err.(net.Error); isNetErr {
|
||||
//goland:noinspection GoDeprecation
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
//nolint:staticcheck
|
||||
return netErr.Temporary() && netErr.Timeout()
|
||||
} else if timeoutErr, isTimeout := Cast[TimeoutError](err); isTimeout {
|
||||
}
|
||||
if timeoutErr, isTimeout := Cast[TimeoutError](err); isTimeout {
|
||||
return timeoutErr.Timeout()
|
||||
}
|
||||
return false
|
||||
|
|
|
@ -2,13 +2,14 @@ package badjson
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
func Decode(content []byte) (any, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||
func Decode(ctx context.Context, content []byte) (any, error) {
|
||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
||||
return decodeJSON(decoder)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
|
@ -9,75 +10,75 @@ import (
|
|||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
func Omitempty[T any](value T) (T, error) {
|
||||
objectContent, err := json.Marshal(value)
|
||||
func Omitempty[T any](ctx context.Context, value T) (T, error) {
|
||||
objectContent, err := json.MarshalContext(ctx, value)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal object")
|
||||
}
|
||||
rawNewObject, err := Decode(objectContent)
|
||||
rawNewObject, err := Decode(ctx, objectContent)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), err
|
||||
}
|
||||
newObjectContent, err := json.Marshal(rawNewObject)
|
||||
newObjectContent, err := json.MarshalContext(ctx, rawNewObject)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
|
||||
}
|
||||
var newObject T
|
||||
err = json.Unmarshal(newObjectContent, &newObject)
|
||||
err = json.UnmarshalContext(ctx, newObjectContent, &newObject)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
|
||||
}
|
||||
return newObject, nil
|
||||
}
|
||||
|
||||
func Merge[T any](source T, destination T) (T, error) {
|
||||
rawSource, err := json.Marshal(source)
|
||||
func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) {
|
||||
rawSource, err := json.MarshalContext(ctx, source)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
||||
}
|
||||
rawDestination, err := json.Marshal(destination)
|
||||
rawDestination, err := json.MarshalContext(ctx, destination)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
||||
}
|
||||
return MergeFrom[T](rawSource, rawDestination)
|
||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
||||
}
|
||||
|
||||
func MergeFromSource[T any](rawSource json.RawMessage, destination T) (T, error) {
|
||||
func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
|
||||
if rawSource == nil {
|
||||
return destination, nil
|
||||
}
|
||||
rawDestination, err := json.Marshal(destination)
|
||||
rawDestination, err := json.MarshalContext(ctx, destination)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
||||
}
|
||||
return MergeFrom[T](rawSource, rawDestination)
|
||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
||||
}
|
||||
|
||||
func MergeFromDestination[T any](source T, rawDestination json.RawMessage) (T, error) {
|
||||
func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
||||
if rawDestination == nil {
|
||||
return source, nil
|
||||
}
|
||||
rawSource, err := json.Marshal(source)
|
||||
rawSource, err := json.MarshalContext(ctx, source)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
||||
}
|
||||
return MergeFrom[T](rawSource, rawDestination)
|
||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
||||
}
|
||||
|
||||
func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage) (T, error) {
|
||||
rawMerged, err := MergeJSON(rawSource, rawDestination)
|
||||
func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
||||
rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "merge options")
|
||||
}
|
||||
var merged T
|
||||
err = json.Unmarshal(rawMerged, &merged)
|
||||
err = json.UnmarshalContext(ctx, rawMerged, &merged)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage) (json.RawMessage, error) {
|
||||
func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
|
||||
if rawSource == nil && rawDestination == nil {
|
||||
return nil, os.ErrInvalid
|
||||
} else if rawSource == nil {
|
||||
|
@ -85,34 +86,36 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage) (json.
|
|||
} else if rawDestination == nil {
|
||||
return rawSource, nil
|
||||
}
|
||||
source, err := Decode(rawSource)
|
||||
source, err := Decode(ctx, rawSource)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode source")
|
||||
}
|
||||
destination, err := Decode(rawDestination)
|
||||
destination, err := Decode(ctx, rawDestination)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode destination")
|
||||
}
|
||||
if source == nil {
|
||||
return json.Marshal(destination)
|
||||
return json.MarshalContext(ctx, destination)
|
||||
} else if destination == nil {
|
||||
return json.Marshal(source)
|
||||
}
|
||||
merged, err := mergeJSON(source, destination)
|
||||
merged, err := mergeJSON(source, destination, disableAppend)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(merged)
|
||||
return json.MarshalContext(ctx, merged)
|
||||
}
|
||||
|
||||
func mergeJSON(anySource any, anyDestination any) (any, error) {
|
||||
func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) {
|
||||
switch destination := anyDestination.(type) {
|
||||
case JSONArray:
|
||||
switch source := anySource.(type) {
|
||||
case JSONArray:
|
||||
destination = append(destination, source...)
|
||||
default:
|
||||
destination = append(destination, source)
|
||||
if !disableAppend {
|
||||
switch source := anySource.(type) {
|
||||
case JSONArray:
|
||||
destination = append(destination, source...)
|
||||
default:
|
||||
destination = append(destination, source)
|
||||
}
|
||||
}
|
||||
return destination, nil
|
||||
case *JSONObject:
|
||||
|
@ -122,7 +125,7 @@ func mergeJSON(anySource any, anyDestination any) (any, error) {
|
|||
oldValue, loaded := destination.Get(entry.Key)
|
||||
if loaded {
|
||||
var err error
|
||||
entry.Value, err = mergeJSON(entry.Value, oldValue)
|
||||
entry.Value, err = mergeJSON(entry.Value, oldValue, disableAppend)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "merge object item ", entry.Key)
|
||||
}
|
||||
|
|
68
common/json/badjson/merge_objects.go
Normal file
68
common/json/badjson/merge_objects.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
cJSON "github.com/sagernet/sing/common/json/internal/contextjson"
|
||||
)
|
||||
|
||||
func MarshallObjects(objects ...any) ([]byte, error) {
|
||||
return MarshallObjectsContext(context.Background(), objects...)
|
||||
}
|
||||
|
||||
func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) {
|
||||
if len(objects) == 1 {
|
||||
return json.Marshal(objects[0])
|
||||
}
|
||||
var content JSONObject
|
||||
for _, object := range objects {
|
||||
objectMap, err := newJSONObject(ctx, object)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content.PutAll(objectMap)
|
||||
}
|
||||
return content.MarshalJSONContext(ctx)
|
||||
}
|
||||
|
||||
func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error {
|
||||
return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object)
|
||||
}
|
||||
|
||||
func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
|
||||
var content JSONObject
|
||||
err := content.UnmarshalJSONContext(ctx, inputContent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) {
|
||||
content.Remove(key)
|
||||
}
|
||||
if object == nil {
|
||||
if content.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
return E.New("unexpected key: ", content.Keys()[0])
|
||||
}
|
||||
inputContent, err = content.MarshalJSONContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object)
|
||||
}
|
||||
|
||||
func newJSONObject(ctx context.Context, object any) (*JSONObject, error) {
|
||||
inputContent, err := json.MarshalContext(ctx, object)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var content JSONObject
|
||||
err = content.UnmarshalJSONContext(ctx, inputContent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &content, nil
|
||||
}
|
|
@ -2,6 +2,7 @@ package badjson
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
|
@ -28,6 +29,10 @@ func (m *JSONObject) IsEmpty() bool {
|
|||
}
|
||||
|
||||
func (m *JSONObject) MarshalJSON() ([]byte, error) {
|
||||
return m.MarshalJSONContext(context.Background())
|
||||
}
|
||||
|
||||
func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
buffer := new(bytes.Buffer)
|
||||
buffer.WriteString("{")
|
||||
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
|
||||
|
@ -38,13 +43,13 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) {
|
|||
})
|
||||
iLen := len(items)
|
||||
for i, entry := range items {
|
||||
keyContent, err := json.Marshal(entry.Key)
|
||||
keyContent, err := json.MarshalContext(ctx, entry.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
||||
buffer.WriteString(": ")
|
||||
valueContent, err := json.Marshal(entry.Value)
|
||||
valueContent, err := json.MarshalContext(ctx, entry.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -58,7 +63,11 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
func (m *JSONObject) UnmarshalJSON(content []byte) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||
return m.UnmarshalJSONContext(context.Background(), content)
|
||||
}
|
||||
|
||||
func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
||||
m.Clear()
|
||||
objectStart, err := decoder.Token()
|
||||
if err != nil {
|
||||
|
|
|
@ -2,6 +2,7 @@ package badjson
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
@ -14,18 +15,22 @@ type TypedMap[K comparable, V any] struct {
|
|||
}
|
||||
|
||||
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
|
||||
return m.MarshalJSONContext(context.Background())
|
||||
}
|
||||
|
||||
func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
buffer := new(bytes.Buffer)
|
||||
buffer.WriteString("{")
|
||||
items := m.Entries()
|
||||
iLen := len(items)
|
||||
for i, entry := range items {
|
||||
keyContent, err := json.Marshal(entry.Key)
|
||||
keyContent, err := json.MarshalContext(ctx, entry.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
||||
buffer.WriteString(": ")
|
||||
valueContent, err := json.Marshal(entry.Value)
|
||||
valueContent, err := json.MarshalContext(ctx, entry.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -39,7 +44,11 @@ func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||
return m.UnmarshalJSONContext(context.Background(), content)
|
||||
}
|
||||
|
||||
func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
||||
m.Clear()
|
||||
objectStart, err := decoder.Token()
|
||||
if err != nil {
|
||||
|
@ -47,7 +56,7 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
|
|||
} else if objectStart != json.Delim('{') {
|
||||
return E.New("expected json object start, but starts with ", objectStart)
|
||||
}
|
||||
err = m.decodeJSON(decoder)
|
||||
err = m.decodeJSON(ctx, decoder)
|
||||
if err != nil {
|
||||
return E.Cause(err, "decode json object content")
|
||||
}
|
||||
|
@ -60,18 +69,18 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error {
|
||||
func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error {
|
||||
for decoder.More() {
|
||||
keyToken, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyContent, err := json.Marshal(keyToken)
|
||||
keyContent, err := json.MarshalContext(ctx, keyToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var entryKey K
|
||||
err = json.Unmarshal(keyContent, &entryKey)
|
||||
err = json.UnmarshalContext(ctx, keyContent, &entryKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
32
common/json/badoption/duration.go
Normal file
32
common/json/badoption/duration.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package badoption
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/json/badoption/internal/my_time"
|
||||
)
|
||||
|
||||
type Duration time.Duration
|
||||
|
||||
func (d Duration) Build() time.Duration {
|
||||
return time.Duration(d)
|
||||
}
|
||||
|
||||
func (d Duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal((time.Duration)(d).String())
|
||||
}
|
||||
|
||||
func (d *Duration) UnmarshalJSON(bytes []byte) error {
|
||||
var value string
|
||||
err := json.Unmarshal(bytes, &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
duration, err := my_time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*d = Duration(duration)
|
||||
return nil
|
||||
}
|
15
common/json/badoption/http.go
Normal file
15
common/json/badoption/http.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package badoption
|
||||
|
||||
import "net/http"
|
||||
|
||||
type HTTPHeader map[string]Listable[string]
|
||||
|
||||
func (h HTTPHeader) Build() http.Header {
|
||||
header := make(http.Header)
|
||||
for name, values := range h {
|
||||
for _, value := range values {
|
||||
header.Add(name, value)
|
||||
}
|
||||
}
|
||||
return header
|
||||
}
|
226
common/json/badoption/internal/my_time/format.go
Normal file
226
common/json/badoption/internal/my_time/format.go
Normal file
|
@ -0,0 +1,226 @@
|
|||
package my_time
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
const durationDay = 24 * time.Hour
|
||||
|
||||
var unitMap = map[string]uint64{
|
||||
"ns": uint64(time.Nanosecond),
|
||||
"us": uint64(time.Microsecond),
|
||||
"µs": uint64(time.Microsecond), // U+00B5 = micro symbol
|
||||
"μs": uint64(time.Microsecond), // U+03BC = Greek letter mu
|
||||
"ms": uint64(time.Millisecond),
|
||||
"s": uint64(time.Second),
|
||||
"m": uint64(time.Minute),
|
||||
"h": uint64(time.Hour),
|
||||
"d": uint64(durationDay),
|
||||
}
|
||||
|
||||
// ParseDuration parses a duration string.
|
||||
// A duration string is a possibly signed sequence of
|
||||
// decimal numbers, each with optional fraction and a unit suffix,
|
||||
// such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func ParseDuration(s string) (time.Duration, error) {
|
||||
// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
|
||||
orig := s
|
||||
var d uint64
|
||||
neg := false
|
||||
|
||||
// Consume [-+]?
|
||||
if s != "" {
|
||||
c := s[0]
|
||||
if c == '-' || c == '+' {
|
||||
neg = c == '-'
|
||||
s = s[1:]
|
||||
}
|
||||
}
|
||||
// Special case: if all that is left is "0", this is zero.
|
||||
if s == "0" {
|
||||
return 0, nil
|
||||
}
|
||||
if s == "" {
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
for s != "" {
|
||||
var (
|
||||
v, f uint64 // integers before, after decimal point
|
||||
scale float64 = 1 // value = v + f/scale
|
||||
)
|
||||
|
||||
var err error
|
||||
|
||||
// The next character must be [0-9.]
|
||||
if !(s[0] == '.' || '0' <= s[0] && s[0] <= '9') {
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
// Consume [0-9]*
|
||||
pl := len(s)
|
||||
v, s, err = leadingInt(s)
|
||||
if err != nil {
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
pre := pl != len(s) // whether we consumed anything before a period
|
||||
|
||||
// Consume (\.[0-9]*)?
|
||||
post := false
|
||||
if s != "" && s[0] == '.' {
|
||||
s = s[1:]
|
||||
pl := len(s)
|
||||
f, scale, s = leadingFraction(s)
|
||||
post = pl != len(s)
|
||||
}
|
||||
if !pre && !post {
|
||||
// no digits (e.g. ".s" or "-.s")
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
|
||||
// Consume unit.
|
||||
i := 0
|
||||
for ; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c == '.' || '0' <= c && c <= '9' {
|
||||
break
|
||||
}
|
||||
}
|
||||
if i == 0 {
|
||||
return 0, errors.New("time: missing unit in duration " + quote(orig))
|
||||
}
|
||||
u := s[:i]
|
||||
s = s[i:]
|
||||
unit, ok := unitMap[u]
|
||||
if !ok {
|
||||
return 0, errors.New("time: unknown unit " + quote(u) + " in duration " + quote(orig))
|
||||
}
|
||||
if v > 1<<63/unit {
|
||||
// overflow
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
v *= unit
|
||||
if f > 0 {
|
||||
// float64 is needed to be nanosecond accurate for fractions of hours.
|
||||
// v >= 0 && (f*unit/scale) <= 3.6e+12 (ns/h, h is the largest unit)
|
||||
v += uint64(float64(f) * (float64(unit) / scale))
|
||||
if v > 1<<63 {
|
||||
// overflow
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
}
|
||||
d += v
|
||||
if d > 1<<63 {
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
}
|
||||
if neg {
|
||||
return -time.Duration(d), nil
|
||||
}
|
||||
if d > 1<<63-1 {
|
||||
return 0, errors.New("time: invalid duration " + quote(orig))
|
||||
}
|
||||
return time.Duration(d), nil
|
||||
}
|
||||
|
||||
var errLeadingInt = errors.New("time: bad [0-9]*") // never printed
|
||||
|
||||
// leadingInt consumes the leading [0-9]* from s.
|
||||
func leadingInt[bytes []byte | string](s bytes) (x uint64, rem bytes, err error) {
|
||||
i := 0
|
||||
for ; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c < '0' || c > '9' {
|
||||
break
|
||||
}
|
||||
if x > 1<<63/10 {
|
||||
// overflow
|
||||
return 0, rem, errLeadingInt
|
||||
}
|
||||
x = x*10 + uint64(c) - '0'
|
||||
if x > 1<<63 {
|
||||
// overflow
|
||||
return 0, rem, errLeadingInt
|
||||
}
|
||||
}
|
||||
return x, s[i:], nil
|
||||
}
|
||||
|
||||
// leadingFraction consumes the leading [0-9]* from s.
|
||||
// It is used only for fractions, so does not return an error on overflow,
|
||||
// it just stops accumulating precision.
|
||||
func leadingFraction(s string) (x uint64, scale float64, rem string) {
|
||||
i := 0
|
||||
scale = 1
|
||||
overflow := false
|
||||
for ; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c < '0' || c > '9' {
|
||||
break
|
||||
}
|
||||
if overflow {
|
||||
continue
|
||||
}
|
||||
if x > (1<<63-1)/10 {
|
||||
// It's possible for overflow to give a positive number, so take care.
|
||||
overflow = true
|
||||
continue
|
||||
}
|
||||
y := x*10 + uint64(c) - '0'
|
||||
if y > 1<<63 {
|
||||
overflow = true
|
||||
continue
|
||||
}
|
||||
x = y
|
||||
scale *= 10
|
||||
}
|
||||
return x, scale, s[i:]
|
||||
}
|
||||
|
||||
// These are borrowed from unicode/utf8 and strconv and replicate behavior in
|
||||
// that package, since we can't take a dependency on either.
|
||||
const (
|
||||
lowerhex = "0123456789abcdef"
|
||||
runeSelf = 0x80
|
||||
runeError = '\uFFFD'
|
||||
)
|
||||
|
||||
func quote(s string) string {
|
||||
buf := make([]byte, 1, len(s)+2) // slice will be at least len(s) + quotes
|
||||
buf[0] = '"'
|
||||
for i, c := range s {
|
||||
if c >= runeSelf || c < ' ' {
|
||||
// This means you are asking us to parse a time.Duration or
|
||||
// time.Location with unprintable or non-ASCII characters in it.
|
||||
// We don't expect to hit this case very often. We could try to
|
||||
// reproduce strconv.Quote's behavior with full fidelity but
|
||||
// given how rarely we expect to hit these edge cases, speed and
|
||||
// conciseness are better.
|
||||
var width int
|
||||
if c == runeError {
|
||||
width = 1
|
||||
if i+2 < len(s) && s[i:i+3] == string(runeError) {
|
||||
width = 3
|
||||
}
|
||||
} else {
|
||||
width = len(string(c))
|
||||
}
|
||||
for j := 0; j < width; j++ {
|
||||
buf = append(buf, `\x`...)
|
||||
buf = append(buf, lowerhex[s[i+j]>>4])
|
||||
buf = append(buf, lowerhex[s[i+j]&0xF])
|
||||
}
|
||||
} else {
|
||||
if c == '"' || c == '\\' {
|
||||
buf = append(buf, '\\')
|
||||
}
|
||||
buf = append(buf, string(c)...)
|
||||
}
|
||||
}
|
||||
buf = append(buf, '"')
|
||||
return string(buf)
|
||||
}
|
35
common/json/badoption/listable.go
Normal file
35
common/json/badoption/listable.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package badoption
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
type Listable[T any] []T
|
||||
|
||||
func (l Listable[T]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
arrayList := []T(l)
|
||||
if len(arrayList) == 1 {
|
||||
return json.Marshal(arrayList[0])
|
||||
}
|
||||
return json.MarshalContext(ctx, arrayList)
|
||||
}
|
||||
|
||||
func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
if string(content) == "null" {
|
||||
return nil
|
||||
}
|
||||
var singleItem T
|
||||
err := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem)
|
||||
if err == nil {
|
||||
*l = []T{singleItem}
|
||||
return nil
|
||||
}
|
||||
newErr := json.UnmarshalContextDisallowUnknownFields(ctx, content, (*[]T)(l))
|
||||
if newErr == nil {
|
||||
return nil
|
||||
}
|
||||
return E.Errors(err, newErr)
|
||||
}
|
98
common/json/badoption/netip.go
Normal file
98
common/json/badoption/netip.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package badoption
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
type Addr netip.Addr
|
||||
|
||||
func (a *Addr) Build(defaultAddr netip.Addr) netip.Addr {
|
||||
if a == nil {
|
||||
return defaultAddr
|
||||
}
|
||||
return netip.Addr(*a)
|
||||
}
|
||||
|
||||
func (a *Addr) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(netip.Addr(*a).String())
|
||||
}
|
||||
|
||||
func (a *Addr) UnmarshalJSON(content []byte) error {
|
||||
var value string
|
||||
err := json.Unmarshal(content, &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addr, err := netip.ParseAddr(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*a = Addr(addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Prefix netip.Prefix
|
||||
|
||||
func (p *Prefix) Build(defaultPrefix netip.Prefix) netip.Prefix {
|
||||
if p == nil {
|
||||
return defaultPrefix
|
||||
}
|
||||
return netip.Prefix(*p)
|
||||
}
|
||||
|
||||
func (p *Prefix) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(netip.Prefix(*p).String())
|
||||
}
|
||||
|
||||
func (p *Prefix) UnmarshalJSON(content []byte) error {
|
||||
var value string
|
||||
err := json.Unmarshal(content, &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*p = Prefix(prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Prefixable netip.Prefix
|
||||
|
||||
func (p *Prefixable) Build(defaultPrefix netip.Prefix) netip.Prefix {
|
||||
if p == nil {
|
||||
return defaultPrefix
|
||||
}
|
||||
return netip.Prefix(*p)
|
||||
}
|
||||
|
||||
func (p *Prefixable) MarshalJSON() ([]byte, error) {
|
||||
prefix := netip.Prefix(*p)
|
||||
if prefix.Bits() == prefix.Addr().BitLen() {
|
||||
return json.Marshal(prefix.Addr().String())
|
||||
} else {
|
||||
return json.Marshal(prefix.String())
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Prefixable) UnmarshalJSON(content []byte) error {
|
||||
var value string
|
||||
err := json.Unmarshal(content, &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prefix, prefixErr := netip.ParsePrefix(value)
|
||||
if prefixErr == nil {
|
||||
*p = Prefixable(prefix)
|
||||
return nil
|
||||
}
|
||||
addr, addrErr := netip.ParseAddr(value)
|
||||
if addrErr == nil {
|
||||
*p = Prefixable(netip.PrefixFrom(addr, addr.BitLen()))
|
||||
return nil
|
||||
}
|
||||
return prefixErr
|
||||
}
|
31
common/json/badoption/regexp.go
Normal file
31
common/json/badoption/regexp.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package badoption
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
type Regexp regexp.Regexp
|
||||
|
||||
func (r *Regexp) Build() *regexp.Regexp {
|
||||
return (*regexp.Regexp)(r)
|
||||
}
|
||||
|
||||
func (r *Regexp) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal((*regexp.Regexp)(r).String())
|
||||
}
|
||||
|
||||
func (r *Regexp) UnmarshalJSON(content []byte) error {
|
||||
var stringValue string
|
||||
err := json.Unmarshal(content, &stringValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
regex, err := regexp.Compile(stringValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*r = Regexp(*regex)
|
||||
return nil
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
//go:build go1.21 && !without_contextjson
|
||||
//go:build go1.20 && !without_contextjson
|
||||
|
||||
package json
|
||||
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
//go:build !go1.21 && go1.20 && !without_contextjson
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/json/internal/contextjson_120"
|
||||
)
|
||||
|
||||
var (
|
||||
Marshal = json.Marshal
|
||||
Unmarshal = json.Unmarshal
|
||||
NewEncoder = json.NewEncoder
|
||||
NewDecoder = json.NewDecoder
|
||||
)
|
||||
|
||||
type (
|
||||
Encoder = json.Encoder
|
||||
Decoder = json.Decoder
|
||||
Token = json.Token
|
||||
Delim = json.Delim
|
||||
SyntaxError = json.SyntaxError
|
||||
RawMessage = json.RawMessage
|
||||
)
|
23
common/json/context_ext.go
Normal file
23
common/json/context_ext.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing/common/json/internal/contextjson"
|
||||
)
|
||||
|
||||
var (
|
||||
MarshalContext = json.MarshalContext
|
||||
UnmarshalContext = json.UnmarshalContext
|
||||
NewEncoderContext = json.NewEncoderContext
|
||||
NewDecoderContext = json.NewDecoderContext
|
||||
UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields
|
||||
)
|
||||
|
||||
type ContextMarshaler interface {
|
||||
MarshalJSONContext(ctx context.Context) ([]byte, error)
|
||||
}
|
||||
|
||||
type ContextUnmarshaler interface {
|
||||
UnmarshalJSONContext(ctx context.Context, content []byte) error
|
||||
}
|
11
common/json/internal/contextjson/context.go
Normal file
11
common/json/internal/contextjson/context.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package json
|
||||
|
||||
import "context"
|
||||
|
||||
type ContextMarshaler interface {
|
||||
MarshalJSONContext(ctx context.Context) ([]byte, error)
|
||||
}
|
||||
|
||||
type ContextUnmarshaler interface {
|
||||
UnmarshalJSONContext(ctx context.Context, content []byte) error
|
||||
}
|
43
common/json/internal/contextjson/context_test.go
Normal file
43
common/json/internal/contextjson/context_test.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package json_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/json/internal/contextjson"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type myStruct struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
return json.Marshal(ctx.Value("key").(string))
|
||||
}
|
||||
|
||||
func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
m.value = ctx.Value("key").(string)
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
func TestMarshalContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.WithValue(context.Background(), "key", "value")
|
||||
var s myStruct
|
||||
b, err := json.MarshalContext(ctx, &s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte(`"value"`), b)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
func TestUnmarshalContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.WithValue(context.Background(), "key", "value")
|
||||
var s myStruct
|
||||
err := json.UnmarshalContext(ctx, []byte(`{}`), &s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "value", s.value)
|
||||
}
|
|
@ -8,6 +8,7 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
@ -95,10 +96,15 @@ import (
|
|||
// Instead, they are replaced by the Unicode replacement
|
||||
// character U+FFFD.
|
||||
func Unmarshal(data []byte, v any) error {
|
||||
return UnmarshalContext(context.Background(), data, v)
|
||||
}
|
||||
|
||||
func UnmarshalContext(ctx context.Context, data []byte, v any) error {
|
||||
// Check for well-formedness.
|
||||
// Avoids filling out half a data structure
|
||||
// before discovering a JSON syntax error.
|
||||
var d decodeState
|
||||
d.ctx = ctx
|
||||
err := checkValid(data, &d.scan)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -209,6 +215,7 @@ type errorContext struct {
|
|||
|
||||
// decodeState represents the state while decoding a JSON value.
|
||||
type decodeState struct {
|
||||
ctx context.Context
|
||||
data []byte
|
||||
off int // next read offset in data
|
||||
opcode int // last read result
|
||||
|
@ -428,7 +435,7 @@ func (d *decodeState) valueQuoted() any {
|
|||
// If it encounters an Unmarshaler, indirect stops and returns that.
|
||||
// If decodingNull is true, indirect stops at the first settable pointer so it
|
||||
// can be set to nil.
|
||||
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
|
||||
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) {
|
||||
// Issue #24153 indicates that it is generally not a guaranteed property
|
||||
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
|
||||
// and expect the value to still be settable for values derived from
|
||||
|
@ -482,11 +489,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
|
|||
}
|
||||
if v.Type().NumMethod() > 0 && v.CanInterface() {
|
||||
if u, ok := v.Interface().(Unmarshaler); ok {
|
||||
return u, nil, reflect.Value{}
|
||||
return u, nil, nil, reflect.Value{}
|
||||
}
|
||||
if cu, ok := v.Interface().(ContextUnmarshaler); ok {
|
||||
return nil, cu, nil, reflect.Value{}
|
||||
}
|
||||
if !decodingNull {
|
||||
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
|
||||
return nil, u, reflect.Value{}
|
||||
return nil, nil, u, reflect.Value{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -498,14 +508,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
|
|||
v = v.Elem()
|
||||
}
|
||||
}
|
||||
return nil, nil, v
|
||||
return nil, nil, nil, v
|
||||
}
|
||||
|
||||
// array consumes an array from d.data[d.off-1:], decoding into v.
|
||||
// The first byte of the array ('[') has been read already.
|
||||
func (d *decodeState) array(v reflect.Value) error {
|
||||
// Check for unmarshaler.
|
||||
u, ut, pv := indirect(v, false)
|
||||
u, cu, ut, pv := indirect(v, false)
|
||||
if u != nil {
|
||||
start := d.readIndex()
|
||||
d.skip()
|
||||
|
@ -515,6 +525,15 @@ func (d *decodeState) array(v reflect.Value) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
if cu != nil {
|
||||
start := d.readIndex()
|
||||
d.skip()
|
||||
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
|
||||
if err != nil {
|
||||
d.saveError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if ut != nil {
|
||||
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
|
||||
d.skip()
|
||||
|
@ -612,7 +631,7 @@ var (
|
|||
// The first byte ('{') of the object has been read already.
|
||||
func (d *decodeState) object(v reflect.Value) error {
|
||||
// Check for unmarshaler.
|
||||
u, ut, pv := indirect(v, false)
|
||||
u, cu, ut, pv := indirect(v, false)
|
||||
if u != nil {
|
||||
start := d.readIndex()
|
||||
d.skip()
|
||||
|
@ -622,6 +641,15 @@ func (d *decodeState) object(v reflect.Value) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
if cu != nil {
|
||||
start := d.readIndex()
|
||||
d.skip()
|
||||
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
|
||||
if err != nil {
|
||||
d.saveError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if ut != nil {
|
||||
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
|
||||
d.skip()
|
||||
|
@ -870,7 +898,7 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
|
|||
return nil
|
||||
}
|
||||
isNull := item[0] == 'n' // null
|
||||
u, ut, pv := indirect(v, isNull)
|
||||
u, cu, ut, pv := indirect(v, isNull)
|
||||
if u != nil {
|
||||
err := u.UnmarshalJSON(item)
|
||||
if err != nil {
|
||||
|
@ -878,6 +906,13 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
|
|||
}
|
||||
return nil
|
||||
}
|
||||
if cu != nil {
|
||||
err := cu.UnmarshalJSONContext(d.ctx, item)
|
||||
if err != nil {
|
||||
d.saveError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if ut != nil {
|
||||
if item[0] != '"' {
|
||||
if fromQuoted {
|
||||
|
|
|
@ -12,6 +12,7 @@ package json
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
@ -156,7 +157,11 @@ import (
|
|||
// handle them. Passing cyclic structures to Marshal will result in
|
||||
// an error.
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
e := newEncodeState()
|
||||
return MarshalContext(context.Background(), v)
|
||||
}
|
||||
|
||||
func MarshalContext(ctx context.Context, v any) ([]byte, error) {
|
||||
e := newEncodeState(ctx)
|
||||
defer encodeStatePool.Put(e)
|
||||
|
||||
err := e.marshal(v, encOpts{escapeHTML: true})
|
||||
|
@ -251,6 +256,7 @@ var hex = "0123456789abcdef"
|
|||
type encodeState struct {
|
||||
bytes.Buffer // accumulated output
|
||||
|
||||
ctx context.Context
|
||||
// Keep track of what pointers we've seen in the current recursive call
|
||||
// path, to avoid cycles that could lead to a stack overflow. Only do
|
||||
// the relatively expensive map operations if ptrLevel is larger than
|
||||
|
@ -264,7 +270,7 @@ const startDetectingCyclesAfter = 1000
|
|||
|
||||
var encodeStatePool sync.Pool
|
||||
|
||||
func newEncodeState() *encodeState {
|
||||
func newEncodeState(ctx context.Context) *encodeState {
|
||||
if v := encodeStatePool.Get(); v != nil {
|
||||
e := v.(*encodeState)
|
||||
e.Reset()
|
||||
|
@ -274,7 +280,7 @@ func newEncodeState() *encodeState {
|
|||
e.ptrLevel = 0
|
||||
return e
|
||||
}
|
||||
return &encodeState{ptrSeen: make(map[any]struct{})}
|
||||
return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})}
|
||||
}
|
||||
|
||||
// jsonError is an error wrapper type for internal use only.
|
||||
|
@ -371,8 +377,9 @@ func typeEncoder(t reflect.Type) encoderFunc {
|
|||
}
|
||||
|
||||
var (
|
||||
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
|
||||
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
|
||||
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
|
||||
contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem()
|
||||
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
|
||||
)
|
||||
|
||||
// newTypeEncoder constructs an encoderFunc for a type.
|
||||
|
@ -385,9 +392,15 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
|
|||
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
|
||||
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
|
||||
}
|
||||
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) {
|
||||
return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false))
|
||||
}
|
||||
if t.Implements(marshalerType) {
|
||||
return marshalerEncoder
|
||||
}
|
||||
if t.Implements(contextMarshalerType) {
|
||||
return contextMarshalerEncoder
|
||||
}
|
||||
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
|
||||
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
|
||||
}
|
||||
|
@ -442,7 +455,7 @@ func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|||
b, err := m.MarshalJSON()
|
||||
if err == nil {
|
||||
e.Grow(len(b))
|
||||
out := e.AvailableBuffer()
|
||||
out := availableBuffer(&e.Buffer)
|
||||
out, err = appendCompact(out, b, opts.escapeHTML)
|
||||
e.Buffer.Write(out)
|
||||
}
|
||||
|
@ -461,7 +474,48 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|||
b, err := m.MarshalJSON()
|
||||
if err == nil {
|
||||
e.Grow(len(b))
|
||||
out := e.AvailableBuffer()
|
||||
out := availableBuffer(&e.Buffer)
|
||||
out, err = appendCompact(out, b, opts.escapeHTML)
|
||||
e.Buffer.Write(out)
|
||||
}
|
||||
if err != nil {
|
||||
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
|
||||
}
|
||||
}
|
||||
|
||||
func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
if v.Kind() == reflect.Pointer && v.IsNil() {
|
||||
e.WriteString("null")
|
||||
return
|
||||
}
|
||||
m, ok := v.Interface().(ContextMarshaler)
|
||||
if !ok {
|
||||
e.WriteString("null")
|
||||
return
|
||||
}
|
||||
b, err := m.MarshalJSONContext(e.ctx)
|
||||
if err == nil {
|
||||
e.Grow(len(b))
|
||||
out := availableBuffer(&e.Buffer)
|
||||
out, err = appendCompact(out, b, opts.escapeHTML)
|
||||
e.Buffer.Write(out)
|
||||
}
|
||||
if err != nil {
|
||||
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
|
||||
}
|
||||
}
|
||||
|
||||
func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
va := v.Addr()
|
||||
if va.IsNil() {
|
||||
e.WriteString("null")
|
||||
return
|
||||
}
|
||||
m := va.Interface().(ContextMarshaler)
|
||||
b, err := m.MarshalJSONContext(e.ctx)
|
||||
if err == nil {
|
||||
e.Grow(len(b))
|
||||
out := availableBuffer(&e.Buffer)
|
||||
out, err = appendCompact(out, b, opts.escapeHTML)
|
||||
e.Buffer.Write(out)
|
||||
}
|
||||
|
@ -484,7 +538,7 @@ func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|||
if err != nil {
|
||||
e.error(&MarshalerError{v.Type(), err, "MarshalText"})
|
||||
}
|
||||
e.Write(appendString(e.AvailableBuffer(), b, opts.escapeHTML))
|
||||
e.Write(appendString(availableBuffer(&e.Buffer), b, opts.escapeHTML))
|
||||
}
|
||||
|
||||
func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
|
@ -498,11 +552,11 @@ func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|||
if err != nil {
|
||||
e.error(&MarshalerError{v.Type(), err, "MarshalText"})
|
||||
}
|
||||
e.Write(appendString(e.AvailableBuffer(), b, opts.escapeHTML))
|
||||
e.Write(appendString(availableBuffer(&e.Buffer), b, opts.escapeHTML))
|
||||
}
|
||||
|
||||
func boolEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
b := e.AvailableBuffer()
|
||||
b := availableBuffer(&e.Buffer)
|
||||
b = mayAppendQuote(b, opts.quoted)
|
||||
b = strconv.AppendBool(b, v.Bool())
|
||||
b = mayAppendQuote(b, opts.quoted)
|
||||
|
@ -510,7 +564,7 @@ func boolEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|||
}
|
||||
|
||||
func intEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
b := e.AvailableBuffer()
|
||||
b := availableBuffer(&e.Buffer)
|
||||
b = mayAppendQuote(b, opts.quoted)
|
||||
b = strconv.AppendInt(b, v.Int(), 10)
|
||||
b = mayAppendQuote(b, opts.quoted)
|
||||
|
@ -518,7 +572,7 @@ func intEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|||
}
|
||||
|
||||
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
b := e.AvailableBuffer()
|
||||
b := availableBuffer(&e.Buffer)
|
||||
b = mayAppendQuote(b, opts.quoted)
|
||||
b = strconv.AppendUint(b, v.Uint(), 10)
|
||||
b = mayAppendQuote(b, opts.quoted)
|
||||
|
@ -538,7 +592,7 @@ func (bits floatEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
|||
// See golang.org/issue/6384 and golang.org/issue/14135.
|
||||
// Like fmt %g, but the exponent cutoffs are different
|
||||
// and exponents themselves are not padded to two digits.
|
||||
b := e.AvailableBuffer()
|
||||
b := availableBuffer(&e.Buffer)
|
||||
b = mayAppendQuote(b, opts.quoted)
|
||||
abs := math.Abs(f)
|
||||
fmt := byte('f')
|
||||
|
@ -577,7 +631,7 @@ func stringEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|||
if !isValidNumber(numStr) {
|
||||
e.error(fmt.Errorf("json: invalid number literal %q", numStr))
|
||||
}
|
||||
b := e.AvailableBuffer()
|
||||
b := availableBuffer(&e.Buffer)
|
||||
b = mayAppendQuote(b, opts.quoted)
|
||||
b = append(b, numStr...)
|
||||
b = mayAppendQuote(b, opts.quoted)
|
||||
|
@ -586,9 +640,9 @@ func stringEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|||
}
|
||||
if opts.quoted {
|
||||
b := appendString(nil, v.String(), opts.escapeHTML)
|
||||
e.Write(appendString(e.AvailableBuffer(), b, false)) // no need to escape again since it is already escaped
|
||||
e.Write(appendString(availableBuffer(&e.Buffer), b, false)) // no need to escape again since it is already escaped
|
||||
} else {
|
||||
e.Write(appendString(e.AvailableBuffer(), v.String(), opts.escapeHTML))
|
||||
e.Write(appendString(availableBuffer(&e.Buffer), v.String(), opts.escapeHTML))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -754,7 +808,7 @@ func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
|||
if i > 0 {
|
||||
e.WriteByte(',')
|
||||
}
|
||||
e.Write(appendString(e.AvailableBuffer(), kv.ks, opts.escapeHTML))
|
||||
e.Write(appendString(availableBuffer(&e.Buffer), kv.ks, opts.escapeHTML))
|
||||
e.WriteByte(':')
|
||||
me.elemEnc(e, kv.v, opts)
|
||||
}
|
||||
|
@ -786,7 +840,7 @@ func encodeByteSlice(e *encodeState, v reflect.Value, _ encOpts) {
|
|||
e.Grow(len(`"`) + encodedLen + len(`"`))
|
||||
|
||||
// TODO(https://go.dev/issue/53693): Use base64.Encoding.AppendEncode.
|
||||
b := e.AvailableBuffer()
|
||||
b := availableBuffer(&e.Buffer)
|
||||
b = append(b, '"')
|
||||
base64.StdEncoding.Encode(b[len(b):][:encodedLen], s)
|
||||
b = b[:len(b)+encodedLen]
|
||||
|
@ -827,7 +881,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc {
|
|||
// Byte slices get special treatment; arrays don't.
|
||||
if t.Elem().Kind() == reflect.Uint8 {
|
||||
p := reflect.PointerTo(t.Elem())
|
||||
if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) {
|
||||
if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) {
|
||||
return encodeByteSlice
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,11 @@ package json
|
|||
|
||||
import "bytes"
|
||||
|
||||
// TODO(https://go.dev/issue/53685): Use bytes.Buffer.AvailableBuffer instead.
|
||||
func availableBuffer(b *bytes.Buffer) []byte {
|
||||
return b.Bytes()[b.Len():]
|
||||
}
|
||||
|
||||
// HTMLEscape appends to dst the JSON-encoded src with <, >, &, U+2028 and U+2029
|
||||
// characters inside string literals changed to \u003c, \u003e, \u0026, \u2028, \u2029
|
||||
// so that the JSON will be safe to embed inside HTML <script> tags.
|
||||
|
@ -13,7 +18,7 @@ import "bytes"
|
|||
// escaping within <script> tags, so an alternative JSON encoding must be used.
|
||||
func HTMLEscape(dst *bytes.Buffer, src []byte) {
|
||||
dst.Grow(len(src))
|
||||
dst.Write(appendHTMLEscape(dst.AvailableBuffer(), src))
|
||||
dst.Write(appendHTMLEscape(availableBuffer(dst), src))
|
||||
}
|
||||
|
||||
func appendHTMLEscape(dst, src []byte) []byte {
|
||||
|
@ -40,7 +45,7 @@ func appendHTMLEscape(dst, src []byte) []byte {
|
|||
// insignificant space characters elided.
|
||||
func Compact(dst *bytes.Buffer, src []byte) error {
|
||||
dst.Grow(len(src))
|
||||
b := dst.AvailableBuffer()
|
||||
b := availableBuffer(dst)
|
||||
b, err := appendCompact(b, src, false)
|
||||
dst.Write(b)
|
||||
return err
|
||||
|
@ -109,7 +114,7 @@ const indentGrowthFactor = 2
|
|||
// if src ends in a trailing newline, so will dst.
|
||||
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
|
||||
dst.Grow(indentGrowthFactor * len(src))
|
||||
b := dst.AvailableBuffer()
|
||||
b := availableBuffer(dst)
|
||||
b, err := appendIndent(b, src, prefix, indent)
|
||||
dst.Write(b)
|
||||
return err
|
||||
|
|
20
common/json/internal/contextjson/keys.go
Normal file
20
common/json/internal/contextjson/keys.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
func ObjectKeys(object reflect.Type) []string {
|
||||
switch object.Kind() {
|
||||
case reflect.Pointer:
|
||||
return ObjectKeys(object.Elem())
|
||||
case reflect.Struct:
|
||||
default:
|
||||
panic("invalid non-struct input")
|
||||
}
|
||||
return common.Map(cachedTypeFields(object).list, func(field field) string {
|
||||
return field.name
|
||||
})
|
||||
}
|
26
common/json/internal/contextjson/keys_test.go
Normal file
26
common/json/internal/contextjson/keys_test.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package json_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
json "github.com/sagernet/sing/common/json/internal/contextjson"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type MyObject struct {
|
||||
Hello string `json:"hello,omitempty"`
|
||||
MyWorld
|
||||
MyWorld2 string `json:"-"`
|
||||
}
|
||||
|
||||
type MyWorld struct {
|
||||
World string `json:"world,omitempty"`
|
||||
}
|
||||
|
||||
func TestObjectKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
keys := json.ObjectKeys(reflect.TypeOf(&MyObject{}))
|
||||
require.Equal(t, []string{"hello", "world"}, keys)
|
||||
}
|
|
@ -300,7 +300,7 @@ func stateEndValue(s *scanner, c byte) int {
|
|||
case parseObjectValue:
|
||||
if c == ',' {
|
||||
s.parseState[n-1] = parseObjectKey
|
||||
s.step = stateBeginString
|
||||
s.step = stateBeginStringOrEmpty
|
||||
return scanObjectValue
|
||||
}
|
||||
if c == '}' {
|
||||
|
@ -310,7 +310,7 @@ func stateEndValue(s *scanner, c byte) int {
|
|||
return s.error(c, "after object key:value pair")
|
||||
case parseArrayValue:
|
||||
if c == ',' {
|
||||
s.step = stateBeginValue
|
||||
s.step = stateBeginValueOrEmpty
|
||||
return scanArrayValue
|
||||
}
|
||||
if c == ']' {
|
||||
|
|
|
@ -6,6 +6,7 @@ package json
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
@ -29,7 +30,11 @@ type Decoder struct {
|
|||
// The decoder introduces its own buffering and may
|
||||
// read data from r beyond the JSON values requested.
|
||||
func NewDecoder(r io.Reader) *Decoder {
|
||||
return &Decoder{r: r}
|
||||
return NewDecoderContext(context.Background(), r)
|
||||
}
|
||||
|
||||
func NewDecoderContext(ctx context.Context, r io.Reader) *Decoder {
|
||||
return &Decoder{r: r, d: decodeState{ctx: ctx}}
|
||||
}
|
||||
|
||||
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
|
||||
|
@ -153,6 +158,10 @@ func (dec *Decoder) refill() error {
|
|||
dec.scanp = 0
|
||||
}
|
||||
|
||||
return dec.refill0()
|
||||
}
|
||||
|
||||
func (dec *Decoder) refill0() error {
|
||||
// Grow buffer if not large enough.
|
||||
const minRead = 512
|
||||
if cap(dec.buf)-len(dec.buf) < minRead {
|
||||
|
@ -179,6 +188,7 @@ func nonSpace(b []byte) bool {
|
|||
|
||||
// An Encoder writes JSON values to an output stream.
|
||||
type Encoder struct {
|
||||
ctx context.Context
|
||||
w io.Writer
|
||||
err error
|
||||
escapeHTML bool
|
||||
|
@ -190,7 +200,11 @@ type Encoder struct {
|
|||
|
||||
// NewEncoder returns a new encoder that writes to w.
|
||||
func NewEncoder(w io.Writer) *Encoder {
|
||||
return &Encoder{w: w, escapeHTML: true}
|
||||
return NewEncoderContext(context.Background(), w)
|
||||
}
|
||||
|
||||
func NewEncoderContext(ctx context.Context, w io.Writer) *Encoder {
|
||||
return &Encoder{ctx: ctx, w: w, escapeHTML: true}
|
||||
}
|
||||
|
||||
// Encode writes the JSON encoding of v to the stream,
|
||||
|
@ -203,7 +217,7 @@ func (enc *Encoder) Encode(v any) error {
|
|||
return enc.err
|
||||
}
|
||||
|
||||
e := newEncodeState()
|
||||
e := newEncodeState(enc.ctx)
|
||||
defer encodeStatePool.Put(e)
|
||||
|
||||
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
|
||||
|
@ -402,7 +416,7 @@ func (dec *Decoder) Token() (Token, error) {
|
|||
return Delim('{'), nil
|
||||
|
||||
case '}':
|
||||
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma {
|
||||
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma && dec.tokenState != tokenObjectKey {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
|
@ -410,7 +424,6 @@ func (dec *Decoder) Token() (Token, error) {
|
|||
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
|
||||
dec.tokenValueEnd()
|
||||
return Delim('}'), nil
|
||||
|
||||
case ':':
|
||||
if dec.tokenState != tokenObjectColon {
|
||||
return dec.tokenError(c)
|
||||
|
@ -483,7 +496,26 @@ func (dec *Decoder) tokenError(c byte) (Token, error) {
|
|||
// current array or object being parsed.
|
||||
func (dec *Decoder) More() bool {
|
||||
c, err := dec.peek()
|
||||
return err == nil && c != ']' && c != '}'
|
||||
// return err == nil && c != ']' && c != '}'
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if c == ']' || c == '}' {
|
||||
return false
|
||||
}
|
||||
if c == ',' {
|
||||
scanp := dec.scanp
|
||||
dec.scanp++
|
||||
c, err = dec.peekNoRefill()
|
||||
dec.scanp = scanp
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if c == ']' || c == '}' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (dec *Decoder) peek() (byte, error) {
|
||||
|
@ -505,6 +537,25 @@ func (dec *Decoder) peek() (byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (dec *Decoder) peekNoRefill() (byte, error) {
|
||||
var err error
|
||||
for {
|
||||
for i := dec.scanp; i < len(dec.buf); i++ {
|
||||
c := dec.buf[i]
|
||||
if isSpace(c) {
|
||||
continue
|
||||
}
|
||||
dec.scanp = i
|
||||
return c, nil
|
||||
}
|
||||
// buffer has been scanned, now report any error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = dec.refill0()
|
||||
}
|
||||
}
|
||||
|
||||
// InputOffset returns the input stream byte offset of the current decoder position.
|
||||
// The offset gives the location of the end of the most recently returned token
|
||||
// and the beginning of the next token.
|
||||
|
|
26
common/json/internal/contextjson/unmarshal.go
Normal file
26
common/json/internal/contextjson/unmarshal.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package json
|
||||
|
||||
import "context"
|
||||
|
||||
func UnmarshalDisallowUnknownFields(data []byte, v any) error {
|
||||
var d decodeState
|
||||
d.disallowUnknownFields = true
|
||||
err := checkValid(data, &d.scan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.init(data)
|
||||
return d.unmarshal(v)
|
||||
}
|
||||
|
||||
func UnmarshalContextDisallowUnknownFields(ctx context.Context, data []byte, v any) error {
|
||||
var d decodeState
|
||||
d.ctx = ctx
|
||||
d.disallowUnknownFields = true
|
||||
err := checkValid(data, &d.scan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.init(data)
|
||||
return d.unmarshal(v)
|
||||
}
|
|
@ -1,3 +0,0 @@
|
|||
# contextjson
|
||||
|
||||
mod from go1.20.11
|
File diff suppressed because it is too large
Load diff
|
@ -1,49 +0,0 @@
|
|||
package json
|
||||
|
||||
import "strconv"
|
||||
|
||||
type decodeContext struct {
|
||||
parent *decodeContext
|
||||
index int
|
||||
key string
|
||||
}
|
||||
|
||||
func (d *decodeState) formatContext() string {
|
||||
var description string
|
||||
context := d.context
|
||||
var appendDot bool
|
||||
for context != nil {
|
||||
if appendDot {
|
||||
description = "." + description
|
||||
}
|
||||
if context.key != "" {
|
||||
description = context.key + description
|
||||
appendDot = true
|
||||
} else {
|
||||
description = "[" + strconv.Itoa(context.index) + "]" + description
|
||||
appendDot = false
|
||||
}
|
||||
context = context.parent
|
||||
}
|
||||
return description
|
||||
}
|
||||
|
||||
type contextError struct {
|
||||
parent error
|
||||
context string
|
||||
index bool
|
||||
}
|
||||
|
||||
func (c *contextError) Unwrap() error {
|
||||
return c.parent
|
||||
}
|
||||
|
||||
func (c *contextError) Error() string {
|
||||
//goland:noinspection GoTypeAssertionOnErrors
|
||||
switch c.parent.(type) {
|
||||
case *contextError:
|
||||
return c.context + "." + c.parent.Error()
|
||||
default:
|
||||
return c.context + ": " + c.parent.Error()
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,141 +0,0 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
caseMask = ^byte(0x20) // Mask to ignore case in ASCII.
|
||||
kelvin = '\u212a'
|
||||
smallLongEss = '\u017f'
|
||||
)
|
||||
|
||||
// foldFunc returns one of four different case folding equivalence
|
||||
// functions, from most general (and slow) to fastest:
|
||||
//
|
||||
// 1) bytes.EqualFold, if the key s contains any non-ASCII UTF-8
|
||||
// 2) equalFoldRight, if s contains special folding ASCII ('k', 'K', 's', 'S')
|
||||
// 3) asciiEqualFold, no special, but includes non-letters (including _)
|
||||
// 4) simpleLetterEqualFold, no specials, no non-letters.
|
||||
//
|
||||
// The letters S and K are special because they map to 3 runes, not just 2:
|
||||
// - S maps to s and to U+017F 'ſ' Latin small letter long s
|
||||
// - k maps to K and to U+212A 'K' Kelvin sign
|
||||
//
|
||||
// See https://play.golang.org/p/tTxjOc0OGo
|
||||
//
|
||||
// The returned function is specialized for matching against s and
|
||||
// should only be given s. It's not curried for performance reasons.
|
||||
func foldFunc(s []byte) func(s, t []byte) bool {
|
||||
nonLetter := false
|
||||
special := false // special letter
|
||||
for _, b := range s {
|
||||
if b >= utf8.RuneSelf {
|
||||
return bytes.EqualFold
|
||||
}
|
||||
upper := b & caseMask
|
||||
if upper < 'A' || upper > 'Z' {
|
||||
nonLetter = true
|
||||
} else if upper == 'K' || upper == 'S' {
|
||||
// See above for why these letters are special.
|
||||
special = true
|
||||
}
|
||||
}
|
||||
if special {
|
||||
return equalFoldRight
|
||||
}
|
||||
if nonLetter {
|
||||
return asciiEqualFold
|
||||
}
|
||||
return simpleLetterEqualFold
|
||||
}
|
||||
|
||||
// equalFoldRight is a specialization of bytes.EqualFold when s is
|
||||
// known to be all ASCII (including punctuation), but contains an 's',
|
||||
// 'S', 'k', or 'K', requiring a Unicode fold on the bytes in t.
|
||||
// See comments on foldFunc.
|
||||
func equalFoldRight(s, t []byte) bool {
|
||||
for _, sb := range s {
|
||||
if len(t) == 0 {
|
||||
return false
|
||||
}
|
||||
tb := t[0]
|
||||
if tb < utf8.RuneSelf {
|
||||
if sb != tb {
|
||||
sbUpper := sb & caseMask
|
||||
if 'A' <= sbUpper && sbUpper <= 'Z' {
|
||||
if sbUpper != tb&caseMask {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
t = t[1:]
|
||||
continue
|
||||
}
|
||||
// sb is ASCII and t is not. t must be either kelvin
|
||||
// sign or long s; sb must be s, S, k, or K.
|
||||
tr, size := utf8.DecodeRune(t)
|
||||
switch sb {
|
||||
case 's', 'S':
|
||||
if tr != smallLongEss {
|
||||
return false
|
||||
}
|
||||
case 'k', 'K':
|
||||
if tr != kelvin {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
t = t[size:]
|
||||
|
||||
}
|
||||
return len(t) == 0
|
||||
}
|
||||
|
||||
// asciiEqualFold is a specialization of bytes.EqualFold for use when
|
||||
// s is all ASCII (but may contain non-letters) and contains no
|
||||
// special-folding letters.
|
||||
// See comments on foldFunc.
|
||||
func asciiEqualFold(s, t []byte) bool {
|
||||
if len(s) != len(t) {
|
||||
return false
|
||||
}
|
||||
for i, sb := range s {
|
||||
tb := t[i]
|
||||
if sb == tb {
|
||||
continue
|
||||
}
|
||||
if ('a' <= sb && sb <= 'z') || ('A' <= sb && sb <= 'Z') {
|
||||
if sb&caseMask != tb&caseMask {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// simpleLetterEqualFold is a specialization of bytes.EqualFold for
|
||||
// use when s is all ASCII letters (no underscores, etc) and also
|
||||
// doesn't contain 'k', 'K', 's', or 'S'.
|
||||
// See comments on foldFunc.
|
||||
func simpleLetterEqualFold(s, t []byte) bool {
|
||||
if len(s) != len(t) {
|
||||
return false
|
||||
}
|
||||
for i, b := range s {
|
||||
if b&caseMask != t[i]&caseMask {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build gofuzz
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func Fuzz(data []byte) (score int) {
|
||||
for _, ctor := range []func() any{
|
||||
func() any { return new(any) },
|
||||
func() any { return new(map[string]any) },
|
||||
func() any { return new([]any) },
|
||||
} {
|
||||
v := ctor()
|
||||
err := Unmarshal(data, v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
score = 1
|
||||
|
||||
m, err := Marshal(v)
|
||||
if err != nil {
|
||||
fmt.Printf("v=%#v\n", v)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
u := ctor()
|
||||
err = Unmarshal(m, u)
|
||||
if err != nil {
|
||||
fmt.Printf("v=%#v\n", v)
|
||||
fmt.Printf("m=%s\n", m)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -1,143 +0,0 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// Compact appends to dst the JSON-encoded src with
|
||||
// insignificant space characters elided.
|
||||
func Compact(dst *bytes.Buffer, src []byte) error {
|
||||
return compact(dst, src, false)
|
||||
}
|
||||
|
||||
func compact(dst *bytes.Buffer, src []byte, escape bool) error {
|
||||
origLen := dst.Len()
|
||||
scan := newScanner()
|
||||
defer freeScanner(scan)
|
||||
start := 0
|
||||
for i, c := range src {
|
||||
if escape && (c == '<' || c == '>' || c == '&') {
|
||||
if start < i {
|
||||
dst.Write(src[start:i])
|
||||
}
|
||||
dst.WriteString(`\u00`)
|
||||
dst.WriteByte(hex[c>>4])
|
||||
dst.WriteByte(hex[c&0xF])
|
||||
start = i + 1
|
||||
}
|
||||
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
|
||||
if escape && c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
|
||||
if start < i {
|
||||
dst.Write(src[start:i])
|
||||
}
|
||||
dst.WriteString(`\u202`)
|
||||
dst.WriteByte(hex[src[i+2]&0xF])
|
||||
start = i + 3
|
||||
}
|
||||
v := scan.step(scan, c)
|
||||
if v >= scanSkipSpace {
|
||||
if v == scanError {
|
||||
break
|
||||
}
|
||||
if start < i {
|
||||
dst.Write(src[start:i])
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if scan.eof() == scanError {
|
||||
dst.Truncate(origLen)
|
||||
return scan.err
|
||||
}
|
||||
if start < len(src) {
|
||||
dst.Write(src[start:])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newline(dst *bytes.Buffer, prefix, indent string, depth int) {
|
||||
dst.WriteByte('\n')
|
||||
dst.WriteString(prefix)
|
||||
for i := 0; i < depth; i++ {
|
||||
dst.WriteString(indent)
|
||||
}
|
||||
}
|
||||
|
||||
// Indent appends to dst an indented form of the JSON-encoded src.
|
||||
// Each element in a JSON object or array begins on a new,
|
||||
// indented line beginning with prefix followed by one or more
|
||||
// copies of indent according to the indentation nesting.
|
||||
// The data appended to dst does not begin with the prefix nor
|
||||
// any indentation, to make it easier to embed inside other formatted JSON data.
|
||||
// Although leading space characters (space, tab, carriage return, newline)
|
||||
// at the beginning of src are dropped, trailing space characters
|
||||
// at the end of src are preserved and copied to dst.
|
||||
// For example, if src has no trailing spaces, neither will dst;
|
||||
// if src ends in a trailing newline, so will dst.
|
||||
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
|
||||
origLen := dst.Len()
|
||||
scan := newScanner()
|
||||
defer freeScanner(scan)
|
||||
needIndent := false
|
||||
depth := 0
|
||||
for _, c := range src {
|
||||
scan.bytes++
|
||||
v := scan.step(scan, c)
|
||||
if v == scanSkipSpace {
|
||||
continue
|
||||
}
|
||||
if v == scanError {
|
||||
break
|
||||
}
|
||||
if needIndent && v != scanEndObject && v != scanEndArray {
|
||||
needIndent = false
|
||||
depth++
|
||||
newline(dst, prefix, indent, depth)
|
||||
}
|
||||
|
||||
// Emit semantically uninteresting bytes
|
||||
// (in particular, punctuation in strings) unmodified.
|
||||
if v == scanContinue {
|
||||
dst.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add spacing around real punctuation.
|
||||
switch c {
|
||||
case '{', '[':
|
||||
// delay indent so that empty object and array are formatted as {} and [].
|
||||
needIndent = true
|
||||
dst.WriteByte(c)
|
||||
|
||||
case ',':
|
||||
dst.WriteByte(c)
|
||||
newline(dst, prefix, indent, depth)
|
||||
|
||||
case ':':
|
||||
dst.WriteByte(c)
|
||||
dst.WriteByte(' ')
|
||||
|
||||
case '}', ']':
|
||||
if needIndent {
|
||||
// suppress indent in empty object/array
|
||||
needIndent = false
|
||||
} else {
|
||||
depth--
|
||||
newline(dst, prefix, indent, depth)
|
||||
}
|
||||
dst.WriteByte(c)
|
||||
|
||||
default:
|
||||
dst.WriteByte(c)
|
||||
}
|
||||
}
|
||||
if scan.eof() == scanError {
|
||||
dst.Truncate(origLen)
|
||||
return scan.err
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,610 +0,0 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
// JSON value parser state machine.
|
||||
// Just about at the limit of what is reasonable to write by hand.
|
||||
// Some parts are a bit tedious, but overall it nicely factors out the
|
||||
// otherwise common code from the multiple scanning functions
|
||||
// in this package (Compact, Indent, checkValid, etc).
|
||||
//
|
||||
// This file starts with two simple examples using the scanner
|
||||
// before diving into the scanner itself.
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Valid reports whether data is a valid JSON encoding.
|
||||
func Valid(data []byte) bool {
|
||||
scan := newScanner()
|
||||
defer freeScanner(scan)
|
||||
return checkValid(data, scan) == nil
|
||||
}
|
||||
|
||||
// checkValid verifies that data is valid JSON-encoded data.
|
||||
// scan is passed in for use by checkValid to avoid an allocation.
|
||||
// checkValid returns nil or a SyntaxError.
|
||||
func checkValid(data []byte, scan *scanner) error {
|
||||
scan.reset()
|
||||
for _, c := range data {
|
||||
scan.bytes++
|
||||
if scan.step(scan, c) == scanError {
|
||||
return scan.err
|
||||
}
|
||||
}
|
||||
if scan.eof() == scanError {
|
||||
return scan.err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// A SyntaxError is a description of a JSON syntax error.
|
||||
// Unmarshal will return a SyntaxError if the JSON can't be parsed.
|
||||
type SyntaxError struct {
|
||||
msg string // description of error
|
||||
Offset int64 // error occurred after reading Offset bytes
|
||||
}
|
||||
|
||||
func (e *SyntaxError) Error() string { return e.msg }
|
||||
|
||||
// A scanner is a JSON scanning state machine.
|
||||
// Callers call scan.reset and then pass bytes in one at a time
|
||||
// by calling scan.step(&scan, c) for each byte.
|
||||
// The return value, referred to as an opcode, tells the
|
||||
// caller about significant parsing events like beginning
|
||||
// and ending literals, objects, and arrays, so that the
|
||||
// caller can follow along if it wishes.
|
||||
// The return value scanEnd indicates that a single top-level
|
||||
// JSON value has been completed, *before* the byte that
|
||||
// just got passed in. (The indication must be delayed in order
|
||||
// to recognize the end of numbers: is 123 a whole value or
|
||||
// the beginning of 12345e+6?).
|
||||
type scanner struct {
|
||||
// The step is a func to be called to execute the next transition.
|
||||
// Also tried using an integer constant and a single func
|
||||
// with a switch, but using the func directly was 10% faster
|
||||
// on a 64-bit Mac Mini, and it's nicer to read.
|
||||
step func(*scanner, byte) int
|
||||
|
||||
// Reached end of top-level value.
|
||||
endTop bool
|
||||
|
||||
// Stack of what we're in the middle of - array values, object keys, object values.
|
||||
parseState []int
|
||||
|
||||
// Error that happened, if any.
|
||||
err error
|
||||
|
||||
// total bytes consumed, updated by decoder.Decode (and deliberately
|
||||
// not set to zero by scan.reset)
|
||||
bytes int64
|
||||
}
|
||||
|
||||
var scannerPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &scanner{}
|
||||
},
|
||||
}
|
||||
|
||||
func newScanner() *scanner {
|
||||
scan := scannerPool.Get().(*scanner)
|
||||
// scan.reset by design doesn't set bytes to zero
|
||||
scan.bytes = 0
|
||||
scan.reset()
|
||||
return scan
|
||||
}
|
||||
|
||||
func freeScanner(scan *scanner) {
|
||||
// Avoid hanging on to too much memory in extreme cases.
|
||||
if len(scan.parseState) > 1024 {
|
||||
scan.parseState = nil
|
||||
}
|
||||
scannerPool.Put(scan)
|
||||
}
|
||||
|
||||
// These values are returned by the state transition functions
|
||||
// assigned to scanner.state and the method scanner.eof.
|
||||
// They give details about the current state of the scan that
|
||||
// callers might be interested to know about.
|
||||
// It is okay to ignore the return value of any particular
|
||||
// call to scanner.state: if one call returns scanError,
|
||||
// every subsequent call will return scanError too.
|
||||
const (
|
||||
// Continue.
|
||||
scanContinue = iota // uninteresting byte
|
||||
scanBeginLiteral // end implied by next result != scanContinue
|
||||
scanBeginObject // begin object
|
||||
scanObjectKey // just finished object key (string)
|
||||
scanObjectValue // just finished non-last object value
|
||||
scanEndObject // end object (implies scanObjectValue if possible)
|
||||
scanBeginArray // begin array
|
||||
scanArrayValue // just finished array value
|
||||
scanEndArray // end array (implies scanArrayValue if possible)
|
||||
scanSkipSpace // space byte; can skip; known to be last "continue" result
|
||||
|
||||
// Stop.
|
||||
scanEnd // top-level value ended *before* this byte; known to be first "stop" result
|
||||
scanError // hit an error, scanner.err.
|
||||
)
|
||||
|
||||
// These values are stored in the parseState stack.
|
||||
// They give the current state of a composite value
|
||||
// being scanned. If the parser is inside a nested value
|
||||
// the parseState describes the nested state, outermost at entry 0.
|
||||
const (
|
||||
parseObjectKey = iota // parsing object key (before colon)
|
||||
parseObjectValue // parsing object value (after colon)
|
||||
parseArrayValue // parsing array value
|
||||
)
|
||||
|
||||
// This limits the max nesting depth to prevent stack overflow.
|
||||
// This is permitted by https://tools.ietf.org/html/rfc7159#section-9
|
||||
const maxNestingDepth = 10000
|
||||
|
||||
// reset prepares the scanner for use.
|
||||
// It must be called before calling s.step.
|
||||
func (s *scanner) reset() {
|
||||
s.step = stateBeginValue
|
||||
s.parseState = s.parseState[0:0]
|
||||
s.err = nil
|
||||
s.endTop = false
|
||||
}
|
||||
|
||||
// eof tells the scanner that the end of input has been reached.
|
||||
// It returns a scan status just as s.step does.
|
||||
func (s *scanner) eof() int {
|
||||
if s.err != nil {
|
||||
return scanError
|
||||
}
|
||||
if s.endTop {
|
||||
return scanEnd
|
||||
}
|
||||
s.step(s, ' ')
|
||||
if s.endTop {
|
||||
return scanEnd
|
||||
}
|
||||
if s.err == nil {
|
||||
s.err = &SyntaxError{"unexpected end of JSON input", s.bytes}
|
||||
}
|
||||
return scanError
|
||||
}
|
||||
|
||||
// pushParseState pushes a new parse state p onto the parse stack.
|
||||
// an error state is returned if maxNestingDepth was exceeded, otherwise successState is returned.
|
||||
func (s *scanner) pushParseState(c byte, newParseState int, successState int) int {
|
||||
s.parseState = append(s.parseState, newParseState)
|
||||
if len(s.parseState) <= maxNestingDepth {
|
||||
return successState
|
||||
}
|
||||
return s.error(c, "exceeded max depth")
|
||||
}
|
||||
|
||||
// popParseState pops a parse state (already obtained) off the stack
|
||||
// and updates s.step accordingly.
|
||||
func (s *scanner) popParseState() {
|
||||
n := len(s.parseState) - 1
|
||||
s.parseState = s.parseState[0:n]
|
||||
if n == 0 {
|
||||
s.step = stateEndTop
|
||||
s.endTop = true
|
||||
} else {
|
||||
s.step = stateEndValue
|
||||
}
|
||||
}
|
||||
|
||||
func isSpace(c byte) bool {
|
||||
return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n')
|
||||
}
|
||||
|
||||
// stateBeginValueOrEmpty is the state after reading `[`.
|
||||
func stateBeginValueOrEmpty(s *scanner, c byte) int {
|
||||
if isSpace(c) {
|
||||
return scanSkipSpace
|
||||
}
|
||||
if c == ']' {
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
return stateBeginValue(s, c)
|
||||
}
|
||||
|
||||
// stateBeginValue is the state at the beginning of the input.
|
||||
func stateBeginValue(s *scanner, c byte) int {
|
||||
if isSpace(c) {
|
||||
return scanSkipSpace
|
||||
}
|
||||
switch c {
|
||||
case '{':
|
||||
s.step = stateBeginStringOrEmpty
|
||||
return s.pushParseState(c, parseObjectKey, scanBeginObject)
|
||||
case '[':
|
||||
s.step = stateBeginValueOrEmpty
|
||||
return s.pushParseState(c, parseArrayValue, scanBeginArray)
|
||||
case '"':
|
||||
s.step = stateInString
|
||||
return scanBeginLiteral
|
||||
case '-':
|
||||
s.step = stateNeg
|
||||
return scanBeginLiteral
|
||||
case '0': // beginning of 0.123
|
||||
s.step = state0
|
||||
return scanBeginLiteral
|
||||
case 't': // beginning of true
|
||||
s.step = stateT
|
||||
return scanBeginLiteral
|
||||
case 'f': // beginning of false
|
||||
s.step = stateF
|
||||
return scanBeginLiteral
|
||||
case 'n': // beginning of null
|
||||
s.step = stateN
|
||||
return scanBeginLiteral
|
||||
}
|
||||
if '1' <= c && c <= '9' { // beginning of 1234.5
|
||||
s.step = state1
|
||||
return scanBeginLiteral
|
||||
}
|
||||
return s.error(c, "looking for beginning of value")
|
||||
}
|
||||
|
||||
// stateBeginStringOrEmpty is the state after reading `{`.
|
||||
func stateBeginStringOrEmpty(s *scanner, c byte) int {
|
||||
if isSpace(c) {
|
||||
return scanSkipSpace
|
||||
}
|
||||
if c == '}' {
|
||||
n := len(s.parseState)
|
||||
s.parseState[n-1] = parseObjectValue
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
return stateBeginString(s, c)
|
||||
}
|
||||
|
||||
// stateBeginString is the state after reading `{"key": value,`.
|
||||
func stateBeginString(s *scanner, c byte) int {
|
||||
if isSpace(c) {
|
||||
return scanSkipSpace
|
||||
}
|
||||
if c == '"' {
|
||||
s.step = stateInString
|
||||
return scanBeginLiteral
|
||||
}
|
||||
return s.error(c, "looking for beginning of object key string")
|
||||
}
|
||||
|
||||
// stateEndValue is the state after completing a value,
|
||||
// such as after reading `{}` or `true` or `["x"`.
|
||||
func stateEndValue(s *scanner, c byte) int {
|
||||
n := len(s.parseState)
|
||||
if n == 0 {
|
||||
// Completed top-level before the current byte.
|
||||
s.step = stateEndTop
|
||||
s.endTop = true
|
||||
return stateEndTop(s, c)
|
||||
}
|
||||
if isSpace(c) {
|
||||
s.step = stateEndValue
|
||||
return scanSkipSpace
|
||||
}
|
||||
ps := s.parseState[n-1]
|
||||
switch ps {
|
||||
case parseObjectKey:
|
||||
if c == ':' {
|
||||
s.parseState[n-1] = parseObjectValue
|
||||
s.step = stateBeginValue
|
||||
return scanObjectKey
|
||||
}
|
||||
return s.error(c, "after object key")
|
||||
case parseObjectValue:
|
||||
if c == ',' {
|
||||
s.parseState[n-1] = parseObjectKey
|
||||
s.step = stateBeginString
|
||||
return scanObjectValue
|
||||
}
|
||||
if c == '}' {
|
||||
s.popParseState()
|
||||
return scanEndObject
|
||||
}
|
||||
return s.error(c, "after object key:value pair")
|
||||
case parseArrayValue:
|
||||
if c == ',' {
|
||||
s.step = stateBeginValue
|
||||
return scanArrayValue
|
||||
}
|
||||
if c == ']' {
|
||||
s.popParseState()
|
||||
return scanEndArray
|
||||
}
|
||||
return s.error(c, "after array element")
|
||||
}
|
||||
return s.error(c, "")
|
||||
}
|
||||
|
||||
// stateEndTop is the state after finishing the top-level value,
|
||||
// such as after reading `{}` or `[1,2,3]`.
|
||||
// Only space characters should be seen now.
|
||||
func stateEndTop(s *scanner, c byte) int {
|
||||
if !isSpace(c) {
|
||||
// Complain about non-space byte on next call.
|
||||
s.error(c, "after top-level value")
|
||||
}
|
||||
return scanEnd
|
||||
}
|
||||
|
||||
// stateInString is the state after reading `"`.
|
||||
func stateInString(s *scanner, c byte) int {
|
||||
if c == '"' {
|
||||
s.step = stateEndValue
|
||||
return scanContinue
|
||||
}
|
||||
if c == '\\' {
|
||||
s.step = stateInStringEsc
|
||||
return scanContinue
|
||||
}
|
||||
if c < 0x20 {
|
||||
return s.error(c, "in string literal")
|
||||
}
|
||||
return scanContinue
|
||||
}
|
||||
|
||||
// stateInStringEsc is the state after reading `"\` during a quoted string.
|
||||
func stateInStringEsc(s *scanner, c byte) int {
|
||||
switch c {
|
||||
case 'b', 'f', 'n', 'r', 't', '\\', '/', '"':
|
||||
s.step = stateInString
|
||||
return scanContinue
|
||||
case 'u':
|
||||
s.step = stateInStringEscU
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in string escape code")
|
||||
}
|
||||
|
||||
// stateInStringEscU is the state after reading `"\u` during a quoted string.
|
||||
func stateInStringEscU(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||
s.step = stateInStringEscU1
|
||||
return scanContinue
|
||||
}
|
||||
// numbers
|
||||
return s.error(c, "in \\u hexadecimal character escape")
|
||||
}
|
||||
|
||||
// stateInStringEscU1 is the state after reading `"\u1` during a quoted string.
|
||||
func stateInStringEscU1(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||
s.step = stateInStringEscU12
|
||||
return scanContinue
|
||||
}
|
||||
// numbers
|
||||
return s.error(c, "in \\u hexadecimal character escape")
|
||||
}
|
||||
|
||||
// stateInStringEscU12 is the state after reading `"\u12` during a quoted string.
|
||||
func stateInStringEscU12(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||
s.step = stateInStringEscU123
|
||||
return scanContinue
|
||||
}
|
||||
// numbers
|
||||
return s.error(c, "in \\u hexadecimal character escape")
|
||||
}
|
||||
|
||||
// stateInStringEscU123 is the state after reading `"\u123` during a quoted string.
|
||||
func stateInStringEscU123(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||
s.step = stateInString
|
||||
return scanContinue
|
||||
}
|
||||
// numbers
|
||||
return s.error(c, "in \\u hexadecimal character escape")
|
||||
}
|
||||
|
||||
// stateNeg is the state after reading `-` during a number.
|
||||
func stateNeg(s *scanner, c byte) int {
|
||||
if c == '0' {
|
||||
s.step = state0
|
||||
return scanContinue
|
||||
}
|
||||
if '1' <= c && c <= '9' {
|
||||
s.step = state1
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in numeric literal")
|
||||
}
|
||||
|
||||
// state1 is the state after reading a non-zero integer during a number,
|
||||
// such as after reading `1` or `100` but not `0`.
|
||||
func state1(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
s.step = state1
|
||||
return scanContinue
|
||||
}
|
||||
return state0(s, c)
|
||||
}
|
||||
|
||||
// state0 is the state after reading `0` during a number.
|
||||
func state0(s *scanner, c byte) int {
|
||||
if c == '.' {
|
||||
s.step = stateDot
|
||||
return scanContinue
|
||||
}
|
||||
if c == 'e' || c == 'E' {
|
||||
s.step = stateE
|
||||
return scanContinue
|
||||
}
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
|
||||
// stateDot is the state after reading the integer and decimal point in a number,
|
||||
// such as after reading `1.`.
|
||||
func stateDot(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
s.step = stateDot0
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "after decimal point in numeric literal")
|
||||
}
|
||||
|
||||
// stateDot0 is the state after reading the integer, decimal point, and subsequent
|
||||
// digits of a number, such as after reading `3.14`.
|
||||
func stateDot0(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
return scanContinue
|
||||
}
|
||||
if c == 'e' || c == 'E' {
|
||||
s.step = stateE
|
||||
return scanContinue
|
||||
}
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
|
||||
// stateE is the state after reading the mantissa and e in a number,
|
||||
// such as after reading `314e` or `0.314e`.
|
||||
func stateE(s *scanner, c byte) int {
|
||||
if c == '+' || c == '-' {
|
||||
s.step = stateESign
|
||||
return scanContinue
|
||||
}
|
||||
return stateESign(s, c)
|
||||
}
|
||||
|
||||
// stateESign is the state after reading the mantissa, e, and sign in a number,
|
||||
// such as after reading `314e-` or `0.314e+`.
|
||||
func stateESign(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
s.step = stateE0
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in exponent of numeric literal")
|
||||
}
|
||||
|
||||
// stateE0 is the state after reading the mantissa, e, optional sign,
|
||||
// and at least one digit of the exponent in a number,
|
||||
// such as after reading `314e-2` or `0.314e+1` or `3.14e0`.
|
||||
func stateE0(s *scanner, c byte) int {
|
||||
if '0' <= c && c <= '9' {
|
||||
return scanContinue
|
||||
}
|
||||
return stateEndValue(s, c)
|
||||
}
|
||||
|
||||
// stateT is the state after reading `t`.
|
||||
func stateT(s *scanner, c byte) int {
|
||||
if c == 'r' {
|
||||
s.step = stateTr
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal true (expecting 'r')")
|
||||
}
|
||||
|
||||
// stateTr is the state after reading `tr`.
|
||||
func stateTr(s *scanner, c byte) int {
|
||||
if c == 'u' {
|
||||
s.step = stateTru
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal true (expecting 'u')")
|
||||
}
|
||||
|
||||
// stateTru is the state after reading `tru`.
|
||||
func stateTru(s *scanner, c byte) int {
|
||||
if c == 'e' {
|
||||
s.step = stateEndValue
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal true (expecting 'e')")
|
||||
}
|
||||
|
||||
// stateF is the state after reading `f`.
|
||||
func stateF(s *scanner, c byte) int {
|
||||
if c == 'a' {
|
||||
s.step = stateFa
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal false (expecting 'a')")
|
||||
}
|
||||
|
||||
// stateFa is the state after reading `fa`.
|
||||
func stateFa(s *scanner, c byte) int {
|
||||
if c == 'l' {
|
||||
s.step = stateFal
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal false (expecting 'l')")
|
||||
}
|
||||
|
||||
// stateFal is the state after reading `fal`.
|
||||
func stateFal(s *scanner, c byte) int {
|
||||
if c == 's' {
|
||||
s.step = stateFals
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal false (expecting 's')")
|
||||
}
|
||||
|
||||
// stateFals is the state after reading `fals`.
|
||||
func stateFals(s *scanner, c byte) int {
|
||||
if c == 'e' {
|
||||
s.step = stateEndValue
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal false (expecting 'e')")
|
||||
}
|
||||
|
||||
// stateN is the state after reading `n`.
|
||||
func stateN(s *scanner, c byte) int {
|
||||
if c == 'u' {
|
||||
s.step = stateNu
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal null (expecting 'u')")
|
||||
}
|
||||
|
||||
// stateNu is the state after reading `nu`.
|
||||
func stateNu(s *scanner, c byte) int {
|
||||
if c == 'l' {
|
||||
s.step = stateNul
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal null (expecting 'l')")
|
||||
}
|
||||
|
||||
// stateNul is the state after reading `nul`.
|
||||
func stateNul(s *scanner, c byte) int {
|
||||
if c == 'l' {
|
||||
s.step = stateEndValue
|
||||
return scanContinue
|
||||
}
|
||||
return s.error(c, "in literal null (expecting 'l')")
|
||||
}
|
||||
|
||||
// stateError is the state after reaching a syntax error,
|
||||
// such as after reading `[1}` or `5.1.2`.
|
||||
func stateError(s *scanner, c byte) int {
|
||||
return scanError
|
||||
}
|
||||
|
||||
// error records an error and switches to the error state.
|
||||
func (s *scanner) error(c byte, context string) int {
|
||||
s.step = stateError
|
||||
s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes}
|
||||
return scanError
|
||||
}
|
||||
|
||||
// quoteChar formats c as a quoted character literal.
|
||||
func quoteChar(c byte) string {
|
||||
// special cases - different from quoted strings
|
||||
if c == '\'' {
|
||||
return `'\''`
|
||||
}
|
||||
if c == '"' {
|
||||
return `'"'`
|
||||
}
|
||||
|
||||
// use quoted string with different quotation marks
|
||||
s := strconv.Quote(string(c))
|
||||
return "'" + s[1:len(s)-1] + "'"
|
||||
}
|
|
@ -1,517 +0,0 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// A Decoder reads and decodes JSON values from an input stream.
|
||||
type Decoder struct {
|
||||
r io.Reader
|
||||
buf []byte
|
||||
d decodeState
|
||||
scanp int // start of unread data in buf
|
||||
scanned int64 // amount of data already scanned
|
||||
scan scanner
|
||||
err error
|
||||
|
||||
tokenState int
|
||||
tokenStack []int
|
||||
}
|
||||
|
||||
// NewDecoder returns a new decoder that reads from r.
|
||||
//
|
||||
// The decoder introduces its own buffering and may
|
||||
// read data from r beyond the JSON values requested.
|
||||
func NewDecoder(r io.Reader) *Decoder {
|
||||
return &Decoder{r: r}
|
||||
}
|
||||
|
||||
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
|
||||
// Number instead of as a float64.
|
||||
func (dec *Decoder) UseNumber() { dec.d.useNumber = true }
|
||||
|
||||
// DisallowUnknownFields causes the Decoder to return an error when the destination
|
||||
// is a struct and the input contains object keys which do not match any
|
||||
// non-ignored, exported fields in the destination.
|
||||
func (dec *Decoder) DisallowUnknownFields() { dec.d.disallowUnknownFields = true }
|
||||
|
||||
// Decode reads the next JSON-encoded value from its
|
||||
// input and stores it in the value pointed to by v.
|
||||
//
|
||||
// See the documentation for Unmarshal for details about
|
||||
// the conversion of JSON into a Go value.
|
||||
func (dec *Decoder) Decode(v any) error {
|
||||
if dec.err != nil {
|
||||
return dec.err
|
||||
}
|
||||
|
||||
if err := dec.tokenPrepareForDecode(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !dec.tokenValueAllowed() {
|
||||
return &SyntaxError{msg: "not at beginning of value", Offset: dec.InputOffset()}
|
||||
}
|
||||
|
||||
// Read whole value into buffer.
|
||||
n, err := dec.readValue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dec.d.init(dec.buf[dec.scanp : dec.scanp+n])
|
||||
dec.scanp += n
|
||||
|
||||
// Don't save err from unmarshal into dec.err:
|
||||
// the connection is still usable since we read a complete JSON
|
||||
// object from it before the error happened.
|
||||
err = dec.d.unmarshal(v)
|
||||
|
||||
// fixup token streaming state
|
||||
dec.tokenValueEnd()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Buffered returns a reader of the data remaining in the Decoder's
|
||||
// buffer. The reader is valid until the next call to Decode.
|
||||
func (dec *Decoder) Buffered() io.Reader {
|
||||
return bytes.NewReader(dec.buf[dec.scanp:])
|
||||
}
|
||||
|
||||
// readValue reads a JSON value into dec.buf.
|
||||
// It returns the length of the encoding.
|
||||
func (dec *Decoder) readValue() (int, error) {
|
||||
dec.scan.reset()
|
||||
|
||||
scanp := dec.scanp
|
||||
var err error
|
||||
Input:
|
||||
// help the compiler see that scanp is never negative, so it can remove
|
||||
// some bounds checks below.
|
||||
for scanp >= 0 {
|
||||
|
||||
// Look in the buffer for a new value.
|
||||
for ; scanp < len(dec.buf); scanp++ {
|
||||
c := dec.buf[scanp]
|
||||
dec.scan.bytes++
|
||||
switch dec.scan.step(&dec.scan, c) {
|
||||
case scanEnd:
|
||||
// scanEnd is delayed one byte so we decrement
|
||||
// the scanner bytes count by 1 to ensure that
|
||||
// this value is correct in the next call of Decode.
|
||||
dec.scan.bytes--
|
||||
break Input
|
||||
case scanEndObject, scanEndArray:
|
||||
// scanEnd is delayed one byte.
|
||||
// We might block trying to get that byte from src,
|
||||
// so instead invent a space byte.
|
||||
if stateEndValue(&dec.scan, ' ') == scanEnd {
|
||||
scanp++
|
||||
break Input
|
||||
}
|
||||
case scanError:
|
||||
dec.err = dec.scan.err
|
||||
return 0, dec.scan.err
|
||||
}
|
||||
}
|
||||
|
||||
// Did the last read have an error?
|
||||
// Delayed until now to allow buffer scan.
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
if dec.scan.step(&dec.scan, ' ') == scanEnd {
|
||||
break Input
|
||||
}
|
||||
if nonSpace(dec.buf) {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
}
|
||||
dec.err = err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
n := scanp - dec.scanp
|
||||
err = dec.refill()
|
||||
scanp = dec.scanp + n
|
||||
}
|
||||
return scanp - dec.scanp, nil
|
||||
}
|
||||
|
||||
func (dec *Decoder) refill() error {
|
||||
// Make room to read more into the buffer.
|
||||
// First slide down data already consumed.
|
||||
if dec.scanp > 0 {
|
||||
dec.scanned += int64(dec.scanp)
|
||||
n := copy(dec.buf, dec.buf[dec.scanp:])
|
||||
dec.buf = dec.buf[:n]
|
||||
dec.scanp = 0
|
||||
}
|
||||
|
||||
// Grow buffer if not large enough.
|
||||
const minRead = 512
|
||||
if cap(dec.buf)-len(dec.buf) < minRead {
|
||||
newBuf := make([]byte, len(dec.buf), 2*cap(dec.buf)+minRead)
|
||||
copy(newBuf, dec.buf)
|
||||
dec.buf = newBuf
|
||||
}
|
||||
|
||||
// Read. Delay error for next iteration (after scan).
|
||||
n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
|
||||
dec.buf = dec.buf[0 : len(dec.buf)+n]
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func nonSpace(b []byte) bool {
|
||||
for _, c := range b {
|
||||
if !isSpace(c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// An Encoder writes JSON values to an output stream.
|
||||
type Encoder struct {
|
||||
w io.Writer
|
||||
err error
|
||||
escapeHTML bool
|
||||
|
||||
indentBuf *bytes.Buffer
|
||||
indentPrefix string
|
||||
indentValue string
|
||||
}
|
||||
|
||||
// NewEncoder returns a new encoder that writes to w.
|
||||
func NewEncoder(w io.Writer) *Encoder {
|
||||
return &Encoder{w: w, escapeHTML: true}
|
||||
}
|
||||
|
||||
// Encode writes the JSON encoding of v to the stream,
|
||||
// followed by a newline character.
|
||||
//
|
||||
// See the documentation for Marshal for details about the
|
||||
// conversion of Go values to JSON.
|
||||
func (enc *Encoder) Encode(v any) error {
|
||||
if enc.err != nil {
|
||||
return enc.err
|
||||
}
|
||||
|
||||
e := newEncodeState()
|
||||
defer encodeStatePool.Put(e)
|
||||
|
||||
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Terminate each value with a newline.
|
||||
// This makes the output look a little nicer
|
||||
// when debugging, and some kind of space
|
||||
// is required if the encoded value was a number,
|
||||
// so that the reader knows there aren't more
|
||||
// digits coming.
|
||||
e.WriteByte('\n')
|
||||
|
||||
b := e.Bytes()
|
||||
if enc.indentPrefix != "" || enc.indentValue != "" {
|
||||
if enc.indentBuf == nil {
|
||||
enc.indentBuf = new(bytes.Buffer)
|
||||
}
|
||||
enc.indentBuf.Reset()
|
||||
err = Indent(enc.indentBuf, b, enc.indentPrefix, enc.indentValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b = enc.indentBuf.Bytes()
|
||||
}
|
||||
if _, err = enc.w.Write(b); err != nil {
|
||||
enc.err = err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// SetIndent instructs the encoder to format each subsequent encoded
|
||||
// value as if indented by the package-level function Indent(dst, src, prefix, indent).
|
||||
// Calling SetIndent("", "") disables indentation.
|
||||
func (enc *Encoder) SetIndent(prefix, indent string) {
|
||||
enc.indentPrefix = prefix
|
||||
enc.indentValue = indent
|
||||
}
|
||||
|
||||
// SetEscapeHTML specifies whether problematic HTML characters
|
||||
// should be escaped inside JSON quoted strings.
|
||||
// The default behavior is to escape &, <, and > to \u0026, \u003c, and \u003e
|
||||
// to avoid certain safety problems that can arise when embedding JSON in HTML.
|
||||
//
|
||||
// In non-HTML settings where the escaping interferes with the readability
|
||||
// of the output, SetEscapeHTML(false) disables this behavior.
|
||||
func (enc *Encoder) SetEscapeHTML(on bool) {
|
||||
enc.escapeHTML = on
|
||||
}
|
||||
|
||||
// RawMessage is a raw encoded JSON value.
|
||||
// It implements Marshaler and Unmarshaler and can
|
||||
// be used to delay JSON decoding or precompute a JSON encoding.
|
||||
type RawMessage []byte
|
||||
|
||||
// MarshalJSON returns m as the JSON encoding of m.
|
||||
func (m RawMessage) MarshalJSON() ([]byte, error) {
|
||||
if m == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON sets *m to a copy of data.
|
||||
func (m *RawMessage) UnmarshalJSON(data []byte) error {
|
||||
if m == nil {
|
||||
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
|
||||
}
|
||||
*m = append((*m)[0:0], data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ Marshaler = (*RawMessage)(nil)
|
||||
_ Unmarshaler = (*RawMessage)(nil)
|
||||
)
|
||||
|
||||
// A Token holds a value of one of these types:
|
||||
//
|
||||
// Delim, for the four JSON delimiters [ ] { }
|
||||
// bool, for JSON booleans
|
||||
// float64, for JSON numbers
|
||||
// Number, for JSON numbers
|
||||
// string, for JSON string literals
|
||||
// nil, for JSON null
|
||||
type Token any
|
||||
|
||||
const (
|
||||
tokenTopValue = iota
|
||||
tokenArrayStart
|
||||
tokenArrayValue
|
||||
tokenArrayComma
|
||||
tokenObjectStart
|
||||
tokenObjectKey
|
||||
tokenObjectColon
|
||||
tokenObjectValue
|
||||
tokenObjectComma
|
||||
)
|
||||
|
||||
// advance tokenstate from a separator state to a value state
|
||||
func (dec *Decoder) tokenPrepareForDecode() error {
|
||||
// Note: Not calling peek before switch, to avoid
|
||||
// putting peek into the standard Decode path.
|
||||
// peek is only called when using the Token API.
|
||||
switch dec.tokenState {
|
||||
case tokenArrayComma:
|
||||
c, err := dec.peek()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c != ',' {
|
||||
return &SyntaxError{"expected comma after array element", dec.InputOffset()}
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenArrayValue
|
||||
case tokenObjectColon:
|
||||
c, err := dec.peek()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c != ':' {
|
||||
return &SyntaxError{"expected colon after object key", dec.InputOffset()}
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenObjectValue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dec *Decoder) tokenValueAllowed() bool {
|
||||
switch dec.tokenState {
|
||||
case tokenTopValue, tokenArrayStart, tokenArrayValue, tokenObjectValue:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (dec *Decoder) tokenValueEnd() {
|
||||
switch dec.tokenState {
|
||||
case tokenArrayStart, tokenArrayValue:
|
||||
dec.tokenState = tokenArrayComma
|
||||
case tokenObjectValue:
|
||||
dec.tokenState = tokenObjectComma
|
||||
}
|
||||
}
|
||||
|
||||
// A Delim is a JSON array or object delimiter, one of [ ] { or }.
|
||||
type Delim rune
|
||||
|
||||
func (d Delim) String() string {
|
||||
return string(d)
|
||||
}
|
||||
|
||||
// Token returns the next JSON token in the input stream.
|
||||
// At the end of the input stream, Token returns nil, io.EOF.
|
||||
//
|
||||
// Token guarantees that the delimiters [ ] { } it returns are
|
||||
// properly nested and matched: if Token encounters an unexpected
|
||||
// delimiter in the input, it will return an error.
|
||||
//
|
||||
// The input stream consists of basic JSON values—bool, string,
|
||||
// number, and null—along with delimiters [ ] { } of type Delim
|
||||
// to mark the start and end of arrays and objects.
|
||||
// Commas and colons are elided.
|
||||
func (dec *Decoder) Token() (Token, error) {
|
||||
for {
|
||||
c, err := dec.peek()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch c {
|
||||
case '[':
|
||||
if !dec.tokenValueAllowed() {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
|
||||
dec.tokenState = tokenArrayStart
|
||||
return Delim('['), nil
|
||||
|
||||
case ']':
|
||||
if dec.tokenState != tokenArrayStart && dec.tokenState != tokenArrayComma {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
|
||||
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
|
||||
dec.tokenValueEnd()
|
||||
return Delim(']'), nil
|
||||
|
||||
case '{':
|
||||
if !dec.tokenValueAllowed() {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
|
||||
dec.tokenState = tokenObjectStart
|
||||
return Delim('{'), nil
|
||||
|
||||
case '}':
|
||||
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
|
||||
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
|
||||
dec.tokenValueEnd()
|
||||
return Delim('}'), nil
|
||||
|
||||
case ':':
|
||||
if dec.tokenState != tokenObjectColon {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenObjectValue
|
||||
continue
|
||||
|
||||
case ',':
|
||||
if dec.tokenState == tokenArrayComma {
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenArrayValue
|
||||
continue
|
||||
}
|
||||
if dec.tokenState == tokenObjectComma {
|
||||
dec.scanp++
|
||||
dec.tokenState = tokenObjectKey
|
||||
continue
|
||||
}
|
||||
return dec.tokenError(c)
|
||||
|
||||
case '"':
|
||||
if dec.tokenState == tokenObjectStart || dec.tokenState == tokenObjectKey {
|
||||
var x string
|
||||
old := dec.tokenState
|
||||
dec.tokenState = tokenTopValue
|
||||
err := dec.Decode(&x)
|
||||
dec.tokenState = old
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dec.tokenState = tokenObjectColon
|
||||
return x, nil
|
||||
}
|
||||
fallthrough
|
||||
|
||||
default:
|
||||
if !dec.tokenValueAllowed() {
|
||||
return dec.tokenError(c)
|
||||
}
|
||||
var x any
|
||||
if err := dec.Decode(&x); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dec *Decoder) tokenError(c byte) (Token, error) {
|
||||
var context string
|
||||
switch dec.tokenState {
|
||||
case tokenTopValue:
|
||||
context = " looking for beginning of value"
|
||||
case tokenArrayStart, tokenArrayValue, tokenObjectValue:
|
||||
context = " looking for beginning of value"
|
||||
case tokenArrayComma:
|
||||
context = " after array element"
|
||||
case tokenObjectKey:
|
||||
context = " looking for beginning of object key string"
|
||||
case tokenObjectColon:
|
||||
context = " after object key"
|
||||
case tokenObjectComma:
|
||||
context = " after object key:value pair"
|
||||
}
|
||||
return nil, &SyntaxError{"invalid character " + quoteChar(c) + context, dec.InputOffset()}
|
||||
}
|
||||
|
||||
// More reports whether there is another element in the
|
||||
// current array or object being parsed.
|
||||
func (dec *Decoder) More() bool {
|
||||
c, err := dec.peek()
|
||||
return err == nil && c != ']' && c != '}'
|
||||
}
|
||||
|
||||
func (dec *Decoder) peek() (byte, error) {
|
||||
var err error
|
||||
for {
|
||||
for i := dec.scanp; i < len(dec.buf); i++ {
|
||||
c := dec.buf[i]
|
||||
if isSpace(c) {
|
||||
continue
|
||||
}
|
||||
dec.scanp = i
|
||||
return c, nil
|
||||
}
|
||||
// buffer has been scanned, now report any error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = dec.refill()
|
||||
}
|
||||
}
|
||||
|
||||
// InputOffset returns the input stream byte offset of the current decoder position.
|
||||
// The offset gives the location of the end of the most recently returned token
|
||||
// and the beginning of the next token.
|
||||
func (dec *Decoder) InputOffset() int64 {
|
||||
return dec.scanned + int64(dec.scanp)
|
||||
}
|
|
@ -1,218 +0,0 @@
|
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import "unicode/utf8"
|
||||
|
||||
// safeSet holds the value true if the ASCII character with the given array
|
||||
// position can be represented inside a JSON string without any further
|
||||
// escaping.
|
||||
//
|
||||
// All values are true except for the ASCII control characters (0-31), the
|
||||
// double quote ("), and the backslash character ("\").
|
||||
var safeSet = [utf8.RuneSelf]bool{
|
||||
' ': true,
|
||||
'!': true,
|
||||
'"': false,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': true,
|
||||
'\'': true,
|
||||
'(': true,
|
||||
')': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
',': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'/': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
':': true,
|
||||
';': true,
|
||||
'<': true,
|
||||
'=': true,
|
||||
'>': true,
|
||||
'?': true,
|
||||
'@': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'V': true,
|
||||
'W': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'[': true,
|
||||
'\\': false,
|
||||
']': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'{': true,
|
||||
'|': true,
|
||||
'}': true,
|
||||
'~': true,
|
||||
'\u007f': true,
|
||||
}
|
||||
|
||||
// htmlSafeSet holds the value true if the ASCII character with the given
|
||||
// array position can be safely represented inside a JSON string, embedded
|
||||
// inside of HTML <script> tags, without any additional escaping.
|
||||
//
|
||||
// All values are true except for the ASCII control characters (0-31), the
|
||||
// double quote ("), the backslash character ("\"), HTML opening and closing
|
||||
// tags ("<" and ">"), and the ampersand ("&").
|
||||
var htmlSafeSet = [utf8.RuneSelf]bool{
|
||||
' ': true,
|
||||
'!': true,
|
||||
'"': false,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': false,
|
||||
'\'': true,
|
||||
'(': true,
|
||||
')': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
',': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'/': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
':': true,
|
||||
';': true,
|
||||
'<': false,
|
||||
'=': true,
|
||||
'>': false,
|
||||
'?': true,
|
||||
'@': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'V': true,
|
||||
'W': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'[': true,
|
||||
'\\': false,
|
||||
']': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'{': true,
|
||||
'|': true,
|
||||
'}': true,
|
||||
'~': true,
|
||||
'\u007f': true,
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// tagOptions is the string following a comma in a struct field's "json"
|
||||
// tag, or the empty string. It does not include the leading comma.
|
||||
type tagOptions string
|
||||
|
||||
// parseTag splits a struct field's json tag into its name and
|
||||
// comma-separated options.
|
||||
func parseTag(tag string) (string, tagOptions) {
|
||||
tag, opt, _ := strings.Cut(tag, ",")
|
||||
return tag, tagOptions(opt)
|
||||
}
|
||||
|
||||
// Contains reports whether a comma-separated list of options
|
||||
// contains a particular substr flag. substr must be surrounded by a
|
||||
// string boundary or commas.
|
||||
func (o tagOptions) Contains(optionName string) bool {
|
||||
if len(o) == 0 {
|
||||
return false
|
||||
}
|
||||
s := string(o)
|
||||
for s != "" {
|
||||
var name string
|
||||
name, s, _ = strings.Cut(s, ",")
|
||||
if name == optionName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -2,6 +2,8 @@ package json
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
|
@ -9,13 +11,18 @@ import (
|
|||
)
|
||||
|
||||
func UnmarshalExtended[T any](content []byte) (T, error) {
|
||||
decoder := NewDecoder(NewCommentFilter(bytes.NewReader(content)))
|
||||
return UnmarshalExtendedContext[T](context.Background(), content)
|
||||
}
|
||||
|
||||
func UnmarshalExtendedContext[T any](ctx context.Context, content []byte) (T, error) {
|
||||
decoder := NewDecoderContext(ctx, NewCommentFilter(bytes.NewReader(content)))
|
||||
var value T
|
||||
err := decoder.Decode(&value)
|
||||
if err == nil {
|
||||
return value, err
|
||||
}
|
||||
if syntaxError, isSyntaxError := err.(*SyntaxError); isSyntaxError {
|
||||
var syntaxError *SyntaxError
|
||||
if errors.As(err, &syntaxError) {
|
||||
prefix := string(content[:syntaxError.Offset])
|
||||
row := strings.Count(prefix, "\n") + 1
|
||||
column := len(prefix) - strings.LastIndex(prefix, "\n") - 1
|
||||
|
|
9
common/json/unmarshal_context.go
Normal file
9
common/json/unmarshal_context.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
//go:build go1.20 && !without_contextjson
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
json "github.com/sagernet/sing/common/json/internal/contextjson"
|
||||
)
|
||||
|
||||
var UnmarshalDisallowUnknownFields = json.UnmarshalDisallowUnknownFields
|
13
common/json/unmarshal_std.go
Normal file
13
common/json/unmarshal_std.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
//go:build !go1.20 || without_contextjson
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
func UnmarshalDisallowUnknownFields(content []byte, value any) error {
|
||||
decoder := NewDecoder(bytes.NewReader(content))
|
||||
decoder.DisallowUnknownFields()
|
||||
return decoder.Decode(value)
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
package metadata
|
||||
|
||||
// Deprecated: wtf is this?
|
||||
type Metadata struct {
|
||||
Protocol string
|
||||
Source Socksaddr
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -116,7 +115,7 @@ func (s *Serializer) WriteAddrPort(writer io.Writer, destination Socksaddr) erro
|
|||
return err
|
||||
}
|
||||
if !isBuffer {
|
||||
err = rw.WriteBytes(writer, buffer.Bytes())
|
||||
err = common.Error(writer.Write(buffer.Bytes()))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -130,7 +129,8 @@ func (s *Serializer) AddrPortLen(destination Socksaddr) int {
|
|||
}
|
||||
|
||||
func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) {
|
||||
af, err := rw.ReadByte(reader)
|
||||
var af byte
|
||||
err := binary.Read(reader, binary.BigEndian, &af)
|
||||
if err != nil {
|
||||
return Socksaddr{}, err
|
||||
}
|
||||
|
@ -164,11 +164,12 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) {
|
|||
}
|
||||
|
||||
func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) {
|
||||
port, err := rw.ReadBytes(reader, 2)
|
||||
var port uint16
|
||||
err := binary.Read(reader, binary.BigEndian, &port)
|
||||
if err != nil {
|
||||
return 0, E.Cause(err, "read port")
|
||||
}
|
||||
return binary.BigEndian.Uint16(port), nil
|
||||
return port, nil
|
||||
}
|
||||
|
||||
func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err error) {
|
||||
|
@ -195,11 +196,17 @@ func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err
|
|||
}
|
||||
|
||||
func ReadSockString(reader io.Reader) (string, error) {
|
||||
strLen, err := rw.ReadByte(reader)
|
||||
var strLen byte
|
||||
err := binary.Read(reader, binary.BigEndian, &strLen)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return rw.ReadString(reader, int(strLen))
|
||||
strBytes := make([]byte, strLen)
|
||||
_, err = io.ReadFull(reader, strBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(strBytes), nil
|
||||
}
|
||||
|
||||
func WriteSocksString(buffer *buf.Buffer, str string) error {
|
||||
|
|
15
common/minmax.go
Normal file
15
common/minmax.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
//go:build go1.21
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
)
|
||||
|
||||
func Min[T cmp.Ordered](x, y T) T {
|
||||
return min(x, y)
|
||||
}
|
||||
|
||||
func Max[T cmp.Ordered](x, y T) T {
|
||||
return max(x, y)
|
||||
}
|
19
common/minmax_compat.go
Normal file
19
common/minmax_compat.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
//go:build go1.20 && !go1.21
|
||||
|
||||
package common
|
||||
|
||||
import "github.com/sagernet/sing/common/x/constraints"
|
||||
|
||||
func Min[T constraints.Ordered](x, y T) T {
|
||||
if x < y {
|
||||
return x
|
||||
}
|
||||
return y
|
||||
}
|
||||
|
||||
func Max[T constraints.Ordered](x, y T) T {
|
||||
if x < y {
|
||||
return y
|
||||
}
|
||||
return x
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
|
@ -70,8 +71,39 @@ type ExtendedConn interface {
|
|||
net.Conn
|
||||
}
|
||||
|
||||
type CloseHandlerFunc = func(it error)
|
||||
|
||||
func AppendClose(parent CloseHandlerFunc, onClose CloseHandlerFunc) CloseHandlerFunc {
|
||||
if onClose == nil {
|
||||
panic("nil onClose")
|
||||
}
|
||||
if parent == nil {
|
||||
return onClose
|
||||
}
|
||||
return func(it error) {
|
||||
onClose(it)
|
||||
parent(it)
|
||||
}
|
||||
}
|
||||
|
||||
func OnceClose(onClose CloseHandlerFunc) CloseHandlerFunc {
|
||||
var once sync.Once
|
||||
return func(it error) {
|
||||
once.Do(func() {
|
||||
onClose(it)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated: Use TCPConnectionHandlerEx instead.
|
||||
type TCPConnectionHandler interface {
|
||||
NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error
|
||||
NewConnection(ctx context.Context, conn net.Conn,
|
||||
//nolint:staticcheck
|
||||
metadata M.Metadata) error
|
||||
}
|
||||
|
||||
type TCPConnectionHandlerEx interface {
|
||||
NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc)
|
||||
}
|
||||
|
||||
type NetPacketConn interface {
|
||||
|
@ -85,12 +117,26 @@ type BindPacketConn interface {
|
|||
net.Conn
|
||||
}
|
||||
|
||||
// Deprecated: Use UDPHandlerEx instead.
|
||||
type UDPHandler interface {
|
||||
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
|
||||
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer,
|
||||
//nolint:staticcheck
|
||||
metadata M.Metadata) error
|
||||
}
|
||||
|
||||
type UDPHandlerEx interface {
|
||||
NewPacketEx(buffer *buf.Buffer, source M.Socksaddr)
|
||||
}
|
||||
|
||||
// Deprecated: Use UDPConnectionHandlerEx instead.
|
||||
type UDPConnectionHandler interface {
|
||||
NewPacketConnection(ctx context.Context, conn PacketConn, metadata M.Metadata) error
|
||||
NewPacketConnection(ctx context.Context, conn PacketConn,
|
||||
//nolint:staticcheck
|
||||
metadata M.Metadata) error
|
||||
}
|
||||
|
||||
type UDPConnectionHandlerEx interface {
|
||||
NewPacketConnectionEx(ctx context.Context, conn PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc)
|
||||
}
|
||||
|
||||
type CachedReader interface {
|
||||
|
@ -101,11 +147,6 @@ type CachedPacketReader interface {
|
|||
ReadCachedPacket() *PacketBuffer
|
||||
}
|
||||
|
||||
type PacketBuffer struct {
|
||||
Buffer *buf.Buffer
|
||||
Destination M.Socksaddr
|
||||
}
|
||||
|
||||
type WithUpstreamReader interface {
|
||||
UpstreamReader() any
|
||||
}
|
||||
|
|
|
@ -13,10 +13,6 @@ type Dialer interface {
|
|||
ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error)
|
||||
}
|
||||
|
||||
type PayloadDialer interface {
|
||||
DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error)
|
||||
}
|
||||
|
||||
type ParallelDialer interface {
|
||||
Dialer
|
||||
DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue