Compare commits

...

136 commits
v0.3.1 ... dev

Author SHA1 Message Date
世界
159e489fc3
Add E.Cause1 2025-04-03 18:26:45 +08:00
世界
d39c2c2fdd
socks: Add custom udp listener 2025-03-26 13:18:24 +08:00
世界
ea82ac275f
Add freelru.GetWithLifetimeNoExpire 2025-03-26 13:18:18 +08:00
世界
ea0ac932ae
Add winiphlpapi 2025-03-26 13:18:17 +08:00
世界
2b41455f5a
Fix udpnat2 handler again 2025-03-26 12:46:15 +08:00
世界
23b0180a1b
Fix crash on udpnat2 handler 2025-03-24 18:11:10 +08:00
世界
ce1b4851a4
Fix socks5 UDP 2025-03-16 10:23:29 +08:00
世界
2238a05966
Fix merge objects 2025-03-16 10:23:29 +08:00
世界
b55d1c78b3
bufio: Add destination NAT packet conn 2025-03-09 15:20:32 +08:00
世界
d54716612c
Fix syscall packet read waiter for Windows 2025-02-28 12:07:45 +08:00
世界
9eafc7fc62
udpnat2: Fix crash 2025-02-10 15:08:18 +08:00
世界
d8153df67f
Add ENOTCONN to IsClosed 2025-02-06 08:41:32 +08:00
世界
d9f6eb136d
Fix set windows system time 2025-01-09 23:30:25 +08:00
世界
4dabb9be97
freelru: Fix GetAndRefreshOrAdd 2025-01-09 15:59:26 +08:00
世界
be9840c70f
listable: Fix incorrect unmarshaling of null to []T{null} 2025-01-09 15:57:12 +08:00
世界
aa7d2543a3
Fix errors usage 2024-12-16 09:20:34 +08:00
世界
33beacc053
Fix socks5 UDP handshake 2024-12-14 18:16:15 +08:00
世界
442cceb9fa
Fix disable UDP fragment 2024-12-12 20:43:56 +08:00
世界
3374a45475
Fix socks5 UDP implementation 2024-12-10 19:53:57 +08:00
世界
73776cf797
Fix lru test 2024-12-10 19:42:55 +08:00
世界
957166799e
Fix CloseOnHandshakeFailure 2024-12-04 17:14:58 +08:00
世界
809d8eca13
freelru: fix PurgeExpired 2024-12-04 11:36:20 +08:00
世界
9f69e7f9f7
E: IsClosedOrCanceled check IsTimeout 2024-12-01 20:19:37 +08:00
世界
478265cd45
badoption: Finish netip options 2024-12-01 14:33:23 +08:00
世界
3f30aaf25e
freelru: purge all expired items 2024-11-30 16:06:59 +08:00
世界
39040e06dc
udpnat2: Fix concurrency 2024-11-28 13:51:17 +08:00
世界
6edd2ce0ea
freelru: Update source and add GetAndRefreshOrAdd 2024-11-28 13:51:17 +08:00
世界
0a2e2a3eaf
udpnat2: Fix timeout 2024-11-27 18:02:22 +08:00
世界
4ba1eb123c
Fix set timeout 2024-11-27 17:28:18 +08:00
世界
c44912a861
freelru: Fix purge 2024-11-27 13:51:08 +08:00
世界
a8f5bf4eb0
udpnat2: Add timeout check 2024-11-26 19:08:35 +08:00
世界
30e9d91b57
Fix AppendClose 2024-11-26 12:21:37 +08:00
世界
7fd3517e4d
udpnat2: Add purge expire ticker 2024-11-26 12:21:37 +08:00
世界
a8285e06a5
udpnat2: Implement set timeout for nat conn 2024-11-26 12:21:37 +08:00
世界
3613ead480
freelru: Add PeekWithLifetime and UpdateLifetime 2024-11-26 11:29:14 +08:00
世界
c8f251c668
Fix copy count 2024-11-24 19:02:21 +08:00
世界
fa5355e99e
bufio: more copy funcs 2024-11-20 11:27:20 +08:00
世界
30fbafd954
udpnat2: Add cache funcs 2024-11-18 12:14:35 +08:00
世界
fdca9b3f8e
badjson: Fix Listable 2024-11-16 16:03:00 +08:00
世界
e52e04f721
Fix HandshakeFailure usages 2024-11-15 16:27:03 +08:00
世界
7f621fdd78
Add freelru.SetUpdateLifetimeOnGet/GetWithLifetime 2024-11-14 17:49:49 +08:00
世界
ae139d9ee1
Update N.PayloadDialer 2024-11-14 17:49:49 +08:00
世界
c432befd02
http: Fix proxying websocket 2024-11-13 19:02:07 +08:00
世界
cc7e630923
control: Refactor interface finder 2024-11-12 20:15:50 +08:00
世界
0998999911
udpnat2: Fix missing shared impl 2024-11-09 11:40:27 +08:00
世界
72ff654ee0
shared: Add SetHealthCheck to interface 2024-11-09 11:40:27 +08:00
世界
11ffb962ae
freelru: Fix impl 2024-11-09 11:40:27 +08:00
世界
fcb19641e6
freelru: Copy shared source 2024-11-09 11:40:27 +08:00
世界
524a6bd0d1
udpnat2: Set upstream to writer 2024-11-09 11:40:27 +08:00
世界
b5f9e70ffd
badjson: Fix Listable 2024-11-09 11:40:27 +08:00
世界
c80c8f907c
badjson: Add context marshaler/unmarshaler 2024-11-05 18:43:05 +08:00
世界
a4eb7fa900
udpnat2: Add SetHandler 2024-11-05 18:43:05 +08:00
世界
7ec09d6045
udpnat2: New synced udp nat service 2024-11-05 18:43:04 +08:00
世界
0641c71805
maphash: copy source from v0.1.0 2024-11-05 18:43:04 +08:00
世界
e7ec021b81
freelru: copy source from v0.14.0 2024-11-05 18:43:04 +08:00
世界
0f2447a95b
Crazy sekai overturns the small pond 2024-11-05 18:43:04 +08:00
世界
72db784fc7
Add bind.Interface.Flags 2024-11-04 11:05:38 +08:00
世界
d59ac57aaa
Add go1.21 compat funcs 2024-10-19 09:09:15 +08:00
世界
c63546470b
Add Update() error to control.InterfaceFinder 2024-09-22 22:15:12 +08:00
世界
55908bea36
Update linter configuration 2024-09-14 21:36:41 +08:00
世界
6567829958
Fix cached conn eats up read deadlines 2024-09-14 10:11:50 +08:00
世界
c324d4143d
json: Add badoption templates 2024-09-10 23:57:22 +08:00
世界
0acb36c118
Minor fixes 2024-09-10 23:46:03 +08:00
世界
26511a251f
udpnat: Fix read deadline not initialized 2024-08-19 17:56:31 +08:00
世界
afd8993773
windnsapi: Fix incorrect error checking 2024-08-19 17:47:42 +08:00
世界
96bef0733f
Fix bad group usages 2024-08-18 11:15:20 +08:00
世界
ec1df651e8
Update golangci-lint configuration 2024-08-18 11:15:15 +08:00
世界
e33b1d67d5
bufio: Add ReadBufferSize and ReadPacketSize 2024-08-18 09:14:52 +08:00
世界
ed6cde73f7
udpnat: Implement read deadline 2024-08-18 09:14:44 +08:00
世界
73cc65605e
pipe: Make pipeDeadline public for use 2024-08-18 08:59:04 +08:00
世界
6c19e0736d
windows: Migrate to mkwinsyscall 2024-08-09 12:00:21 +08:00
世界
08e8c02fb1
Fix usage of PowerUnregisterSuspendResumeNotification 2024-08-08 15:27:35 +08:00
世界
7beca62e4f
Improve winpowrprof callback 2024-08-06 13:19:09 +08:00
世界
e422e3d048
Reuse winpowrprof callback 2024-08-06 12:23:56 +08:00
世界
fa81eabc29
ntp: Fix a bad context usage 2024-08-06 12:23:56 +08:00
世界
4498e57839
task: Fix context not continuous 2024-07-31 18:06:42 +08:00
世界
f97054e917
ntp: Ignore setup error 2024-07-31 10:25:01 +08:00
世界
a2f9fef936
domain: Add adguard matcher 2024-07-26 08:00:09 +08:00
世界
7893a74f75
json: Add UnmarshalDisallowUnknownFields 2024-07-22 12:02:46 +08:00
世界
332e470075
domain: Add a new label type for domain suffix 2024-07-17 15:55:30 +08:00
世界
2bf9cc7253
Remove unused legacy variable 2024-07-17 15:54:12 +08:00
世界
bf8fc103a4
Fix golangci-lint configuration 2024-07-17 15:53:40 +08:00
世界
774893928c
Add crazy badges to README 2024-07-02 16:39:51 +08:00
renovate[bot]
7ceaf63d41
[dependencies] Update github-actions
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
2024-07-02 16:32:54 +08:00
世界
8806e421f2
Fix renovate configuration 2024-07-02 16:29:55 +08:00
世界
c37f988a4f
Add dump for domain matcher 2024-07-02 15:49:37 +08:00
世界
0b4c0a1283
varbin: Unique pointer value format for array or map value 2024-06-24 19:51:17 +08:00
世界
4745c34b4c
varbin: Accept fixed slices 2024-06-24 13:51:49 +08:00
世界
e0196407a3
Remove bad rw usages 2024-06-24 09:42:23 +08:00
世界
d8ec9c46cc
Deprecated bad rw funcs 2024-06-24 09:42:23 +08:00
世界
a33349366d
Remove bad linkname usage 2024-06-24 09:42:23 +08:00
世界
3155c16990
Use varbin for domain matcher read write 2024-06-24 09:42:23 +08:00
世界
caa4340dc9
binary: Move varint utils to new package 2024-06-24 09:42:23 +08:00
世界
a31dba8ad2
bianry: Improve varint read and write 2024-06-23 14:07:07 +08:00
世界
9571124cf4
badjson: Add disableAppend option 2024-06-22 20:07:31 +08:00
世界
1c495c9b07
Make RoutingMark uint32 2024-06-17 12:38:23 +08:00
wwqgtxx
0f95dfe0e3
Recover context json for go1.20 2024-06-17 12:38:22 +08:00
世界
f3380c8dfe
Drop support for go1.18 and go1.19 2024-06-17 12:38:22 +08:00
世界
ab4353dd13
binary: Add variant data reader/writer 2024-06-17 12:38:22 +08:00
世界
e0e490af7b
binary: init from encoding/binary 2024-06-17 12:38:22 +08:00
世界
ad4d59e2ed
Add control.Interface.HardwareAddr 2024-06-17 12:38:21 +08:00
世界
aca2a85545
Add redirect and tproxy controls 2024-06-17 12:38:21 +08:00
世界
589c7eb4df
contentjson: Add support for tailing comma 2024-06-17 12:38:21 +08:00
世界
2873799b6d
Drop context json for go1.20 2024-06-17 12:38:21 +08:00
世界
47cc308abf
Add InterfaceFinder.Interfaces() []control.Interface 2024-06-17 12:38:21 +08:00
HystericalDragon
d9f2559214
Fix http socks close 2024-06-17 12:37:52 +08:00
世界
b8736cc58d
Update workflows 2024-06-11 21:13:49 +08:00
世界
e82ff8e2e6
Fix get source address from X-Forwarded-For 2024-06-06 22:18:55 +08:00
世界
ba68e017a9
Fix PrefixFromNet 2024-06-06 22:18:20 +08:00
世界
967afcf6c1
Update dependencies 2024-06-06 22:18:13 +08:00
世界
0c110ad733
Update dependencies 2024-05-27 19:35:12 +08:00
世界
de1b0bd772
Fix Fix socks5 packet conn 2024-05-21 15:09:17 +08:00
wwqgtxx
f67a0988a6
Remove linkname usages of x/sys/windows 2024-05-18 20:51:36 +08:00
世界
e0ee7f49e2
Remove disallowed linkname usages 2024-05-18 13:20:31 +08:00
dyhkwong
284cb5ce98
Fix socks5 packet conn 2024-05-17 21:24:05 +08:00
世界
8fb1634c9a
Fix calculate reader headroom 2024-05-17 13:20:35 +08:00
世界
4ab8cac5eb
Update dependencies 2024-04-23 23:36:51 +08:00
世界
eec2fc325a
Improve interface finder 2024-04-12 09:05:39 +08:00
世界
2fa039945c
Add control.SetKeepAlivePeriod 2024-04-10 20:25:47 +08:00
世界
e5825dcb59
Merge time service to library 2024-04-10 20:25:47 +08:00
wwqgtxx
6b73a57a24
Add winpowrprof package 2024-04-10 20:25:46 +08:00
世界
3e2631ef0b
badjson: Improve omitempty 2024-04-10 20:25:46 +08:00
世界
4d96f15eca
Improve domain suffix match behavior 2024-04-10 20:25:46 +08:00
dyhkwong
f9c59e9940
Improve bufio.NATPacketConn 2024-04-10 20:24:58 +08:00
dyhkwong
8b68fc4d7a
Fix canceler.PacketConn 2024-04-10 20:17:34 +08:00
世界
5bfc326913
Fix syscall packet read waiter 2024-03-25 01:47:03 +08:00
世界
04152ea672
Fix canceler 2024-03-24 00:12:02 +08:00
wwqgtxx
a069af4787
Fix TypedValue 2024-03-13 16:07:00 +08:00
世界
807a51bb81
Fix crash in HTTP server again 2024-03-10 16:52:22 +08:00
DuFoxit
ec2595f010
Fix SwitchyOmega authentication failed 2024-03-05 13:04:20 +08:00
世界
c98e8b6921
Update dependencies 2024-03-05 12:51:24 +08:00
世界
a4a9ec42c6
Fix crash in HTTP server 2024-03-02 14:25:20 +08:00
世界
6e3921083b
Fix SO_BINDTOIFINDEX usage 2024-02-29 12:56:57 +08:00
世界
8e89f9b4dc
Update workflow 2024-02-29 12:54:19 +08:00
世界
5f02cb1cff
Fix HTTP server authenticate 2024-02-22 20:46:01 +08:00
世界
ef00a1ec1e
Fix badjson merge 2024-02-22 17:27:13 +08:00
189 changed files with 8980 additions and 5603 deletions

View file

@ -1,11 +1,13 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"commitMessagePrefix": "[dependencies]",
"branchName": "main",
"extends": [
"config:base",
":disableRateLimiting"
],
"baseBranches": [
"dev"
],
"packageRules": [
{
"matchManagers": [
@ -15,9 +17,9 @@
},
{
"matchManagers": [
"gomod"
"dockerfile"
],
"groupName": "gomod"
"groupName": "Dockerfile"
}
]
}

View file

@ -1,8 +1,9 @@
name: Lint
name: lint
on:
push:
branches:
- main
- dev
paths-ignore:
- '**.md'
@ -10,6 +11,7 @@ on:
- '!.github/workflows/lint.yml'
pull_request:
branches:
- main
- dev
jobs:
@ -22,16 +24,16 @@ jobs:
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: ^1.21
go-version: ^1.23
- name: Cache go module
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
key: go-${{ hashFiles('**/go.sum') }}
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
uses: golangci/golangci-lint-action@v6
with:
version: latest

View file

@ -1,8 +1,9 @@
name: Debug build
name: test
on:
push:
branches:
- main
- dev
paths-ignore:
- '**.md'
@ -10,11 +11,12 @@ on:
- '!.github/workflows/debug.yml'
pull_request:
branches:
- main
- dev
jobs:
build:
name: Linux Debug build
name: Linux
runs-on: ubuntu-latest
steps:
- name: Checkout
@ -22,55 +24,14 @@ jobs:
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: ^1.21
- name: Add cache to Go proxy
run: |
version=`git rev-parse HEAD`
mkdir build
pushd build
go mod init build
go get -v github.com/sagernet/sing@$version
popd
continue-on-error: true
- name: Build
run: |
make test
build_go118:
name: Linux Debug build (Go 1.18)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ~1.18
continue-on-error: true
- name: Build
run: |
make test
build_go119:
name: Linux Debug build (Go 1.19)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ~1.19
continue-on-error: true
go-version: ^1.23
- name: Build
run: |
make test
build_go120:
name: Linux Debug build (Go 1.20)
name: Linux (Go 1.20)
runs-on: ubuntu-latest
steps:
- name: Checkout
@ -78,15 +39,47 @@ jobs:
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: ~1.20
continue-on-error: true
- name: Build
run: |
make test
build__windows:
name: Windows Debug build
build_go121:
name: Linux (Go 1.21)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.21
continue-on-error: true
- name: Build
run: |
make test
build_go122:
name: Linux (Go 1.22)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.22
continue-on-error: true
- name: Build
run: |
make test
build_windows:
name: Windows
runs-on: windows-latest
steps:
- name: Checkout
@ -94,15 +87,15 @@ jobs:
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: ^1.21
go-version: ^1.23
continue-on-error: true
- name: Build
run: |
make test
build_darwin:
name: macOS Debug build
name: macOS
runs-on: macos-latest
steps:
- name: Checkout
@ -110,9 +103,9 @@ jobs:
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: ^1.21
go-version: ^1.23
continue-on-error: true
- name: Build
run: |

View file

@ -5,6 +5,8 @@ linters:
- govet
- gci
- staticcheck
- paralleltest
- ineffassign
linters-settings:
gci:
@ -14,4 +16,9 @@ linters-settings:
- prefix(github.com/sagernet/)
- default
staticcheck:
go: '1.20'
checks:
- all
- -SA1003
run:
go: "1.23"

View file

@ -8,14 +8,14 @@ fmt_install:
go install -v github.com/daixiang0/gci@latest
lint:
GOOS=linux golangci-lint run ./...
GOOS=android golangci-lint run ./...
GOOS=windows golangci-lint run ./...
GOOS=darwin golangci-lint run ./...
GOOS=freebsd golangci-lint run ./...
GOOS=linux golangci-lint run
GOOS=android golangci-lint run
GOOS=windows golangci-lint run
GOOS=darwin golangci-lint run
GOOS=freebsd golangci-lint run
lint_install:
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
test:
go test $(shell go list ./... | grep -v /internal/)
go test ./...

View file

@ -1,3 +1,6 @@
# sing
![test](https://github.com/sagernet/sing/actions/workflows/test.yml/badge.svg)
![lint](https://github.com/sagernet/sing/actions/workflows/lint.yml/badge.svg)
Do you hear the people sing?

View file

@ -10,26 +10,37 @@ type TypedValue[T any] struct {
value atomic.Value
}
// typedValue is a struct with determined type to resolve atomic.Value usages with interface types
// https://github.com/golang/go/issues/22550
//
// The intention to have an atomic value store for errors. However, running this code panics:
// panic: sync/atomic: store of inconsistently typed value into Value
// This is because atomic.Value requires that the underlying concrete type be the same (which is a reasonable expectation for its implementation).
// When going through the atomic.Value.Store method call, the fact that both these are of the error interface is lost.
type typedValue[T any] struct {
value T
}
func (t *TypedValue[T]) Load() T {
value := t.value.Load()
if value == nil {
return common.DefaultValue[T]()
}
return value.(T)
return value.(typedValue[T]).value
}
func (t *TypedValue[T]) Store(value T) {
t.value.Store(value)
t.value.Store(typedValue[T]{value})
}
func (t *TypedValue[T]) Swap(new T) T {
old := t.value.Swap(new)
old := t.value.Swap(typedValue[T]{new})
if old == nil {
return common.DefaultValue[T]()
}
return old.(T)
return old.(typedValue[T]).value
}
func (t *TypedValue[T]) CompareAndSwap(old, new T) bool {
return t.value.CompareAndSwap(old, new)
return t.value.CompareAndSwap(typedValue[T]{old}, typedValue[T]{new})
}

View file

@ -2,11 +2,10 @@ package baderror
import (
"context"
"errors"
"io"
"net"
"strings"
E "github.com/sagernet/sing/common/exceptions"
)
func Contains(err error, msgList ...string) bool {
@ -22,8 +21,7 @@ func WrapH2(err error) error {
if err == nil {
return nil
}
err = E.Unwrap(err)
if err == io.ErrUnexpectedEOF {
if errors.Is(err, io.ErrUnexpectedEOF) {
return io.EOF
}
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {

3
common/binary/README.md Normal file
View file

@ -0,0 +1,3 @@
# binary
mod from go 1.22.3

817
common/binary/binary.go Normal file
View file

@ -0,0 +1,817 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package binary implements simple translation between numbers and byte
// sequences and encoding and decoding of varints.
//
// Numbers are translated by reading and writing fixed-size values.
// A fixed-size value is either a fixed-size arithmetic
// type (bool, int8, uint8, int16, float32, complex64, ...)
// or an array or struct containing only fixed-size values.
//
// The varint functions encode and decode single integer values using
// a variable-length encoding; smaller values require fewer bytes.
// For a specification, see
// https://developers.google.com/protocol-buffers/docs/encoding.
//
// This package favors simplicity over efficiency. Clients that require
// high-performance serialization, especially for large data structures,
// should look at more advanced solutions such as the [encoding/gob]
// package or [google.golang.org/protobuf] for protocol buffers.
package binary
import (
"errors"
"io"
"math"
"reflect"
"sync"
)
// A ByteOrder specifies how to convert byte slices into
// 16-, 32-, or 64-bit unsigned integers.
//
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
type ByteOrder interface {
Uint16([]byte) uint16
Uint32([]byte) uint32
Uint64([]byte) uint64
PutUint16([]byte, uint16)
PutUint32([]byte, uint32)
PutUint64([]byte, uint64)
String() string
}
// AppendByteOrder specifies how to append 16-, 32-, or 64-bit unsigned integers
// into a byte slice.
//
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
type AppendByteOrder interface {
AppendUint16([]byte, uint16) []byte
AppendUint32([]byte, uint32) []byte
AppendUint64([]byte, uint64) []byte
String() string
}
// LittleEndian is the little-endian implementation of [ByteOrder] and [AppendByteOrder].
var LittleEndian littleEndian
// BigEndian is the big-endian implementation of [ByteOrder] and [AppendByteOrder].
var BigEndian bigEndian
type littleEndian struct{}
func (littleEndian) Uint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[0]) | uint16(b[1])<<8
}
func (littleEndian) PutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
}
func (littleEndian) AppendUint16(b []byte, v uint16) []byte {
return append(b,
byte(v),
byte(v>>8),
)
}
func (littleEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func (littleEndian) PutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
}
func (littleEndian) AppendUint32(b []byte, v uint32) []byte {
return append(b,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
)
}
func (littleEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
func (littleEndian) PutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
b[4] = byte(v >> 32)
b[5] = byte(v >> 40)
b[6] = byte(v >> 48)
b[7] = byte(v >> 56)
}
func (littleEndian) AppendUint64(b []byte, v uint64) []byte {
return append(b,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
byte(v>>32),
byte(v>>40),
byte(v>>48),
byte(v>>56),
)
}
func (littleEndian) String() string { return "LittleEndian" }
func (littleEndian) GoString() string { return "binary.LittleEndian" }
type bigEndian struct{}
func (bigEndian) Uint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[1]) | uint16(b[0])<<8
}
func (bigEndian) PutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 8)
b[1] = byte(v)
}
func (bigEndian) AppendUint16(b []byte, v uint16) []byte {
return append(b,
byte(v>>8),
byte(v),
)
}
func (bigEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}
func (bigEndian) PutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
b[3] = byte(v)
}
func (bigEndian) AppendUint32(b []byte, v uint32) []byte {
return append(b,
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v),
)
}
func (bigEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
}
func (bigEndian) PutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
b[3] = byte(v >> 32)
b[4] = byte(v >> 24)
b[5] = byte(v >> 16)
b[6] = byte(v >> 8)
b[7] = byte(v)
}
func (bigEndian) AppendUint64(b []byte, v uint64) []byte {
return append(b,
byte(v>>56),
byte(v>>48),
byte(v>>40),
byte(v>>32),
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v),
)
}
func (bigEndian) String() string { return "BigEndian" }
func (bigEndian) GoString() string { return "binary.BigEndian" }
func (nativeEndian) String() string { return "NativeEndian" }
func (nativeEndian) GoString() string { return "binary.NativeEndian" }
// Read reads structured binary data from r into data.
// Data must be a pointer to a fixed-size value or a slice
// of fixed-size values.
// Bytes read from r are decoded using the specified byte order
// and written to successive fields of the data.
// When decoding boolean values, a zero byte is decoded as false, and
// any other non-zero byte is decoded as true.
// When reading into structs, the field data for fields with
// blank (_) field names is skipped; i.e., blank field names
// may be used for padding.
// When reading into a struct, all non-blank fields must be exported
// or Read may panic.
//
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// Read returns [io.ErrUnexpectedEOF].
func Read(r io.Reader, order ByteOrder, data any) error {
// Fast path for basic types and slices.
if n := intDataSize(data); n != 0 {
bs := make([]byte, n)
if _, err := io.ReadFull(r, bs); err != nil {
return err
}
switch data := data.(type) {
case *bool:
*data = bs[0] != 0
case *int8:
*data = int8(bs[0])
case *uint8:
*data = bs[0]
case *int16:
*data = int16(order.Uint16(bs))
case *uint16:
*data = order.Uint16(bs)
case *int32:
*data = int32(order.Uint32(bs))
case *uint32:
*data = order.Uint32(bs)
case *int64:
*data = int64(order.Uint64(bs))
case *uint64:
*data = order.Uint64(bs)
case *float32:
*data = math.Float32frombits(order.Uint32(bs))
case *float64:
*data = math.Float64frombits(order.Uint64(bs))
case []bool:
for i, x := range bs { // Easier to loop over the input for 8-bit values.
data[i] = x != 0
}
case []int8:
for i, x := range bs {
data[i] = int8(x)
}
case []uint8:
copy(data, bs)
case []int16:
for i := range data {
data[i] = int16(order.Uint16(bs[2*i:]))
}
case []uint16:
for i := range data {
data[i] = order.Uint16(bs[2*i:])
}
case []int32:
for i := range data {
data[i] = int32(order.Uint32(bs[4*i:]))
}
case []uint32:
for i := range data {
data[i] = order.Uint32(bs[4*i:])
}
case []int64:
for i := range data {
data[i] = int64(order.Uint64(bs[8*i:]))
}
case []uint64:
for i := range data {
data[i] = order.Uint64(bs[8*i:])
}
case []float32:
for i := range data {
data[i] = math.Float32frombits(order.Uint32(bs[4*i:]))
}
case []float64:
for i := range data {
data[i] = math.Float64frombits(order.Uint64(bs[8*i:]))
}
default:
n = 0 // fast path doesn't apply
}
if n != 0 {
return nil
}
}
// Fallback to reflect-based decoding.
v := reflect.ValueOf(data)
size := -1
switch v.Kind() {
case reflect.Pointer:
v = v.Elem()
size = dataSize(v)
case reflect.Slice:
size = dataSize(v)
}
if size < 0 {
return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String())
}
d := &decoder{order: order, buf: make([]byte, size)}
if _, err := io.ReadFull(r, d.buf); err != nil {
return err
}
d.value(v)
return nil
}
// Write writes the binary representation of data into w.
// Data must be a fixed-size value or a slice of fixed-size
// values, or a pointer to such data.
// Boolean values encode as one byte: 1 for true, and 0 for false.
// Bytes written to w are encoded using the specified byte order
// and read from successive fields of the data.
// When writing structs, zero values are written for fields
// with blank (_) field names.
func Write(w io.Writer, order ByteOrder, data any) error {
// Fast path for basic types and slices.
if n := intDataSize(data); n != 0 {
bs := make([]byte, n)
switch v := data.(type) {
case *bool:
if *v {
bs[0] = 1
} else {
bs[0] = 0
}
case bool:
if v {
bs[0] = 1
} else {
bs[0] = 0
}
case []bool:
for i, x := range v {
if x {
bs[i] = 1
} else {
bs[i] = 0
}
}
case *int8:
bs[0] = byte(*v)
case int8:
bs[0] = byte(v)
case []int8:
for i, x := range v {
bs[i] = byte(x)
}
case *uint8:
bs[0] = *v
case uint8:
bs[0] = v
case []uint8:
bs = v
case *int16:
order.PutUint16(bs, uint16(*v))
case int16:
order.PutUint16(bs, uint16(v))
case []int16:
for i, x := range v {
order.PutUint16(bs[2*i:], uint16(x))
}
case *uint16:
order.PutUint16(bs, *v)
case uint16:
order.PutUint16(bs, v)
case []uint16:
for i, x := range v {
order.PutUint16(bs[2*i:], x)
}
case *int32:
order.PutUint32(bs, uint32(*v))
case int32:
order.PutUint32(bs, uint32(v))
case []int32:
for i, x := range v {
order.PutUint32(bs[4*i:], uint32(x))
}
case *uint32:
order.PutUint32(bs, *v)
case uint32:
order.PutUint32(bs, v)
case []uint32:
for i, x := range v {
order.PutUint32(bs[4*i:], x)
}
case *int64:
order.PutUint64(bs, uint64(*v))
case int64:
order.PutUint64(bs, uint64(v))
case []int64:
for i, x := range v {
order.PutUint64(bs[8*i:], uint64(x))
}
case *uint64:
order.PutUint64(bs, *v)
case uint64:
order.PutUint64(bs, v)
case []uint64:
for i, x := range v {
order.PutUint64(bs[8*i:], x)
}
case *float32:
order.PutUint32(bs, math.Float32bits(*v))
case float32:
order.PutUint32(bs, math.Float32bits(v))
case []float32:
for i, x := range v {
order.PutUint32(bs[4*i:], math.Float32bits(x))
}
case *float64:
order.PutUint64(bs, math.Float64bits(*v))
case float64:
order.PutUint64(bs, math.Float64bits(v))
case []float64:
for i, x := range v {
order.PutUint64(bs[8*i:], math.Float64bits(x))
}
}
_, err := w.Write(bs)
return err
}
// Fallback to reflect-based encoding.
v := reflect.Indirect(reflect.ValueOf(data))
size := dataSize(v)
if size < 0 {
return errors.New("binary.Write: some values are not fixed-sized in type " + reflect.TypeOf(data).String())
}
buf := make([]byte, size)
e := &encoder{order: order, buf: buf}
e.value(v)
_, err := w.Write(buf)
return err
}
// Size returns how many bytes [Write] would generate to encode the value v, which
// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data.
// If v is neither of these, Size returns -1.
func Size(v any) int {
return dataSize(reflect.Indirect(reflect.ValueOf(v)))
}
var structSize sync.Map // map[reflect.Type]int
// dataSize returns the number of bytes the actual data represented by v occupies in memory.
// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice
// it returns the length of the slice times the element size and does not count the memory
// occupied by the header. If the type of v is not acceptable, dataSize returns -1.
func dataSize(v reflect.Value) int {
switch v.Kind() {
case reflect.Slice:
if s := sizeof(v.Type().Elem()); s >= 0 {
return s * v.Len()
}
case reflect.Struct:
t := v.Type()
if size, ok := structSize.Load(t); ok {
return size.(int)
}
size := sizeof(t)
structSize.Store(t, size)
return size
default:
if v.IsValid() {
return sizeof(v.Type())
}
}
return -1
}
// sizeof returns the size >= 0 of variables for the given type or -1 if the type is not acceptable.
func sizeof(t reflect.Type) int {
switch t.Kind() {
case reflect.Array:
if s := sizeof(t.Elem()); s >= 0 {
return s * t.Len()
}
case reflect.Struct:
sum := 0
for i, n := 0, t.NumField(); i < n; i++ {
s := sizeof(t.Field(i).Type)
if s < 0 {
return -1
}
sum += s
}
return sum
case reflect.Bool,
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
return int(t.Size())
}
return -1
}
type coder struct {
order ByteOrder
buf []byte
offset int
}
type (
decoder coder
encoder coder
)
func (d *decoder) bool() bool {
x := d.buf[d.offset]
d.offset++
return x != 0
}
func (e *encoder) bool(x bool) {
if x {
e.buf[e.offset] = 1
} else {
e.buf[e.offset] = 0
}
e.offset++
}
func (d *decoder) uint8() uint8 {
x := d.buf[d.offset]
d.offset++
return x
}
func (e *encoder) uint8(x uint8) {
e.buf[e.offset] = x
e.offset++
}
func (d *decoder) uint16() uint16 {
x := d.order.Uint16(d.buf[d.offset : d.offset+2])
d.offset += 2
return x
}
func (e *encoder) uint16(x uint16) {
e.order.PutUint16(e.buf[e.offset:e.offset+2], x)
e.offset += 2
}
func (d *decoder) uint32() uint32 {
x := d.order.Uint32(d.buf[d.offset : d.offset+4])
d.offset += 4
return x
}
func (e *encoder) uint32(x uint32) {
e.order.PutUint32(e.buf[e.offset:e.offset+4], x)
e.offset += 4
}
func (d *decoder) uint64() uint64 {
x := d.order.Uint64(d.buf[d.offset : d.offset+8])
d.offset += 8
return x
}
func (e *encoder) uint64(x uint64) {
e.order.PutUint64(e.buf[e.offset:e.offset+8], x)
e.offset += 8
}
func (d *decoder) int8() int8 { return int8(d.uint8()) }
func (e *encoder) int8(x int8) { e.uint8(uint8(x)) }
func (d *decoder) int16() int16 { return int16(d.uint16()) }
func (e *encoder) int16(x int16) { e.uint16(uint16(x)) }
func (d *decoder) int32() int32 { return int32(d.uint32()) }
func (e *encoder) int32(x int32) { e.uint32(uint32(x)) }
func (d *decoder) int64() int64 { return int64(d.uint64()) }
func (e *encoder) int64(x int64) { e.uint64(uint64(x)) }
func (d *decoder) value(v reflect.Value) {
switch v.Kind() {
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
d.value(v.Index(i))
}
case reflect.Struct:
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
// Note: Calling v.CanSet() below is an optimization.
// It would be sufficient to check the field name,
// but creating the StructField info for each field is
// costly (run "go test -bench=ReadStruct" and compare
// results when making changes to this code).
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
d.value(v)
} else {
d.skip(v)
}
}
case reflect.Slice:
l := v.Len()
for i := 0; i < l; i++ {
d.value(v.Index(i))
}
case reflect.Bool:
v.SetBool(d.bool())
case reflect.Int8:
v.SetInt(int64(d.int8()))
case reflect.Int16:
v.SetInt(int64(d.int16()))
case reflect.Int32:
v.SetInt(int64(d.int32()))
case reflect.Int64:
v.SetInt(d.int64())
case reflect.Uint8:
v.SetUint(uint64(d.uint8()))
case reflect.Uint16:
v.SetUint(uint64(d.uint16()))
case reflect.Uint32:
v.SetUint(uint64(d.uint32()))
case reflect.Uint64:
v.SetUint(d.uint64())
case reflect.Float32:
v.SetFloat(float64(math.Float32frombits(d.uint32())))
case reflect.Float64:
v.SetFloat(math.Float64frombits(d.uint64()))
case reflect.Complex64:
v.SetComplex(complex(
float64(math.Float32frombits(d.uint32())),
float64(math.Float32frombits(d.uint32())),
))
case reflect.Complex128:
v.SetComplex(complex(
math.Float64frombits(d.uint64()),
math.Float64frombits(d.uint64()),
))
}
}
func (e *encoder) value(v reflect.Value) {
switch v.Kind() {
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
e.value(v.Index(i))
}
case reflect.Struct:
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
// see comment for corresponding code in decoder.value()
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
e.value(v)
} else {
e.skip(v)
}
}
case reflect.Slice:
l := v.Len()
for i := 0; i < l; i++ {
e.value(v.Index(i))
}
case reflect.Bool:
e.bool(v.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch v.Type().Kind() {
case reflect.Int8:
e.int8(int8(v.Int()))
case reflect.Int16:
e.int16(int16(v.Int()))
case reflect.Int32:
e.int32(int32(v.Int()))
case reflect.Int64:
e.int64(v.Int())
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
switch v.Type().Kind() {
case reflect.Uint8:
e.uint8(uint8(v.Uint()))
case reflect.Uint16:
e.uint16(uint16(v.Uint()))
case reflect.Uint32:
e.uint32(uint32(v.Uint()))
case reflect.Uint64:
e.uint64(v.Uint())
}
case reflect.Float32, reflect.Float64:
switch v.Type().Kind() {
case reflect.Float32:
e.uint32(math.Float32bits(float32(v.Float())))
case reflect.Float64:
e.uint64(math.Float64bits(v.Float()))
}
case reflect.Complex64, reflect.Complex128:
switch v.Type().Kind() {
case reflect.Complex64:
x := v.Complex()
e.uint32(math.Float32bits(float32(real(x))))
e.uint32(math.Float32bits(float32(imag(x))))
case reflect.Complex128:
x := v.Complex()
e.uint64(math.Float64bits(real(x)))
e.uint64(math.Float64bits(imag(x)))
}
}
}
func (d *decoder) skip(v reflect.Value) {
d.offset += dataSize(v)
}
func (e *encoder) skip(v reflect.Value) {
n := dataSize(v)
zero := e.buf[e.offset : e.offset+n]
for i := range zero {
zero[i] = 0
}
e.offset += n
}
// intDataSize returns the size of the data required to represent the data when encoded.
// It returns zero if the type cannot be implemented by the fast path in Read or Write.
func intDataSize(data any) int {
switch data := data.(type) {
case bool, int8, uint8, *bool, *int8, *uint8:
return 1
case []bool:
return len(data)
case []int8:
return len(data)
case []uint8:
return len(data)
case int16, uint16, *int16, *uint16:
return 2
case []int16:
return 2 * len(data)
case []uint16:
return 2 * len(data)
case int32, uint32, *int32, *uint32:
return 4
case []int32:
return 4 * len(data)
case []uint32:
return 4 * len(data)
case int64, uint64, *int64, *uint64:
return 8
case []int64:
return 8 * len(data)
case []uint64:
return 8 * len(data)
case float32, *float32:
return 4
case float64, *float64:
return 8
case []float32:
return 4 * len(data)
case []float64:
return 8 * len(data)
}
return 0
}

18
common/binary/export.go Normal file
View file

@ -0,0 +1,18 @@
package binary
import (
"encoding/binary"
"reflect"
)
func DataSize(t reflect.Value) int {
return dataSize(t)
}
func EncodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
(&encoder{order: order, buf: buf}).value(v)
}
func DecodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
(&decoder{order: order, buf: buf}).value(v)
}

View file

@ -0,0 +1,14 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build armbe || arm64be || m68k || mips || mips64 || mips64p32 || ppc || ppc64 || s390 || s390x || shbe || sparc || sparc64
package binary
type nativeEndian struct {
bigEndian
}
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
var NativeEndian nativeEndian

View file

@ -0,0 +1,14 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build 386 || amd64 || amd64p32 || alpha || arm || arm64 || loong64 || mipsle || mips64le || mips64p32le || nios2 || ppc64le || riscv || riscv64 || sh || wasm
package binary
type nativeEndian struct {
littleEndian
}
// NativeEndian is the native-endian implementation of [ByteOrder] and [AppendByteOrder].
var NativeEndian nativeEndian

166
common/binary/varint.go Normal file
View file

@ -0,0 +1,166 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package binary
// This file implements "varint" encoding of 64-bit integers.
// The encoding is:
// - unsigned integers are serialized 7 bits at a time, starting with the
// least significant bits
// - the most significant bit (msb) in each output byte indicates if there
// is a continuation byte (msb = 1)
// - signed integers are mapped to unsigned integers using "zig-zag"
// encoding: Positive values x are written as 2*x + 0, negative values
// are written as 2*(^x) + 1; that is, negative numbers are complemented
// and whether to complement is encoded in bit 0.
//
// Design note:
// At most 10 bytes are needed for 64-bit values. The encoding could
// be more dense: a full 64-bit value needs an extra byte just to hold bit 63.
// Instead, the msb of the previous byte could be used to hold bit 63 since we
// know there can't be more than 64 bits. This is a trivial improvement and
// would reduce the maximum encoding length to 9 bytes. However, it breaks the
// invariant that the msb is always the "continuation bit" and thus makes the
// format incompatible with a varint encoding for larger numbers (say 128-bit).
import (
"errors"
"io"
)
// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer.
const (
MaxVarintLen16 = 3
MaxVarintLen32 = 5
MaxVarintLen64 = 10
)
// AppendUvarint appends the varint-encoded form of x,
// as generated by [PutUvarint], to buf and returns the extended buffer.
func AppendUvarint(buf []byte, x uint64) []byte {
for x >= 0x80 {
buf = append(buf, byte(x)|0x80)
x >>= 7
}
return append(buf, byte(x))
}
// PutUvarint encodes a uint64 into buf and returns the number of bytes written.
// If the buffer is too small, PutUvarint will panic.
func PutUvarint(buf []byte, x uint64) int {
i := 0
for x >= 0x80 {
buf[i] = byte(x) | 0x80
x >>= 7
i++
}
buf[i] = byte(x)
return i + 1
}
// Uvarint decodes a uint64 from buf and returns that value and the
// number of bytes read (> 0). If an error occurred, the value is 0
// and the number of bytes n is <= 0 meaning:
//
// n == 0: buf too small
// n < 0: value larger than 64 bits (overflow)
// and -n is the number of bytes read
func Uvarint(buf []byte) (uint64, int) {
var x uint64
var s uint
for i, b := range buf {
if i == MaxVarintLen64 {
// Catch byte reads past MaxVarintLen64.
// See issue https://golang.org/issues/41185
return 0, -(i + 1) // overflow
}
if b < 0x80 {
if i == MaxVarintLen64-1 && b > 1 {
return 0, -(i + 1) // overflow
}
return x | uint64(b)<<s, i + 1
}
x |= uint64(b&0x7f) << s
s += 7
}
return 0, 0
}
// AppendVarint appends the varint-encoded form of x,
// as generated by [PutVarint], to buf and returns the extended buffer.
func AppendVarint(buf []byte, x int64) []byte {
ux := uint64(x) << 1
if x < 0 {
ux = ^ux
}
return AppendUvarint(buf, ux)
}
// PutVarint encodes an int64 into buf and returns the number of bytes written.
// If the buffer is too small, PutVarint will panic.
func PutVarint(buf []byte, x int64) int {
ux := uint64(x) << 1
if x < 0 {
ux = ^ux
}
return PutUvarint(buf, ux)
}
// Varint decodes an int64 from buf and returns that value and the
// number of bytes read (> 0). If an error occurred, the value is 0
// and the number of bytes n is <= 0 with the following meaning:
//
// n == 0: buf too small
// n < 0: value larger than 64 bits (overflow)
// and -n is the number of bytes read
func Varint(buf []byte) (int64, int) {
ux, n := Uvarint(buf) // ok to continue in presence of error
x := int64(ux >> 1)
if ux&1 != 0 {
x = ^x
}
return x, n
}
var errOverflow = errors.New("binary: varint overflows a 64-bit integer")
// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64.
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// ReadUvarint returns [io.ErrUnexpectedEOF].
func ReadUvarint(r io.ByteReader) (uint64, error) {
var x uint64
var s uint
for i := 0; i < MaxVarintLen64; i++ {
b, err := r.ReadByte()
if err != nil {
if i > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return x, err
}
if b < 0x80 {
if i == MaxVarintLen64-1 && b > 1 {
return x, errOverflow
}
return x | uint64(b)<<s, nil
}
x |= uint64(b&0x7f) << s
s += 7
}
return x, errOverflow
}
// ReadVarint reads an encoded signed integer from r and returns it as an int64.
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// ReadVarint returns [io.ErrUnexpectedEOF].
func ReadVarint(r io.ByteReader) (int64, error) {
ux, err := ReadUvarint(r) // ok to continue in presence of error
x := int64(ux >> 1)
if ux&1 != 0 {
x = ^x
}
return x, err
}

View file

@ -9,19 +9,20 @@ import (
type AddrConn struct {
net.Conn
M.Metadata
Source M.Socksaddr
Destination M.Socksaddr
}
func (c *AddrConn) LocalAddr() net.Addr {
if c.Metadata.Destination.IsValid() {
return c.Metadata.Destination.TCPAddr()
if c.Destination.IsValid() {
return c.Destination.TCPAddr()
}
return c.Conn.LocalAddr()
}
func (c *AddrConn) RemoteAddr() net.Addr {
if c.Metadata.Source.IsValid() {
return c.Metadata.Source.TCPAddr()
if c.Source.IsValid() {
return c.Source.TCPAddr()
}
return c.Conn.RemoteAddr()
}

View file

@ -39,7 +39,7 @@ func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
if !isReadWaiter {
return nil, false
}
return &BindPacketReadWaiter{readWaiter}, true
return &bindPacketReadWaiter{readWaiter}, true
}
func (c *bindPacketConn) RemoteAddr() net.Addr {
@ -104,9 +104,62 @@ func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
if !isReadWaiter {
return nil, false
}
return &UnbindPacketReadWaiter{readWaiter, c.addr}, true
return &unbindPacketReadWaiter{readWaiter, c.addr}, true
}
func (c *UnbindPacketConn) Upstream() any {
return c.ExtendedConn
}
func NewServerPacketConn(conn net.PacketConn) N.ExtendedConn {
return &serverPacketConn{
NetPacketConn: NewPacketConn(conn),
}
}
type serverPacketConn struct {
N.NetPacketConn
remoteAddr M.Socksaddr
}
func (c *serverPacketConn) Read(p []byte) (n int, err error) {
n, addr, err := c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
c.remoteAddr = M.SocksaddrFromNet(addr)
return
}
func (c *serverPacketConn) ReadBuffer(buffer *buf.Buffer) error {
destination, err := c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return err
}
c.remoteAddr = destination
return nil
}
func (c *serverPacketConn) Write(p []byte) (n int, err error) {
return c.NetPacketConn.WriteTo(p, c.remoteAddr.UDPAddr())
}
func (c *serverPacketConn) WriteBuffer(buffer *buf.Buffer) error {
return c.NetPacketConn.WritePacket(buffer, c.remoteAddr)
}
func (c *serverPacketConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *serverPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
if !isReadWaiter {
return nil, false
}
return &serverPacketReadWaiter{c, readWaiter}, true
}

View file

@ -6,33 +6,33 @@ import (
N "github.com/sagernet/sing/common/network"
)
var _ N.ReadWaiter = (*BindPacketReadWaiter)(nil)
var _ N.ReadWaiter = (*bindPacketReadWaiter)(nil)
type BindPacketReadWaiter struct {
type bindPacketReadWaiter struct {
readWaiter N.PacketReadWaiter
}
func (w *BindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
func (w *bindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
func (w *bindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, _, err = w.readWaiter.WaitReadPacket()
return
}
var _ N.PacketReadWaiter = (*UnbindPacketReadWaiter)(nil)
var _ N.PacketReadWaiter = (*unbindPacketReadWaiter)(nil)
type UnbindPacketReadWaiter struct {
type unbindPacketReadWaiter struct {
readWaiter N.ReadWaiter
addr M.Socksaddr
}
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
func (w *unbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
func (w *unbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, err = w.readWaiter.WaitReadBuffer()
if err != nil {
return
@ -40,3 +40,23 @@ func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destinati
destination = w.addr
return
}
var _ N.ReadWaiter = (*serverPacketReadWaiter)(nil)
type serverPacketReadWaiter struct {
*serverPacketConn
readWaiter N.PacketReadWaiter
}
func (w *serverPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *serverPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, destination, err := w.readWaiter.WaitReadPacket()
if err != nil {
return
}
w.remoteAddr = destination
return
}

View file

@ -4,6 +4,7 @@ import (
"io"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
)
@ -41,6 +42,25 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) {
}
}
func (w *BufferedWriter) WriteByte(c byte) error {
w.access.Lock()
defer w.access.Unlock()
if w.buffer == nil {
return common.Error(w.upstream.Write([]byte{c}))
}
for {
err := w.buffer.WriteByte(c)
if err == nil {
return nil
}
_, err = w.upstream.Write(w.buffer.Bytes())
if err != nil {
return err
}
w.buffer.Reset()
}
}
func (w *BufferedWriter) Fallthrough() error {
w.access.Lock()
defer w.access.Unlock()

View file

@ -3,7 +3,6 @@ package bufio
import (
"io"
"net"
"time"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
@ -60,13 +59,6 @@ func (c *CachedConn) WriteTo(w io.Writer) (n int64, err error) {
return
}
func (c *CachedConn) SetReadDeadline(t time.Time) error {
if c.buffer != nil && !c.buffer.IsEmpty() {
return nil
}
return c.Conn.SetReadDeadline(t)
}
func (c *CachedConn) ReadFrom(r io.Reader) (n int64, err error) {
return Copy(c.Conn, r)
}
@ -192,10 +184,12 @@ func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
if buffer != nil {
buffer.DecRef()
}
return &N.PacketBuffer{
packet := N.NewPacketBuffer()
*packet = N.PacketBuffer{
Buffer: buffer,
Destination: c.destination,
}
return packet
}
func (c *CachedPacketConn) Upstream() any {

View file

@ -35,14 +35,7 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
if destination.IsFqdn() {
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
if err != nil {
return err
}
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
}
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
}
func (w *ExtendedUDPConn) Upstream() any {

View file

@ -12,7 +12,6 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
)
@ -30,28 +29,36 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
if cachedSrc, isCached := source.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil {
if !cachedBuffer.IsEmpty() {
_, err = destination.Write(cachedBuffer.Bytes())
if err != nil {
cachedBuffer.Release()
return
}
}
dataLen := cachedBuffer.Len()
_, err = destination.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err != nil {
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
continue
}
}
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
break
}
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
}
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
}
@ -76,6 +83,7 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
}
// Deprecated: not used
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
buffer.IncRef()
defer buffer.DecRef()
@ -114,19 +122,10 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
}
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
bufferSize := N.CalculateMTU(source, destination)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
options := N.NewReadWaitOptions(source, destination)
var notFirstTime bool
for {
buffer := buf.NewSize(bufferSize)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
buffer := options.NewBuffer()
err = source.ReadBuffer(buffer)
if err != nil {
buffer.Release()
@ -137,7 +136,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
return
}
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
options.PostReturn(buffer)
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
@ -158,16 +157,12 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
}
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
return CopyConnContextList([]context.Context{ctx}, source, destination)
}
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
var group task.Group
if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex {
if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex {
group.Append("upload", func(ctx context.Context) error {
err := common.Error(Copy(destination, source))
if err == nil {
rw.CloseWrite(destination)
N.CloseWrite(destination)
} else {
common.Close(destination)
}
@ -179,11 +174,11 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
return common.Error(Copy(destination, source))
})
}
if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex {
if _, srcDuplex := common.Cast[N.WriteCloser](source); srcDuplex {
group.Append("download", func(ctx context.Context) error {
err := common.Error(Copy(source, destination))
if err == nil {
rw.CloseWrite(source)
N.CloseWrite(source)
} else {
common.Close(source)
}
@ -198,7 +193,7 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
group.Cleanup(func() {
common.Close(source, destination)
})
return group.RunContextList(contextList)
return group.Run(ctx)
}
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
@ -218,24 +213,24 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
break
}
if cachedPackets != nil {
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
if err != nil {
return
}
}
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
n += copeN
return
}
func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var (
handled bool
copeN int64
)
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter {
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destinationConn),
})
needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled {
@ -249,28 +244,19 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
return
}
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
bufferSize := N.CalculateMTU(source, destinationConn)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var destination M.Socksaddr
func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
options := N.NewReadWaitOptions(source, destination)
var destinationAddress M.Socksaddr
for {
buffer := buf.NewSize(bufferSize)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
destination, err = source.ReadPacket(buffer)
buffer := options.NewPacketBuffer()
destinationAddress, err = source.ReadPacket(buffer)
if err != nil {
buffer.Release()
return
}
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destinationConn.WritePacket(buffer, destination)
options.PostReturn(buffer)
err = destination.WritePacket(buffer, destinationAddress)
if err != nil {
buffer.Leak()
if !notFirstTime {
@ -278,34 +264,25 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
n += int64(dataLen)
notFirstTime = true
}
}
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
options := N.NewReadWaitOptions(nil, destination)
var notFirstTime bool
for _, packetBuffer := range packetBuffers {
buffer := buf.NewPacket()
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
_, err = buffer.Write(packetBuffer.Buffer.Bytes())
packetBuffer.Buffer.Release()
if err != nil {
buffer.Release()
continue
}
buffer := options.Copy(packetBuffer.Buffer)
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
err = destination.WritePacket(buffer, packetBuffer.Destination)
N.PutPacketBuffer(packetBuffer)
if err != nil {
buffer.Leak()
if !notFirstTime {
@ -313,16 +290,19 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr
}
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
n += int64(dataLen)
notFirstTime = true
}
return
}
func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
return CopyPacketConnContextList([]context.Context{ctx}, source, destination)
}
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
var group task.Group
group.Append("upload", func(ctx context.Context) error {
return common.Error(CopyPacket(destination, source))
@ -334,5 +314,5 @@ func CopyPacketConnContextList(contextList []context.Context, source N.PacketCon
common.Close(source, destination)
})
group.FastFail()
return group.RunContextList(contextList)
return group.Run(ctx)
}

View file

@ -47,6 +47,7 @@ func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (nee
} else {
buffer.Release()
}
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr == syscall.EAGAIN {
return false
}
@ -105,23 +106,21 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
var readN int
var from syscall.Sockaddr
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr != nil {
buffer.Release()
return w.readErr != syscall.EAGAIN
}
if readN > 0 {
buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
}
if w.readErr == syscall.EAGAIN {
return false
}
if from != nil {
switch fromAddr := from.(type) {
case *syscall.SockaddrInet4:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *syscall.SockaddrInet6:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
w.options.PostReturn(buffer)
w.buffer = buffer
switch fromAddr := from.(type) {
case *syscall.SockaddrInet4:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *syscall.SockaddrInet6:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
return true
}

View file

@ -5,7 +5,6 @@ import (
"net/netip"
"os"
"syscall"
"unsafe"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
@ -15,27 +14,6 @@ import (
"golang.org/x/sys/windows"
)
//go:linkname modws2_32 golang.org/x/sys/windows.modws2_32
var modws2_32 *windows.LazyDLL
var procrecv = modws2_32.NewProc("recv")
//go:linkname errnoErr golang.org/x/sys/windows.errnoErr
func errnoErr(e syscall.Errno) error
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
r0, _, e1 := syscall.SyscallN(procrecv.Addr(), uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags))
n = int32(r0)
if n == -1 {
err = errnoErr(e1)
}
return
}
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
type syscallReadWaiter struct {
@ -142,16 +120,16 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
var readN int
var from windows.Sockaddr
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr != nil {
buffer.Release()
return w.readErr != windows.WSAEWOULDBLOCK
}
if readN > 0 {
buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
}
if w.readErr == windows.WSAEWOULDBLOCK {
return false
}
w.options.PostReturn(buffer)
w.buffer = buffer
if from != nil {
switch fromAddr := from.(type) {
case *windows.SockaddrInet4:

View file

@ -25,6 +25,45 @@ func ReadPacket(reader N.PacketReader, buffer *buf.Buffer) (n int, addr net.Addr
return
}
func ReadBufferSize(reader io.Reader, bufferSize int) (buffer *buf.Buffer, err error) {
readWaiter, isReadWaiter := CreateReadWaiter(reader)
if isReadWaiter {
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
MTU: bufferSize,
})
return readWaiter.WaitReadBuffer()
}
buffer = buf.NewSize(bufferSize)
if extendedReader, isExtendedReader := reader.(N.ExtendedReader); isExtendedReader {
err = extendedReader.ReadBuffer(buffer)
} else {
_, err = buffer.ReadOnceFrom(reader)
}
if err != nil {
buffer.Release()
buffer = nil
}
return
}
func ReadPacketSize(reader N.PacketReader, packetSize int) (buffer *buf.Buffer, destination M.Socksaddr, err error) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(reader)
if isReadWaiter {
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
MTU: packetSize,
})
buffer, destination, err = readWaiter.WaitReadPacket()
return
}
buffer = buf.NewSize(packetSize)
destination, err = reader.ReadPacket(buffer)
if err != nil {
buffer.Release()
buffer = nil
}
return
}
func Write(writer io.Writer, data []byte) (n int, err error) {
if extendedWriter, isExtended := writer.(N.ExtendedWriter); isExtended {
return WriteBuffer(extendedWriter, buf.As(data))

View file

@ -17,13 +17,21 @@ type NATPacketConn interface {
func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &unidirectionalNATPacketConn{
NetPacketConn: conn,
origin: origin,
destination: destination,
origin: socksaddrWithoutPort(origin),
destination: socksaddrWithoutPort(destination),
}
}
func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &bidirectionalNATPacketConn{
NetPacketConn: conn,
origin: socksaddrWithoutPort(origin),
destination: socksaddrWithoutPort(destination),
}
}
func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &destinationNATPacketConn{
NetPacketConn: conn,
origin: origin,
destination: destination,
@ -37,15 +45,24 @@ type unidirectionalNATPacketConn struct {
}
func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if M.SocksaddrFromNet(addr) == c.destination {
addr = c.origin.UDPAddr()
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WriteTo(p, addr)
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
}
func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination == c.destination {
destination = c.origin
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
@ -54,6 +71,10 @@ func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
}
func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func (c *unidirectionalNATPacketConn) Upstream() any {
return c.NetPacketConn
}
@ -66,30 +87,55 @@ type bidirectionalNATPacketConn struct {
func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p)
if err == nil && M.SocksaddrFromNet(addr) == c.origin {
addr = c.destination.UDPAddr()
if err != nil {
return
}
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
addr = destination.UDPAddr()
return
}
func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if M.SocksaddrFromNet(addr) == c.destination {
addr = c.origin.UDPAddr()
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WriteTo(p, addr)
return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
}
func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer)
if destination == c.origin {
destination = c.destination
if err != nil {
return
}
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
return
}
func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination == c.destination {
destination = c.origin
if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
@ -101,3 +147,66 @@ func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.
func (c *bidirectionalNATPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
type destinationNATPacketConn struct {
N.NetPacketConn
origin M.Socksaddr
destination M.Socksaddr
}
func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
if M.SocksaddrFromNet(addr) == c.origin {
addr = c.destination.UDPAddr()
}
return
}
func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if M.SocksaddrFromNet(addr) == c.destination {
addr = c.origin.UDPAddr()
}
return c.NetPacketConn.WriteTo(p, addr)
}
func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return
}
if destination == c.origin {
destination = c.destination
}
return
}
func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination == c.destination {
destination = c.origin
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
}
func (c *destinationNATPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *destinationNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
destination.Port = 0
return destination
}

39
common/bufio/nat_wait.go Normal file
View file

@ -0,0 +1,39 @@
package bufio
import (
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func (c *bidirectionalNATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaiter, bool) {
waiter, created := CreatePacketReadWaiter(c.NetPacketConn)
if !created {
return nil, false
}
return &waitBidirectionalNATPacketConn{c, waiter}, true
}
type waitBidirectionalNATPacketConn struct {
*bidirectionalNATPacketConn
readWaiter N.PacketReadWaiter
}
func (c *waitBidirectionalNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return c.readWaiter.InitializeReadWaiter(options)
}
func (c *waitBidirectionalNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, destination, err = c.readWaiter.WaitReadPacket()
if err != nil {
return
}
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
return
}

View file

@ -36,7 +36,7 @@ func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
return clientErr
})
err = group.Run()
err = group.Run(context.Background())
require.NoError(t, err)
listener.Close()
t.Cleanup(func() {

View file

@ -0,0 +1,5 @@
package bufio
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
//sys recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) [failretval == -1] = ws2_32.recv

View file

@ -38,7 +38,6 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
var innerErr unix.Errno
err := w.rawConn.Write(func(fd uintptr) (done bool) {
//nolint:staticcheck
//goland:noinspection GoDeprecation
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
})

View file

@ -0,0 +1,57 @@
// Code generated by 'go generate'; DO NOT EDIT.
package bufio
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procrecv = modws2_32.NewProc("recv")
)
func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
r0, _, e1 := syscall.Syscall6(procrecv.Addr(), 4, uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags), 0, 0)
n = int32(r0)
if n == -1 {
err = errnoErr(e1)
}
return
}

View file

@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration {
return i.timeout
}
func (i *Instance) SetTimeout(timeout time.Duration) {
func (i *Instance) SetTimeout(timeout time.Duration) bool {
i.timeout = timeout
i.Update()
return i.Update()
}
func (i *Instance) wait() {

View file

@ -13,7 +13,7 @@ import (
type PacketConn interface {
N.PacketConn
Timeout() time.Duration
SetTimeout(timeout time.Duration)
SetTimeout(timeout time.Duration) bool
}
type TimerPacketConn struct {
@ -21,13 +21,15 @@ type TimerPacketConn struct {
instance *Instance
}
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) {
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
oldTimeout := timeoutConn.Timeout()
if timeout < oldTimeout {
timeoutConn.SetTimeout(timeout)
if oldTimeout > 0 && timeout >= oldTimeout {
return ctx, conn
}
if timeoutConn.SetTimeout(timeout) {
return ctx, conn
}
return ctx, timeoutConn
}
err := conn.SetReadDeadline(time.Time{})
if err == nil {
@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration {
return c.instance.Timeout()
}
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) {
c.instance.SetTimeout(timeout)
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
return c.instance.SetTimeout(timeout)
}
func (c *TimerPacketConn) Close() error {

View file

@ -2,6 +2,7 @@ package canceler
import (
"context"
"net"
"time"
"github.com/sagernet/sing/common"
@ -31,7 +32,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
for {
err = c.PacketConn.SetReadDeadline(time.Now().Add(c.timeout))
if err != nil {
return M.Socksaddr{}, err
return
}
destination, err = c.PacketConn.ReadPacket(buffer)
if err == nil {
@ -43,7 +44,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
return
}
} else {
return M.Socksaddr{}, err
return
}
}
}
@ -60,12 +61,13 @@ func (c *TimeoutPacketConn) Timeout() time.Duration {
return c.timeout
}
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
c.timeout = timeout
c.PacketConn.SetReadDeadline(time.Now())
return c.PacketConn.SetReadDeadline(time.Now()) == nil
}
func (c *TimeoutPacketConn) Close() error {
c.cancel(net.ErrClosed)
return c.PacketConn.Close()
}

View file

@ -157,6 +157,18 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
return -1
}
func Equal[S ~[]E, E comparable](s1, s2 S) bool {
if len(s1) != len(s2) {
return false
}
for i := range s1 {
if s1[i] != s2[i] {
return false
}
}
return true
}
//go:norace
func Dup[T any](obj T) T {
pointer := uintptr(unsafe.Pointer(&obj))
@ -268,6 +280,14 @@ func Reverse[T any](arr []T) []T {
return arr
}
func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K {
ret := make(map[V]K, len(m))
for k, v := range m {
ret[v] = k
}
return ret
}
func Done(ctx context.Context) bool {
select {
case <-ctx.Done():
@ -362,22 +382,3 @@ func Close(closers ...any) error {
}
return retErr
}
type Starter interface {
Start() error
}
func Start(starters ...any) error {
for _, rawStarter := range starters {
if rawStarter == nil {
continue
}
if starter, isStarter := rawStarter.(Starter); isStarter {
err := starter.Start()
if err != nil {
return err
}
}
}
return nil
}

View file

@ -5,6 +5,7 @@ import (
"reflect"
)
// Deprecated: not used
func SelectContext(contextList []context.Context) (int, error) {
if len(contextList) == 1 {
<-contextList[0].Done()

View file

@ -9,15 +9,15 @@ import (
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
var err error
if interfaceIndex == -1 {
if finder == nil {
return os.ErrInvalid
}
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
switch network {
case "tcp6", "udp6":

View file

@ -1,30 +1,59 @@
package control
import "net"
import (
"net"
"net/netip"
"unsafe"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
)
type InterfaceFinder interface {
InterfaceIndexByName(name string) (int, error)
InterfaceNameByIndex(index int) (string, error)
Update() error
Interfaces() []Interface
ByName(name string) (*Interface, error)
ByIndex(index int) (*Interface, error)
ByAddr(addr netip.Addr) (*Interface, error)
}
func DefaultInterfaceFinder() InterfaceFinder {
return (*netInterfaceFinder)(nil)
type Interface struct {
Index int
MTU int
Name string
HardwareAddr net.HardwareAddr
Flags net.Flags
Addresses []netip.Prefix
}
type netInterfaceFinder struct{}
func (i Interface) Equals(other Interface) bool {
return i.Index == other.Index &&
i.MTU == other.MTU &&
i.Name == other.Name &&
common.Equal(i.HardwareAddr, other.HardwareAddr) &&
i.Flags == other.Flags &&
common.Equal(i.Addresses, other.Addresses)
}
func (w *netInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
netInterface, err := net.InterfaceByName(name)
func (i Interface) NetInterface() net.Interface {
return *(*net.Interface)(unsafe.Pointer(&i))
}
func InterfaceFromNet(iif net.Interface) (Interface, error) {
ifAddrs, err := iif.Addrs()
if err != nil {
return 0, err
return Interface{}, err
}
return netInterface.Index, nil
return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil
}
func (w *netInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
netInterface, err := net.InterfaceByIndex(index)
if err != nil {
return "", err
func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface {
return Interface{
Index: iif.Index,
MTU: iif.MTU,
Name: iif.Name,
HardwareAddr: iif.HardwareAddr,
Flags: iif.Flags,
Addresses: addresses,
}
return netInterface.Name, nil
}

View file

@ -0,0 +1,89 @@
package control
import (
"net"
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
)
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
type DefaultInterfaceFinder struct {
interfaces []Interface
}
func NewDefaultInterfaceFinder() *DefaultInterfaceFinder {
return &DefaultInterfaceFinder{}
}
func (f *DefaultInterfaceFinder) Update() error {
netIfs, err := net.Interfaces()
if err != nil {
return err
}
interfaces := make([]Interface, 0, len(netIfs))
for _, netIf := range netIfs {
var iif Interface
iif, err = InterfaceFromNet(netIf)
if err != nil {
return err
}
interfaces = append(interfaces, iif)
}
f.interfaces = interfaces
return nil
}
func (f *DefaultInterfaceFinder) UpdateInterfaces(interfaces []Interface) {
f.interfaces = interfaces
}
func (f *DefaultInterfaceFinder) Interfaces() []Interface {
return f.interfaces
}
func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) {
for _, netInterface := range f.interfaces {
if netInterface.Name == name {
return &netInterface, nil
}
}
_, err := net.InterfaceByName(name)
if err == nil {
err = f.Update()
if err != nil {
return nil, err
}
return f.ByName(name)
}
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
}
func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) {
for _, netInterface := range f.interfaces {
if netInterface.Index == index {
return &netInterface, nil
}
}
_, err := net.InterfaceByIndex(index)
if err == nil {
err = f.Update()
if err != nil {
return nil, err
}
return f.ByIndex(index)
}
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
}
func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) {
for _, netInterface := range f.interfaces {
for _, prefix := range netInterface.Addresses {
if prefix.Contains(addr) {
return &netInterface, nil
}
}
}
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: addr.AsSlice()}, Err: E.New("no such network interface")}
}

View file

@ -14,19 +14,18 @@ var ifIndexDisabled atomic.Bool
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
if interfaceIndex != -1 {
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
}
if interfaceName == "" {
return os.ErrInvalid
}
if !preferInterfaceName && finder != nil && !ifIndexDisabled.Load() {
var err error
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
if err != nil {
return err
if !preferInterfaceName && !ifIndexDisabled.Load() {
if interfaceIndex == -1 {
if interfaceName == "" {
return os.ErrInvalid
}
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
if err == nil {
return nil
} else if E.IsMulti(err, unix.ENOPROTOOPT, unix.EINVAL) {
@ -35,6 +34,9 @@ func bindToInterface(conn syscall.RawConn, network string, address string, finde
return err
}
}
if interfaceName == "" {
return os.ErrInvalid
}
return unix.BindToDevice(int(fd), interfaceName)
})
}

View file

@ -11,19 +11,19 @@ import (
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
var err error
if interfaceIndex == -1 {
if finder == nil {
return os.ErrInvalid
}
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
handle := syscall.Handle(fd)
if M.ParseSocksaddr(address).AddrString() == "" {
err = bind4(handle, interfaceIndex)
err := bind4(handle, interfaceIndex)
if err != nil {
return err
}

View file

@ -4,19 +4,26 @@ import (
"os"
"syscall"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
)
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
if N.NetworkName(network) != N.NetworkUDP {
return nil
}
return Raw(conn, func(fd uintptr) error {
switch network {
case "udp4":
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1); err != nil {
if network == "udp" || network == "udp4" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
}
case "udp6":
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1); err != nil {
}
if network == "udp" || network == "udp6" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
}
}

View file

@ -11,17 +11,19 @@ import (
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
switch N.NetworkName(network) {
case N.NetworkUDP:
default:
if N.NetworkName(network) != N.NetworkUDP {
return nil
}
return Raw(conn, func(fd uintptr) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
if network == "udp" || network == "udp4" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}
if network == "udp6" {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
if network == "udp" || network == "udp6" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}

View file

@ -25,17 +25,19 @@ const (
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
switch N.NetworkName(network) {
case N.NetworkUDP:
default:
if N.NetworkName(network) != N.NetworkUDP {
return nil
}
return Raw(conn, func(fd uintptr) error {
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
if network == "udp" || network == "udp4" {
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}
if network == "udp6" {
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
if network == "udp" || network == "udp6" {
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}

View file

@ -3,6 +3,7 @@ package control
import (
"syscall"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
@ -30,6 +31,14 @@ func Conn(conn syscall.Conn, block func(fd uintptr) error) error {
return Raw(rawConn, block)
}
func Conn0[T any](conn syscall.Conn, block func(fd uintptr) (T, error)) (T, error) {
rawConn, err := conn.SyscallConn()
if err != nil {
return common.DefaultValue[T](), err
}
return Raw0[T](rawConn, block)
}
func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
var innerErr error
err := rawConn.Control(func(fd uintptr) {
@ -37,3 +46,14 @@ func Raw(rawConn syscall.RawConn, block func(fd uintptr) error) error {
})
return E.Errors(innerErr, err)
}
func Raw0[T any](rawConn syscall.RawConn, block func(fd uintptr) (T, error)) (T, error) {
var (
value T
innerErr error
)
err := rawConn.Control(func(fd uintptr) {
value, innerErr = block(fd)
})
return value, E.Errors(innerErr, err)
}

View file

@ -4,10 +4,10 @@ import (
"syscall"
)
func RoutingMark(mark int) Func {
func RoutingMark(mark uint32) Func {
return func(network, address string, conn syscall.RawConn) error {
return Raw(conn, func(fd uintptr) error {
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark)
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(mark))
})
}
}

View file

@ -2,6 +2,6 @@
package control
func RoutingMark(mark int) Func {
func RoutingMark(mark uint32) Func {
return nil
}

View file

@ -0,0 +1,58 @@
package control
import (
"encoding/binary"
"net"
"net/netip"
"syscall"
"unsafe"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
const (
PF_OUT = 0x2
DIOCNATLOOK = 0xc0544417
)
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
pfFd, err := syscall.Open("/dev/pf", 0, syscall.O_RDONLY)
if err != nil {
return netip.AddrPort{}, err
}
defer syscall.Close(pfFd)
nl := struct {
saddr, daddr, rsaddr, rdaddr [16]byte
sxport, dxport, rsxport, rdxport [4]byte
af, proto, protoVariant, direction uint8
}{
af: syscall.AF_INET,
proto: syscall.IPPROTO_TCP,
direction: PF_OUT,
}
localAddr := M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
removeAddr := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
if localAddr.IsIPv4() {
copy(nl.saddr[:net.IPv4len], removeAddr.Addr.AsSlice())
copy(nl.daddr[:net.IPv4len], localAddr.Addr.AsSlice())
nl.af = syscall.AF_INET
} else {
copy(nl.saddr[:], removeAddr.Addr.AsSlice())
copy(nl.daddr[:], localAddr.Addr.AsSlice())
nl.af = syscall.AF_INET6
}
binary.BigEndian.PutUint16(nl.sxport[:], removeAddr.Port)
binary.BigEndian.PutUint16(nl.dxport[:], localAddr.Port)
if _, _, errno := unix.Syscall(syscall.SYS_IOCTL, uintptr(pfFd), DIOCNATLOOK, uintptr(unsafe.Pointer(&nl))); errno != 0 {
return netip.AddrPort{}, errno
}
var address netip.Addr
if nl.af == unix.AF_INET {
address = M.AddrFromIP(nl.rdaddr[:net.IPv4len])
} else {
address = netip.AddrFrom16(nl.rdaddr)
}
return netip.AddrPortFrom(address, binary.BigEndian.Uint16(nl.rdxport[:])), nil
}

View file

@ -0,0 +1,38 @@
package control
import (
"encoding/binary"
"net"
"net/netip"
"os"
"syscall"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
syscallConn, loaded := common.Cast[syscall.Conn](conn)
if !loaded {
return netip.AddrPort{}, os.ErrInvalid
}
return Conn0[netip.AddrPort](syscallConn, func(fd uintptr) (netip.AddrPort, error) {
if M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap().IsIPv4() {
raw, err := unix.GetsockoptIPv6Mreq(int(fd), unix.IPPROTO_IP, unix.SO_ORIGINAL_DST)
if err != nil {
return netip.AddrPort{}, err
}
return netip.AddrPortFrom(M.AddrFromIP(raw.Multiaddr[4:8]), uint16(raw.Multiaddr[2])<<8+uint16(raw.Multiaddr[3])), nil
} else {
raw, err := unix.GetsockoptIPv6MTUInfo(int(fd), unix.IPPROTO_IPV6, unix.SO_ORIGINAL_DST)
if err != nil {
return netip.AddrPort{}, err
}
var port [2]byte
binary.BigEndian.PutUint16(port[:], raw.Addr.Port)
return netip.AddrPortFrom(M.AddrFromIP(raw.Addr.Addr[:]), binary.LittleEndian.Uint16(port[:])), nil
}
})
}

View file

@ -0,0 +1,13 @@
//go:build !linux && !darwin
package control
import (
"net"
"net/netip"
"os"
)
func GetOriginalDestination(conn net.Conn) (netip.AddrPort, error) {
return netip.AddrPort{}, os.ErrInvalid
}

View file

@ -0,0 +1,30 @@
package control
import (
"syscall"
"time"
_ "unsafe"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
)
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
return func(network, address string, conn syscall.RawConn) error {
if N.NetworkName(network) != N.NetworkTCP {
return nil
}
return Raw(conn, func(fd uintptr) error {
return E.Errors(
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPIDLE, int(roundDurationUp(idle, time.Second))),
unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_KEEPINTVL, int(roundDurationUp(interval, time.Second))),
)
})
}
}
func roundDurationUp(d time.Duration, to time.Duration) time.Duration {
return (d + to - 1) / to
}

View file

@ -0,0 +1,11 @@
//go:build !linux
package control
import (
"time"
)
func SetKeepAlivePeriod(idle time.Duration, interval time.Duration) Func {
return nil
}

View file

@ -0,0 +1,56 @@
package control
import (
"encoding/binary"
"net/netip"
"syscall"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
func TProxy(fd uintptr, family int) error {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
if err == nil {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
}
if err == nil && family == unix.AF_INET6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
}
if err == nil {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
}
if err == nil && family == unix.AF_INET6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
}
return err
}
func TProxyWriteBack() Func {
return func(network, address string, conn syscall.RawConn) error {
return Raw(conn, func(fd uintptr) error {
if M.ParseSocksaddr(address).Addr.Is6() {
return syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
} else {
return syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
}
})
}
}
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
controlMessages, err := unix.ParseSocketControlMessage(oob)
if err != nil {
return netip.AddrPort{}, err
}
for _, message := range controlMessages {
if message.Header.Level == unix.SOL_IP && message.Header.Type == unix.IP_RECVORIGDSTADDR {
return netip.AddrPortFrom(M.AddrFromIP(message.Data[4:8]), binary.BigEndian.Uint16(message.Data[2:4])), nil
} else if message.Header.Level == unix.SOL_IPV6 && message.Header.Type == unix.IPV6_RECVORIGDSTADDR {
return netip.AddrPortFrom(M.AddrFromIP(message.Data[8:24]), binary.BigEndian.Uint16(message.Data[2:4])), nil
}
}
return netip.AddrPort{}, E.New("not found")
}

View file

@ -0,0 +1,20 @@
//go:build !linux
package control
import (
"net/netip"
"os"
)
func TProxy(fd uintptr, isIPv6 bool) error {
return os.ErrInvalid
}
func TProxyWriteBack() Func {
return nil
}
func GetOriginalDestinationFromOOB(oob []byte) (netip.AddrPort, error) {
return netip.AddrPort{}, os.ErrInvalid
}

View file

@ -0,0 +1,67 @@
package domain_test
import (
"sort"
"testing"
"github.com/sagernet/sing/common/domain"
"github.com/stretchr/testify/require"
)
func TestAdGuardMatcher(t *testing.T) {
t.Parallel()
ruleLines := []string{
"||example.org^",
"|example.com^",
"example.net^",
"||example.edu",
"||example.edu.tw^",
"|example.gov",
"example.arpa",
}
matcher := domain.NewAdGuardMatcher(ruleLines)
require.NotNil(t, matcher)
matchDomain := []string{
"example.org",
"www.example.org",
"example.com",
"example.net",
"isexample.net",
"www.example.net",
"example.edu",
"example.edu.cn",
"example.edu.tw",
"www.example.edu",
"www.example.edu.cn",
"example.gov",
"example.gov.cn",
"example.arpa",
"www.example.arpa",
"isexample.arpa",
"example.arpa.cn",
"www.example.arpa.cn",
"isexample.arpa.cn",
}
notMatchDomain := []string{
"example.org.cn",
"notexample.org",
"example.com.cn",
"www.example.com.cn",
"example.net.cn",
"notexample.edu",
"notexample.edu.cn",
"www.example.gov",
"notexample.gov",
}
for _, domain := range matchDomain {
require.True(t, matcher.Match(domain), domain)
}
for _, domain := range notMatchDomain {
require.False(t, matcher.Match(domain), domain)
}
dLines := matcher.Dump()
sort.Strings(ruleLines)
sort.Strings(dLines)
require.Equal(t, ruleLines, dLines)
}

View file

@ -0,0 +1,172 @@
package domain
import (
"bytes"
"sort"
"strings"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/varbin"
)
const (
anyLabel = '*'
suffixLabel = '\b'
)
type AdGuardMatcher struct {
set *succinctSet
}
func NewAdGuardMatcher(ruleLines []string) *AdGuardMatcher {
ruleList := make([]string, 0, len(ruleLines))
for _, ruleLine := range ruleLines {
var (
isSuffix bool // ||
hasStart bool // |
hasEnd bool // ^
)
if strings.HasPrefix(ruleLine, "||") {
ruleLine = ruleLine[2:]
isSuffix = true
} else if strings.HasPrefix(ruleLine, "|") {
ruleLine = ruleLine[1:]
hasStart = true
}
if strings.HasSuffix(ruleLine, "^") {
ruleLine = ruleLine[:len(ruleLine)-1]
hasEnd = true
}
if isSuffix {
ruleLine = string(rootLabel) + ruleLine
} else if !hasStart {
ruleLine = string(prefixLabel) + ruleLine
}
if !hasEnd {
if strings.HasSuffix(ruleLine, ".") {
ruleLine = ruleLine[:len(ruleLine)-1]
}
ruleLine += string(suffixLabel)
}
ruleList = append(ruleList, reverseDomain(ruleLine))
}
ruleList = common.Uniq(ruleList)
sort.Strings(ruleList)
return &AdGuardMatcher{newSuccinctSet(ruleList)}
}
func ReadAdGuardMatcher(reader varbin.Reader) (*AdGuardMatcher, error) {
set, err := readSuccinctSet(reader)
if err != nil {
return nil, err
}
return &AdGuardMatcher{set}, nil
}
func (m *AdGuardMatcher) Write(writer varbin.Writer) error {
return m.set.Write(writer)
}
func (m *AdGuardMatcher) Match(domain string) bool {
key := reverseDomain(domain)
if m.has([]byte(key), 0, 0) {
return true
}
for {
if m.has([]byte(string(suffixLabel)+key), 0, 0) {
return true
}
idx := strings.IndexByte(key, '.')
if idx == -1 {
return false
}
key = key[idx+1:]
}
}
func (m *AdGuardMatcher) has(key []byte, nodeId, bmIdx int) bool {
for i := 0; i < len(key); i++ {
currentChar := key[i]
for ; ; bmIdx++ {
if getBit(m.set.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := m.set.labels[bmIdx-nodeId]
if nextLabel == prefixLabel {
return true
}
if nextLabel == rootLabel {
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
hasNext := getBit(m.set.leaves, nextNodeId) != 0
if currentChar == '.' && hasNext {
return true
}
}
if nextLabel == currentChar {
break
}
if nextLabel == anyLabel {
idx := bytes.IndexRune(key[i:], '.')
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
if idx == -1 {
if getBit(m.set.leaves, nextNodeId) != 0 {
return true
}
idx = 0
}
nextBmIdx := selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nextNodeId-1) + 1
if m.has(key[i+idx:], nextNodeId, nextBmIdx) {
return true
}
}
}
nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1
}
if getBit(m.set.leaves, nodeId) != 0 {
return true
}
for ; ; bmIdx++ {
if getBit(m.set.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := m.set.labels[bmIdx-nodeId]
if nextLabel == prefixLabel || nextLabel == rootLabel {
return true
}
}
}
func (m *AdGuardMatcher) Dump() (ruleLines []string) {
for _, key := range m.set.keys() {
key = reverseDomain(key)
var (
isSuffix bool
hasStart bool
hasEnd bool
)
if key[0] == prefixLabel {
key = key[1:]
} else if key[0] == rootLabel {
key = key[1:]
isSuffix = true
} else {
hasStart = true
}
if key[len(key)-1] == suffixLabel {
key = key[:len(key)-1]
} else {
hasEnd = true
}
if isSuffix {
key = "||" + key
} else if hasStart {
key = "|" + key
}
if hasEnd {
key += "^"
}
ruleLines = append(ruleLines, key)
}
return
}

View file

@ -1,27 +1,41 @@
package domain
import (
"encoding/binary"
"io"
"sort"
"unicode/utf8"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/varbin"
)
const (
prefixLabel = '\r'
rootLabel = '\n'
)
type Matcher struct {
set *succinctSet
}
func NewMatcher(domains []string, domainSuffix []string) *Matcher {
domainList := make([]string, 0, len(domains)+len(domainSuffix))
func NewMatcher(domains []string, domainSuffix []string, generateLegacy bool) *Matcher {
domainList := make([]string, 0, len(domains)+2*len(domainSuffix))
seen := make(map[string]bool, len(domainList))
for _, domain := range domainSuffix {
if seen[domain] {
continue
}
seen[domain] = true
domainList = append(domainList, reverseDomainSuffix(domain))
if domain[0] == '.' {
domainList = append(domainList, reverseDomain(string(prefixLabel)+domain))
} else if generateLegacy {
domainList = append(domainList, reverseDomain(domain))
suffixDomain := "." + domain
if !seen[suffixDomain] {
seen[suffixDomain] = true
domainList = append(domainList, reverseDomain(string(prefixLabel)+suffixDomain))
}
} else {
domainList = append(domainList, reverseDomain(string(rootLabel)+domain))
}
}
for _, domain := range domains {
if seen[domain] {
@ -34,82 +48,91 @@ func NewMatcher(domains []string, domainSuffix []string) *Matcher {
return &Matcher{newSuccinctSet(domainList)}
}
func ReadMatcher(reader io.Reader) (*Matcher, error) {
var version uint8
err := binary.Read(reader, binary.BigEndian, &version)
func ReadMatcher(reader varbin.Reader) (*Matcher, error) {
set, err := readSuccinctSet(reader)
if err != nil {
return nil, err
}
leavesLength, err := rw.ReadUVariant(reader)
if err != nil {
return nil, err
}
leaves := make([]uint64, leavesLength)
err = binary.Read(reader, binary.BigEndian, leaves)
if err != nil {
return nil, err
}
labelBitmapLength, err := rw.ReadUVariant(reader)
if err != nil {
return nil, err
}
labelBitmap := make([]uint64, labelBitmapLength)
err = binary.Read(reader, binary.BigEndian, labelBitmap)
if err != nil {
return nil, err
}
labelsLength, err := rw.ReadUVariant(reader)
if err != nil {
return nil, err
}
labels := make([]byte, labelsLength)
_, err = io.ReadFull(reader, labels)
if err != nil {
return nil, err
}
set := &succinctSet{
leaves: leaves,
labelBitmap: labelBitmap,
labels: labels,
}
set.init()
return &Matcher{set}, nil
}
func (m *Matcher) Match(domain string) bool {
return m.set.Has(reverseDomain(domain))
func (m *Matcher) Write(writer varbin.Writer) error {
return m.set.Write(writer)
}
func (m *Matcher) Write(writer io.Writer) error {
err := binary.Write(writer, binary.BigEndian, byte(1))
if err != nil {
return err
func (m *Matcher) Match(domain string) bool {
return m.has(reverseDomain(domain))
}
func (m *Matcher) has(key string) bool {
var nodeId, bmIdx int
for i := 0; i < len(key); i++ {
currentChar := key[i]
for ; ; bmIdx++ {
if getBit(m.set.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := m.set.labels[bmIdx-nodeId]
if nextLabel == prefixLabel {
return true
}
if nextLabel == rootLabel {
nextNodeId := countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
hasNext := getBit(m.set.leaves, nextNodeId) != 0
if currentChar == '.' && hasNext {
return true
}
}
if nextLabel == currentChar {
break
}
}
nodeId = countZeros(m.set.labelBitmap, m.set.ranks, bmIdx+1)
bmIdx = selectIthOne(m.set.labelBitmap, m.set.ranks, m.set.selects, nodeId-1) + 1
}
err = rw.WriteUVariant(writer, uint64(len(m.set.leaves)))
if err != nil {
return err
if getBit(m.set.leaves, nodeId) != 0 {
return true
}
err = binary.Write(writer, binary.BigEndian, m.set.leaves)
if err != nil {
return err
for ; ; bmIdx++ {
if getBit(m.set.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := m.set.labels[bmIdx-nodeId]
if nextLabel == prefixLabel || nextLabel == rootLabel {
return true
}
}
err = rw.WriteUVariant(writer, uint64(len(m.set.labelBitmap)))
if err != nil {
return err
}
func (m *Matcher) Dump() (domainList []string, prefixList []string) {
domainMap := make(map[string]bool)
prefixMap := make(map[string]bool)
for _, key := range m.set.keys() {
key = reverseDomain(key)
if key[0] == prefixLabel {
prefixMap[key[1:]] = true
} else if key[0] == rootLabel {
prefixList = append(prefixList, key[1:])
} else {
domainMap[key] = true
}
}
err = binary.Write(writer, binary.BigEndian, m.set.labelBitmap)
if err != nil {
return err
for rawPrefix := range prefixMap {
if rawPrefix[0] == '.' {
if rootDomain := rawPrefix[1:]; domainMap[rootDomain] {
delete(domainMap, rootDomain)
prefixList = append(prefixList, rootDomain)
continue
}
}
prefixList = append(prefixList, rawPrefix)
}
err = rw.WriteUVariant(writer, uint64(len(m.set.labels)))
if err != nil {
return err
for domain := range domainMap {
domainList = append(domainList, domain)
}
_, err = writer.Write(m.set.labels)
if err != nil {
return err
}
return nil
sort.Strings(domainList)
sort.Strings(prefixList)
return domainList, prefixList
}
func reverseDomain(domain string) string {
@ -122,15 +145,3 @@ func reverseDomain(domain string) string {
}
return string(b)
}
func reverseDomainSuffix(domain string) string {
l := len(domain)
b := make([]byte, l+1)
for i := 0; i < l; {
r, n := utf8.DecodeRuneInString(domain[i:])
i += n
utf8.EncodeRune(b[l-i:], r)
}
b[l] = prefixLabel
return string(b)
}

View file

@ -0,0 +1,80 @@
package domain_test
import (
"encoding/json"
"net/http"
"sort"
"testing"
"github.com/sagernet/sing/common/domain"
"github.com/stretchr/testify/require"
)
func TestMatcher(t *testing.T) {
t.Parallel()
testDomain := []string{"example.com", "example.org"}
testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"}
matcher := domain.NewMatcher(testDomain, testDomainSuffix, false)
require.NotNil(t, matcher)
require.True(t, matcher.Match("example.com"))
require.True(t, matcher.Match("example.org"))
require.False(t, matcher.Match("example.cn"))
require.True(t, matcher.Match("example.com.cn"))
require.True(t, matcher.Match("example.org.cn"))
require.False(t, matcher.Match("com.cn"))
require.False(t, matcher.Match("org.cn"))
require.True(t, matcher.Match("sagernet.org"))
require.True(t, matcher.Match("sing-box.sagernet.org"))
dDomain, dDomainSuffix := matcher.Dump()
require.Equal(t, testDomain, dDomain)
require.Equal(t, testDomainSuffix, dDomainSuffix)
}
func TestMatcherLegacy(t *testing.T) {
t.Parallel()
testDomain := []string{"example.com", "example.org"}
testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"}
matcher := domain.NewMatcher(testDomain, testDomainSuffix, true)
require.NotNil(t, matcher)
require.True(t, matcher.Match("example.com"))
require.True(t, matcher.Match("example.org"))
require.False(t, matcher.Match("example.cn"))
require.True(t, matcher.Match("example.com.cn"))
require.True(t, matcher.Match("example.org.cn"))
require.False(t, matcher.Match("com.cn"))
require.False(t, matcher.Match("org.cn"))
require.True(t, matcher.Match("sagernet.org"))
require.True(t, matcher.Match("sing-box.sagernet.org"))
dDomain, dDomainSuffix := matcher.Dump()
require.Equal(t, testDomain, dDomain)
require.Equal(t, testDomainSuffix, dDomainSuffix)
}
type simpleRuleSet struct {
Rules []struct {
Domain []string `json:"domain"`
DomainSuffix []string `json:"domain_suffix"`
}
}
func TestDumpLarge(t *testing.T) {
t.Parallel()
response, err := http.Get("https://raw.githubusercontent.com/MetaCubeX/meta-rules-dat/sing/geo/geosite/cn.json")
require.NoError(t, err)
defer response.Body.Close()
var ruleSet simpleRuleSet
err = json.NewDecoder(response.Body).Decode(&ruleSet)
require.NoError(t, err)
domainList := ruleSet.Rules[0].Domain
domainSuffixList := ruleSet.Rules[0].DomainSuffix
require.Len(t, ruleSet.Rules, 1)
require.True(t, len(domainList)+len(domainSuffixList) > 0)
sort.Strings(domainList)
sort.Strings(domainSuffixList)
matcher := domain.NewMatcher(domainList, domainSuffixList, false)
require.NotNil(t, matcher)
dDomain, dDomainSuffix := matcher.Dump()
require.Equal(t, domainList, dDomain)
require.Equal(t, domainSuffixList, dDomainSuffix)
}

View file

@ -1,10 +1,11 @@
package domain
import (
"encoding/binary"
"math/bits"
)
const prefixLabel = '\r'
"github.com/sagernet/sing/common/varbin"
)
// mod from https://github.com/openacid/succinct
@ -42,36 +43,61 @@ func newSuccinctSet(keys []string) *succinctSet {
return ss
}
func (ss *succinctSet) Has(key string) bool {
var nodeId, bmIdx int
for i := 0; i < len(key); i++ {
currentChar := key[i]
func (ss *succinctSet) keys() []string {
var result []string
var currentKey []byte
var bmIdx, nodeId int
var traverse func(int, int)
traverse = func(nodeId, bmIdx int) {
if getBit(ss.leaves, nodeId) != 0 {
result = append(result, string(currentKey))
}
for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
return false
return
}
nextLabel := ss.labels[bmIdx-nodeId]
if nextLabel == prefixLabel {
return true
}
if nextLabel == currentChar {
break
}
}
nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1
}
if getBit(ss.leaves, nodeId) != 0 {
return true
}
for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
return false
}
if ss.labels[bmIdx-nodeId] == prefixLabel {
return true
currentKey = append(currentKey, nextLabel)
nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
nextBmIdx := selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nextNodeId-1) + 1
traverse(nextNodeId, nextBmIdx)
currentKey = currentKey[:len(currentKey)-1]
}
}
traverse(nodeId, bmIdx)
return result
}
type succinctSetData struct {
Reserved uint8
Leaves []uint64
LabelBitmap []uint64
Labels []byte
}
func readSuccinctSet(reader varbin.Reader) (*succinctSet, error) {
matcher, err := varbin.ReadValue[succinctSetData](reader, binary.BigEndian)
if err != nil {
return nil, err
}
set := &succinctSet{
leaves: matcher.Leaves,
labelBitmap: matcher.LabelBitmap,
labels: matcher.Labels,
}
set.init()
return set, nil
}
func (ss *succinctSet) Write(writer varbin.Writer) error {
return varbin.Write(writer, binary.BigEndian, succinctSetData{
Leaves: ss.leaves,
LabelBitmap: ss.labelBitmap,
Labels: ss.labels,
})
}
func setBit(bm *[]uint64, i int, v int) {

View file

@ -12,3 +12,16 @@ func (e *causeError) Error() string {
func (e *causeError) Unwrap() error {
return e.cause
}
type causeError1 struct {
error
cause error
}
func (e *causeError1) Error() string {
return e.error.Error() + ": " + e.cause.Error()
}
func (e *causeError1) Unwrap() []error {
return []error{e.error, e.cause}
}

View file

@ -12,6 +12,7 @@ import (
F "github.com/sagernet/sing/common/format"
)
// Deprecated: wtf is this?
type Handler interface {
NewError(ctx context.Context, err error)
}
@ -31,6 +32,13 @@ func Cause(cause error, message ...any) error {
return &causeError{F.ToString(message...), cause}
}
func Cause1(err error, cause error) error {
if cause == nil {
panic("cause on an nil error")
}
return &causeError1{err, cause}
}
func Extend(cause error, message ...any) error {
if cause == nil {
panic("extend on an nil error")
@ -39,11 +47,11 @@ func Extend(cause error, message ...any) error {
}
func IsClosedOrCanceled(err error) bool {
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded)
return IsClosed(err) || IsCanceled(err) || IsTimeout(err)
}
func IsClosed(err error) bool {
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET)
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, syscall.ENOTCONN)
}
func IsCanceled(err error) bool {

View file

@ -1,24 +1,14 @@
package exceptions
import "github.com/sagernet/sing/common"
import (
"errors"
type HasInnerError interface {
Unwrap() error
}
"github.com/sagernet/sing/common"
)
// Deprecated: Use errors.Unwrap instead.
func Unwrap(err error) error {
for {
inner, ok := err.(HasInnerError)
if !ok {
break
}
innerErr := inner.Unwrap()
if innerErr == nil {
break
}
err = innerErr
}
return err
return errors.Unwrap(err)
}
func Cast[T any](err error) (T, bool) {

View file

@ -37,10 +37,13 @@ func Errors(errors ...error) error {
}
func Expand(err error) []error {
if multiErr, isMultiErr := err.(MultiError); isMultiErr {
return ExpandAll(multiErr.Unwrap())
if err == nil {
return nil
} else if multiErr, isMultiErr := err.(MultiError); isMultiErr {
return ExpandAll(common.FilterNotNil(multiErr.Unwrap()))
} else {
return []error{err}
}
return []error{err}
}
func ExpandAll(errs []error) []error {
@ -60,12 +63,5 @@ func IsMulti(err error, targetList ...error) bool {
return true
}
}
err = Unwrap(err)
multiErr, isMulti := err.(MultiError)
if !isMulti {
return false
}
return common.All(multiErr.Unwrap(), func(it error) bool {
return IsMulti(it, targetList...)
})
return false
}

View file

@ -1,17 +1,21 @@
package exceptions
import "net"
import (
"errors"
"net"
)
type TimeoutError interface {
Timeout() bool
}
func IsTimeout(err error) bool {
if netErr, isNetErr := err.(net.Error); isNetErr {
//goland:noinspection GoDeprecation
var netErr net.Error
if errors.As(err, &netErr) {
//nolint:staticcheck
return netErr.Temporary() && netErr.Timeout()
} else if timeoutErr, isTimeout := Cast[TimeoutError](err); isTimeout {
}
if timeoutErr, isTimeout := Cast[TimeoutError](err); isTimeout {
return timeoutErr.Timeout()
}
return false

View file

@ -3,12 +3,25 @@ package badjson
import (
"bytes"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
type JSONArray []any
func (a JSONArray) IsEmpty() bool {
if len(a) == 0 {
return true
}
return common.All(a, func(it any) bool {
if valueInterface, valueMaybeEmpty := it.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
return true
}
return false
})
}
func (a JSONArray) MarshalJSON() ([]byte, error) {
return json.Marshal([]any(a))
}

View file

@ -0,0 +1,5 @@
package badjson
type isEmpty interface {
IsEmpty() bool
}

View file

@ -2,13 +2,14 @@ package badjson
import (
"bytes"
"context"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
func Decode(content []byte) (any, error) {
decoder := json.NewDecoder(bytes.NewReader(content))
func Decode(ctx context.Context, content []byte) (any, error) {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
return decodeJSON(decoder)
}

View file

@ -1,6 +1,7 @@
package badjson
import (
"context"
"os"
"reflect"
@ -9,75 +10,75 @@ import (
"github.com/sagernet/sing/common/json"
)
func Omitempty[T any](value T) (T, error) {
objectContent, err := json.Marshal(value)
func Omitempty[T any](ctx context.Context, value T) (T, error) {
objectContent, err := json.MarshalContext(ctx, value)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal object")
}
rawNewObject, err := Decode(objectContent)
rawNewObject, err := Decode(ctx, objectContent)
if err != nil {
return common.DefaultValue[T](), err
}
newObjectContent, err := json.Marshal(rawNewObject)
newObjectContent, err := json.MarshalContext(ctx, rawNewObject)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
}
var newObject T
err = json.Unmarshal(newObjectContent, &newObject)
err = json.UnmarshalContext(ctx, newObjectContent, &newObject)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
}
return newObject, nil
}
func Merge[T any](source T, destination T) (T, error) {
rawSource, err := json.Marshal(source)
func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) {
rawSource, err := json.MarshalContext(ctx, source)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source")
}
rawDestination, err := json.Marshal(destination)
rawDestination, err := json.MarshalContext(ctx, destination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
}
return MergeFrom[T](rawSource, rawDestination)
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}
func MergeFromSource[T any](rawSource json.RawMessage, destination T) (T, error) {
func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
if rawSource == nil {
return destination, nil
}
rawDestination, err := json.Marshal(destination)
rawDestination, err := json.MarshalContext(ctx, destination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
}
return MergeFrom[T](rawSource, rawDestination)
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}
func MergeFromDestination[T any](source T, rawDestination json.RawMessage) (T, error) {
func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
if rawDestination == nil {
return source, nil
}
rawSource, err := json.Marshal(source)
rawSource, err := json.MarshalContext(ctx, source)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source")
}
return MergeFrom[T](rawSource, rawDestination)
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}
func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage) (T, error) {
rawMerged, err := MergeJSON(rawSource, rawDestination)
func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "merge options")
}
var merged T
err = json.Unmarshal(rawMerged, &merged)
err = json.UnmarshalContext(ctx, rawMerged, &merged)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
}
return merged, nil
}
func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage) (json.RawMessage, error) {
func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
if rawSource == nil && rawDestination == nil {
return nil, os.ErrInvalid
} else if rawSource == nil {
@ -85,29 +86,36 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage) (json.
} else if rawDestination == nil {
return rawSource, nil
}
source, err := Decode(rawSource)
source, err := Decode(ctx, rawSource)
if err != nil {
return nil, E.Cause(err, "decode source")
}
destination, err := Decode(rawDestination)
destination, err := Decode(ctx, rawDestination)
if err != nil {
return nil, E.Cause(err, "decode destination")
}
merged, err := mergeJSON(source, destination)
if source == nil {
return json.MarshalContext(ctx, destination)
} else if destination == nil {
return json.Marshal(source)
}
merged, err := mergeJSON(source, destination, disableAppend)
if err != nil {
return nil, err
}
return json.Marshal(merged)
return json.MarshalContext(ctx, merged)
}
func mergeJSON(anySource any, anyDestination any) (any, error) {
func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) {
switch destination := anyDestination.(type) {
case JSONArray:
switch source := anySource.(type) {
case JSONArray:
destination = append(destination, source...)
default:
destination = append(destination, source)
if !disableAppend {
switch source := anySource.(type) {
case JSONArray:
destination = append(destination, source...)
default:
destination = append(destination, source)
}
}
return destination, nil
case *JSONObject:
@ -117,7 +125,7 @@ func mergeJSON(anySource any, anyDestination any) (any, error) {
oldValue, loaded := destination.Get(entry.Key)
if loaded {
var err error
entry.Value, err = mergeJSON(entry.Value, oldValue)
entry.Value, err = mergeJSON(entry.Value, oldValue, disableAppend)
if err != nil {
return nil, E.Cause(err, "merge object item ", entry.Key)
}
@ -125,7 +133,7 @@ func mergeJSON(anySource any, anyDestination any) (any, error) {
destination.Put(entry.Key, entry.Value)
}
default:
return nil, E.New("cannot merge json object into ", reflect.TypeOf(destination))
return nil, E.New("cannot merge json object into ", reflect.TypeOf(source))
}
return destination, nil
default:

View file

@ -0,0 +1,68 @@
package badjson
import (
"context"
"reflect"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
cJSON "github.com/sagernet/sing/common/json/internal/contextjson"
)
func MarshallObjects(objects ...any) ([]byte, error) {
return MarshallObjectsContext(context.Background(), objects...)
}
func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) {
if len(objects) == 1 {
return json.Marshal(objects[0])
}
var content JSONObject
for _, object := range objects {
objectMap, err := newJSONObject(ctx, object)
if err != nil {
return nil, err
}
content.PutAll(objectMap)
}
return content.MarshalJSONContext(ctx)
}
func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error {
return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object)
}
func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
var content JSONObject
err := content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return err
}
for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) {
content.Remove(key)
}
if object == nil {
if content.IsEmpty() {
return nil
}
return E.New("unexpected key: ", content.Keys()[0])
}
inputContent, err = content.MarshalJSONContext(ctx)
if err != nil {
return err
}
return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object)
}
func newJSONObject(ctx context.Context, object any) (*JSONObject, error) {
inputContent, err := json.MarshalContext(ctx, object)
if err != nil {
return nil, err
}
var content JSONObject
err = content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return nil, err
}
return &content, nil
}

View file

@ -2,6 +2,7 @@ package badjson
import (
"bytes"
"context"
"strings"
"github.com/sagernet/sing/common"
@ -15,24 +16,40 @@ type JSONObject struct {
linkedhashmap.Map[string, any]
}
func (m JSONObject) MarshalJSON() ([]byte, error) {
func (m *JSONObject) IsEmpty() bool {
if m.Size() == 0 {
return true
}
return common.All(m.Entries(), func(it collections.MapEntry[string, any]) bool {
if valueInterface, valueMaybeEmpty := it.Value.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
return true
}
return false
})
}
func (m *JSONObject) MarshalJSON() ([]byte, error) {
return m.MarshalJSONContext(context.Background())
}
func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
buffer := new(bytes.Buffer)
buffer.WriteString("{")
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
if valueObject, valueIsJSONObject := it.Value.(*JSONObject); valueIsJSONObject && valueObject.IsEmpty() {
if valueInterface, valueMaybeEmpty := it.Value.(isEmpty); valueMaybeEmpty && valueInterface.IsEmpty() {
return false
}
return true
})
iLen := len(items)
for i, entry := range items {
keyContent, err := json.Marshal(entry.Key)
keyContent, err := json.MarshalContext(ctx, entry.Key)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ")
valueContent, err := json.Marshal(entry.Value)
valueContent, err := json.MarshalContext(ctx, entry.Value)
if err != nil {
return nil, err
}
@ -46,7 +63,11 @@ func (m JSONObject) MarshalJSON() ([]byte, error) {
}
func (m *JSONObject) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content))
return m.UnmarshalJSONContext(context.Background(), content)
}
func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
m.Clear()
objectStart, err := decoder.Token()
if err != nil {

View file

@ -2,6 +2,7 @@ package badjson
import (
"bytes"
"context"
"strings"
E "github.com/sagernet/sing/common/exceptions"
@ -14,18 +15,22 @@ type TypedMap[K comparable, V any] struct {
}
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
return m.MarshalJSONContext(context.Background())
}
func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
buffer := new(bytes.Buffer)
buffer.WriteString("{")
items := m.Entries()
iLen := len(items)
for i, entry := range items {
keyContent, err := json.Marshal(entry.Key)
keyContent, err := json.MarshalContext(ctx, entry.Key)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ")
valueContent, err := json.Marshal(entry.Value)
valueContent, err := json.MarshalContext(ctx, entry.Value)
if err != nil {
return nil, err
}
@ -39,7 +44,11 @@ func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
}
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content))
return m.UnmarshalJSONContext(context.Background(), content)
}
func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
m.Clear()
objectStart, err := decoder.Token()
if err != nil {
@ -47,7 +56,7 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
} else if objectStart != json.Delim('{') {
return E.New("expected json object start, but starts with ", objectStart)
}
err = m.decodeJSON(decoder)
err = m.decodeJSON(ctx, decoder)
if err != nil {
return E.Cause(err, "decode json object content")
}
@ -60,18 +69,18 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
return nil
}
func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error {
func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error {
for decoder.More() {
keyToken, err := decoder.Token()
if err != nil {
return err
}
keyContent, err := json.Marshal(keyToken)
keyContent, err := json.MarshalContext(ctx, keyToken)
if err != nil {
return err
}
var entryKey K
err = json.Unmarshal(keyContent, &entryKey)
err = json.UnmarshalContext(ctx, keyContent, &entryKey)
if err != nil {
return err
}

View file

@ -0,0 +1,32 @@
package badoption
import (
"time"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badoption/internal/my_time"
)
type Duration time.Duration
func (d Duration) Build() time.Duration {
return time.Duration(d)
}
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal((time.Duration)(d).String())
}
func (d *Duration) UnmarshalJSON(bytes []byte) error {
var value string
err := json.Unmarshal(bytes, &value)
if err != nil {
return err
}
duration, err := my_time.ParseDuration(value)
if err != nil {
return err
}
*d = Duration(duration)
return nil
}

View file

@ -0,0 +1,15 @@
package badoption
import "net/http"
type HTTPHeader map[string]Listable[string]
func (h HTTPHeader) Build() http.Header {
header := make(http.Header)
for name, values := range h {
for _, value := range values {
header.Add(name, value)
}
}
return header
}

View file

@ -0,0 +1,226 @@
package my_time
import (
"errors"
"time"
)
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
const durationDay = 24 * time.Hour
var unitMap = map[string]uint64{
"ns": uint64(time.Nanosecond),
"us": uint64(time.Microsecond),
"µs": uint64(time.Microsecond), // U+00B5 = micro symbol
"μs": uint64(time.Microsecond), // U+03BC = Greek letter mu
"ms": uint64(time.Millisecond),
"s": uint64(time.Second),
"m": uint64(time.Minute),
"h": uint64(time.Hour),
"d": uint64(durationDay),
}
// ParseDuration parses a duration string.
// A duration string is a possibly signed sequence of
// decimal numbers, each with optional fraction and a unit suffix,
// such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func ParseDuration(s string) (time.Duration, error) {
// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
orig := s
var d uint64
neg := false
// Consume [-+]?
if s != "" {
c := s[0]
if c == '-' || c == '+' {
neg = c == '-'
s = s[1:]
}
}
// Special case: if all that is left is "0", this is zero.
if s == "0" {
return 0, nil
}
if s == "" {
return 0, errors.New("time: invalid duration " + quote(orig))
}
for s != "" {
var (
v, f uint64 // integers before, after decimal point
scale float64 = 1 // value = v + f/scale
)
var err error
// The next character must be [0-9.]
if !(s[0] == '.' || '0' <= s[0] && s[0] <= '9') {
return 0, errors.New("time: invalid duration " + quote(orig))
}
// Consume [0-9]*
pl := len(s)
v, s, err = leadingInt(s)
if err != nil {
return 0, errors.New("time: invalid duration " + quote(orig))
}
pre := pl != len(s) // whether we consumed anything before a period
// Consume (\.[0-9]*)?
post := false
if s != "" && s[0] == '.' {
s = s[1:]
pl := len(s)
f, scale, s = leadingFraction(s)
post = pl != len(s)
}
if !pre && !post {
// no digits (e.g. ".s" or "-.s")
return 0, errors.New("time: invalid duration " + quote(orig))
}
// Consume unit.
i := 0
for ; i < len(s); i++ {
c := s[i]
if c == '.' || '0' <= c && c <= '9' {
break
}
}
if i == 0 {
return 0, errors.New("time: missing unit in duration " + quote(orig))
}
u := s[:i]
s = s[i:]
unit, ok := unitMap[u]
if !ok {
return 0, errors.New("time: unknown unit " + quote(u) + " in duration " + quote(orig))
}
if v > 1<<63/unit {
// overflow
return 0, errors.New("time: invalid duration " + quote(orig))
}
v *= unit
if f > 0 {
// float64 is needed to be nanosecond accurate for fractions of hours.
// v >= 0 && (f*unit/scale) <= 3.6e+12 (ns/h, h is the largest unit)
v += uint64(float64(f) * (float64(unit) / scale))
if v > 1<<63 {
// overflow
return 0, errors.New("time: invalid duration " + quote(orig))
}
}
d += v
if d > 1<<63 {
return 0, errors.New("time: invalid duration " + quote(orig))
}
}
if neg {
return -time.Duration(d), nil
}
if d > 1<<63-1 {
return 0, errors.New("time: invalid duration " + quote(orig))
}
return time.Duration(d), nil
}
var errLeadingInt = errors.New("time: bad [0-9]*") // never printed
// leadingInt consumes the leading [0-9]* from s.
func leadingInt[bytes []byte | string](s bytes) (x uint64, rem bytes, err error) {
i := 0
for ; i < len(s); i++ {
c := s[i]
if c < '0' || c > '9' {
break
}
if x > 1<<63/10 {
// overflow
return 0, rem, errLeadingInt
}
x = x*10 + uint64(c) - '0'
if x > 1<<63 {
// overflow
return 0, rem, errLeadingInt
}
}
return x, s[i:], nil
}
// leadingFraction consumes the leading [0-9]* from s.
// It is used only for fractions, so does not return an error on overflow,
// it just stops accumulating precision.
func leadingFraction(s string) (x uint64, scale float64, rem string) {
i := 0
scale = 1
overflow := false
for ; i < len(s); i++ {
c := s[i]
if c < '0' || c > '9' {
break
}
if overflow {
continue
}
if x > (1<<63-1)/10 {
// It's possible for overflow to give a positive number, so take care.
overflow = true
continue
}
y := x*10 + uint64(c) - '0'
if y > 1<<63 {
overflow = true
continue
}
x = y
scale *= 10
}
return x, scale, s[i:]
}
// These are borrowed from unicode/utf8 and strconv and replicate behavior in
// that package, since we can't take a dependency on either.
const (
lowerhex = "0123456789abcdef"
runeSelf = 0x80
runeError = '\uFFFD'
)
func quote(s string) string {
buf := make([]byte, 1, len(s)+2) // slice will be at least len(s) + quotes
buf[0] = '"'
for i, c := range s {
if c >= runeSelf || c < ' ' {
// This means you are asking us to parse a time.Duration or
// time.Location with unprintable or non-ASCII characters in it.
// We don't expect to hit this case very often. We could try to
// reproduce strconv.Quote's behavior with full fidelity but
// given how rarely we expect to hit these edge cases, speed and
// conciseness are better.
var width int
if c == runeError {
width = 1
if i+2 < len(s) && s[i:i+3] == string(runeError) {
width = 3
}
} else {
width = len(string(c))
}
for j := 0; j < width; j++ {
buf = append(buf, `\x`...)
buf = append(buf, lowerhex[s[i+j]>>4])
buf = append(buf, lowerhex[s[i+j]&0xF])
}
} else {
if c == '"' || c == '\\' {
buf = append(buf, '\\')
}
buf = append(buf, string(c)...)
}
}
buf = append(buf, '"')
return string(buf)
}

View file

@ -0,0 +1,35 @@
package badoption
import (
"context"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
type Listable[T any] []T
func (l Listable[T]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
arrayList := []T(l)
if len(arrayList) == 1 {
return json.Marshal(arrayList[0])
}
return json.MarshalContext(ctx, arrayList)
}
func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
if string(content) == "null" {
return nil
}
var singleItem T
err := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem)
if err == nil {
*l = []T{singleItem}
return nil
}
newErr := json.UnmarshalContextDisallowUnknownFields(ctx, content, (*[]T)(l))
if newErr == nil {
return nil
}
return E.Errors(err, newErr)
}

View file

@ -0,0 +1,98 @@
package badoption
import (
"net/netip"
"github.com/sagernet/sing/common/json"
)
type Addr netip.Addr
func (a *Addr) Build(defaultAddr netip.Addr) netip.Addr {
if a == nil {
return defaultAddr
}
return netip.Addr(*a)
}
func (a *Addr) MarshalJSON() ([]byte, error) {
return json.Marshal(netip.Addr(*a).String())
}
func (a *Addr) UnmarshalJSON(content []byte) error {
var value string
err := json.Unmarshal(content, &value)
if err != nil {
return err
}
addr, err := netip.ParseAddr(value)
if err != nil {
return err
}
*a = Addr(addr)
return nil
}
type Prefix netip.Prefix
func (p *Prefix) Build(defaultPrefix netip.Prefix) netip.Prefix {
if p == nil {
return defaultPrefix
}
return netip.Prefix(*p)
}
func (p *Prefix) MarshalJSON() ([]byte, error) {
return json.Marshal(netip.Prefix(*p).String())
}
func (p *Prefix) UnmarshalJSON(content []byte) error {
var value string
err := json.Unmarshal(content, &value)
if err != nil {
return err
}
prefix, err := netip.ParsePrefix(value)
if err != nil {
return err
}
*p = Prefix(prefix)
return nil
}
type Prefixable netip.Prefix
func (p *Prefixable) Build(defaultPrefix netip.Prefix) netip.Prefix {
if p == nil {
return defaultPrefix
}
return netip.Prefix(*p)
}
func (p *Prefixable) MarshalJSON() ([]byte, error) {
prefix := netip.Prefix(*p)
if prefix.Bits() == prefix.Addr().BitLen() {
return json.Marshal(prefix.Addr().String())
} else {
return json.Marshal(prefix.String())
}
}
func (p *Prefixable) UnmarshalJSON(content []byte) error {
var value string
err := json.Unmarshal(content, &value)
if err != nil {
return err
}
prefix, prefixErr := netip.ParsePrefix(value)
if prefixErr == nil {
*p = Prefixable(prefix)
return nil
}
addr, addrErr := netip.ParseAddr(value)
if addrErr == nil {
*p = Prefixable(netip.PrefixFrom(addr, addr.BitLen()))
return nil
}
return prefixErr
}

View file

@ -0,0 +1,31 @@
package badoption
import (
"regexp"
"github.com/sagernet/sing/common/json"
)
type Regexp regexp.Regexp
func (r *Regexp) Build() *regexp.Regexp {
return (*regexp.Regexp)(r)
}
func (r *Regexp) MarshalJSON() ([]byte, error) {
return json.Marshal((*regexp.Regexp)(r).String())
}
func (r *Regexp) UnmarshalJSON(content []byte) error {
var stringValue string
err := json.Unmarshal(content, &stringValue)
if err != nil {
return err
}
regex, err := regexp.Compile(stringValue)
if err != nil {
return err
}
*r = Regexp(*regex)
return nil
}

View file

@ -1,4 +1,4 @@
//go:build go1.21 && !without_contextjson
//go:build go1.20 && !without_contextjson
package json

View file

@ -1,23 +0,0 @@
//go:build !go1.21 && go1.20 && !without_contextjson
package json
import (
"github.com/sagernet/sing/common/json/internal/contextjson_120"
)
var (
Marshal = json.Marshal
Unmarshal = json.Unmarshal
NewEncoder = json.NewEncoder
NewDecoder = json.NewDecoder
)
type (
Encoder = json.Encoder
Decoder = json.Decoder
Token = json.Token
Delim = json.Delim
SyntaxError = json.SyntaxError
RawMessage = json.RawMessage
)

View file

@ -0,0 +1,23 @@
package json
import (
"context"
"github.com/sagernet/sing/common/json/internal/contextjson"
)
var (
MarshalContext = json.MarshalContext
UnmarshalContext = json.UnmarshalContext
NewEncoderContext = json.NewEncoderContext
NewDecoderContext = json.NewDecoderContext
UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields
)
type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}
type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}

View file

@ -0,0 +1,11 @@
package json
import "context"
type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}
type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}

View file

@ -0,0 +1,43 @@
package json_test
import (
"context"
"testing"
"github.com/sagernet/sing/common/json/internal/contextjson"
"github.com/stretchr/testify/require"
)
type myStruct struct {
value string
}
func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) {
return json.Marshal(ctx.Value("key").(string))
}
func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error {
m.value = ctx.Value("key").(string)
return nil
}
//nolint:staticcheck
func TestMarshalContext(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), "key", "value")
var s myStruct
b, err := json.MarshalContext(ctx, &s)
require.NoError(t, err)
require.Equal(t, []byte(`"value"`), b)
}
//nolint:staticcheck
func TestUnmarshalContext(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), "key", "value")
var s myStruct
err := json.UnmarshalContext(ctx, []byte(`{}`), &s)
require.NoError(t, err)
require.Equal(t, "value", s.value)
}

View file

@ -8,6 +8,7 @@
package json
import (
"context"
"encoding"
"encoding/base64"
"fmt"
@ -95,10 +96,15 @@ import (
// Instead, they are replaced by the Unicode replacement
// character U+FFFD.
func Unmarshal(data []byte, v any) error {
return UnmarshalContext(context.Background(), data, v)
}
func UnmarshalContext(ctx context.Context, data []byte, v any) error {
// Check for well-formedness.
// Avoids filling out half a data structure
// before discovering a JSON syntax error.
var d decodeState
d.ctx = ctx
err := checkValid(data, &d.scan)
if err != nil {
return err
@ -209,6 +215,7 @@ type errorContext struct {
// decodeState represents the state while decoding a JSON value.
type decodeState struct {
ctx context.Context
data []byte
off int // next read offset in data
opcode int // last read result
@ -428,7 +435,7 @@ func (d *decodeState) valueQuoted() any {
// If it encounters an Unmarshaler, indirect stops and returns that.
// If decodingNull is true, indirect stops at the first settable pointer so it
// can be set to nil.
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) {
// Issue #24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from
@ -482,11 +489,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
}
if v.Type().NumMethod() > 0 && v.CanInterface() {
if u, ok := v.Interface().(Unmarshaler); ok {
return u, nil, reflect.Value{}
return u, nil, nil, reflect.Value{}
}
if cu, ok := v.Interface().(ContextUnmarshaler); ok {
return nil, cu, nil, reflect.Value{}
}
if !decodingNull {
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return nil, u, reflect.Value{}
return nil, nil, u, reflect.Value{}
}
}
}
@ -498,14 +508,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
v = v.Elem()
}
}
return nil, nil, v
return nil, nil, nil, v
}
// array consumes an array from d.data[d.off-1:], decoding into v.
// The first byte of the array ('[') has been read already.
func (d *decodeState) array(v reflect.Value) error {
// Check for unmarshaler.
u, ut, pv := indirect(v, false)
u, cu, ut, pv := indirect(v, false)
if u != nil {
start := d.readIndex()
d.skip()
@ -515,6 +525,15 @@ func (d *decodeState) array(v reflect.Value) error {
}
return nil
}
if cu != nil {
start := d.readIndex()
d.skip()
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
d.skip()
@ -612,7 +631,7 @@ var (
// The first byte ('{') of the object has been read already.
func (d *decodeState) object(v reflect.Value) error {
// Check for unmarshaler.
u, ut, pv := indirect(v, false)
u, cu, ut, pv := indirect(v, false)
if u != nil {
start := d.readIndex()
d.skip()
@ -622,6 +641,15 @@ func (d *decodeState) object(v reflect.Value) error {
}
return nil
}
if cu != nil {
start := d.readIndex()
d.skip()
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
d.skip()
@ -870,7 +898,7 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
return nil
}
isNull := item[0] == 'n' // null
u, ut, pv := indirect(v, isNull)
u, cu, ut, pv := indirect(v, isNull)
if u != nil {
err := u.UnmarshalJSON(item)
if err != nil {
@ -878,6 +906,13 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
}
return nil
}
if cu != nil {
err := cu.UnmarshalJSONContext(d.ctx, item)
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil {
if item[0] != '"' {
if fromQuoted {

View file

@ -12,6 +12,7 @@ package json
import (
"bytes"
"context"
"encoding"
"encoding/base64"
"fmt"
@ -156,7 +157,11 @@ import (
// handle them. Passing cyclic structures to Marshal will result in
// an error.
func Marshal(v any) ([]byte, error) {
e := newEncodeState()
return MarshalContext(context.Background(), v)
}
func MarshalContext(ctx context.Context, v any) ([]byte, error) {
e := newEncodeState(ctx)
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: true})
@ -251,6 +256,7 @@ var hex = "0123456789abcdef"
type encodeState struct {
bytes.Buffer // accumulated output
ctx context.Context
// Keep track of what pointers we've seen in the current recursive call
// path, to avoid cycles that could lead to a stack overflow. Only do
// the relatively expensive map operations if ptrLevel is larger than
@ -264,7 +270,7 @@ const startDetectingCyclesAfter = 1000
var encodeStatePool sync.Pool
func newEncodeState() *encodeState {
func newEncodeState(ctx context.Context) *encodeState {
if v := encodeStatePool.Get(); v != nil {
e := v.(*encodeState)
e.Reset()
@ -274,7 +280,7 @@ func newEncodeState() *encodeState {
e.ptrLevel = 0
return e
}
return &encodeState{ptrSeen: make(map[any]struct{})}
return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})}
}
// jsonError is an error wrapper type for internal use only.
@ -371,8 +377,9 @@ func typeEncoder(t reflect.Type) encoderFunc {
}
var (
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
)
// newTypeEncoder constructs an encoderFunc for a type.
@ -385,9 +392,15 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) {
return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Implements(marshalerType) {
return marshalerEncoder
}
if t.Implements(contextMarshalerType) {
return contextMarshalerEncoder
}
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
}
@ -442,7 +455,7 @@ func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
b, err := m.MarshalJSON()
if err == nil {
e.Grow(len(b))
out := e.AvailableBuffer()
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
@ -461,7 +474,48 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
b, err := m.MarshalJSON()
if err == nil {
e.Grow(len(b))
out := e.AvailableBuffer()
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null")
return
}
m, ok := v.Interface().(ContextMarshaler)
if !ok {
e.WriteString("null")
return
}
b, err := m.MarshalJSONContext(e.ctx)
if err == nil {
e.Grow(len(b))
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
va := v.Addr()
if va.IsNil() {
e.WriteString("null")
return
}
m := va.Interface().(ContextMarshaler)
b, err := m.MarshalJSONContext(e.ctx)
if err == nil {
e.Grow(len(b))
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
@ -484,7 +538,7 @@ func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalText"})
}
e.Write(appendString(e.AvailableBuffer(), b, opts.escapeHTML))
e.Write(appendString(availableBuffer(&e.Buffer), b, opts.escapeHTML))
}
func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
@ -498,11 +552,11 @@ func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalText"})
}
e.Write(appendString(e.AvailableBuffer(), b, opts.escapeHTML))
e.Write(appendString(availableBuffer(&e.Buffer), b, opts.escapeHTML))
}
func boolEncoder(e *encodeState, v reflect.Value, opts encOpts) {
b := e.AvailableBuffer()
b := availableBuffer(&e.Buffer)
b = mayAppendQuote(b, opts.quoted)
b = strconv.AppendBool(b, v.Bool())
b = mayAppendQuote(b, opts.quoted)
@ -510,7 +564,7 @@ func boolEncoder(e *encodeState, v reflect.Value, opts encOpts) {
}
func intEncoder(e *encodeState, v reflect.Value, opts encOpts) {
b := e.AvailableBuffer()
b := availableBuffer(&e.Buffer)
b = mayAppendQuote(b, opts.quoted)
b = strconv.AppendInt(b, v.Int(), 10)
b = mayAppendQuote(b, opts.quoted)
@ -518,7 +572,7 @@ func intEncoder(e *encodeState, v reflect.Value, opts encOpts) {
}
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
b := e.AvailableBuffer()
b := availableBuffer(&e.Buffer)
b = mayAppendQuote(b, opts.quoted)
b = strconv.AppendUint(b, v.Uint(), 10)
b = mayAppendQuote(b, opts.quoted)
@ -538,7 +592,7 @@ func (bits floatEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
// See golang.org/issue/6384 and golang.org/issue/14135.
// Like fmt %g, but the exponent cutoffs are different
// and exponents themselves are not padded to two digits.
b := e.AvailableBuffer()
b := availableBuffer(&e.Buffer)
b = mayAppendQuote(b, opts.quoted)
abs := math.Abs(f)
fmt := byte('f')
@ -577,7 +631,7 @@ func stringEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if !isValidNumber(numStr) {
e.error(fmt.Errorf("json: invalid number literal %q", numStr))
}
b := e.AvailableBuffer()
b := availableBuffer(&e.Buffer)
b = mayAppendQuote(b, opts.quoted)
b = append(b, numStr...)
b = mayAppendQuote(b, opts.quoted)
@ -586,9 +640,9 @@ func stringEncoder(e *encodeState, v reflect.Value, opts encOpts) {
}
if opts.quoted {
b := appendString(nil, v.String(), opts.escapeHTML)
e.Write(appendString(e.AvailableBuffer(), b, false)) // no need to escape again since it is already escaped
e.Write(appendString(availableBuffer(&e.Buffer), b, false)) // no need to escape again since it is already escaped
} else {
e.Write(appendString(e.AvailableBuffer(), v.String(), opts.escapeHTML))
e.Write(appendString(availableBuffer(&e.Buffer), v.String(), opts.escapeHTML))
}
}
@ -754,7 +808,7 @@ func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if i > 0 {
e.WriteByte(',')
}
e.Write(appendString(e.AvailableBuffer(), kv.ks, opts.escapeHTML))
e.Write(appendString(availableBuffer(&e.Buffer), kv.ks, opts.escapeHTML))
e.WriteByte(':')
me.elemEnc(e, kv.v, opts)
}
@ -786,7 +840,7 @@ func encodeByteSlice(e *encodeState, v reflect.Value, _ encOpts) {
e.Grow(len(`"`) + encodedLen + len(`"`))
// TODO(https://go.dev/issue/53693): Use base64.Encoding.AppendEncode.
b := e.AvailableBuffer()
b := availableBuffer(&e.Buffer)
b = append(b, '"')
base64.StdEncoding.Encode(b[len(b):][:encodedLen], s)
b = b[:len(b)+encodedLen]
@ -827,7 +881,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc {
// Byte slices get special treatment; arrays don't.
if t.Elem().Kind() == reflect.Uint8 {
p := reflect.PointerTo(t.Elem())
if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) {
if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) {
return encodeByteSlice
}
}

View file

@ -6,6 +6,11 @@ package json
import "bytes"
// TODO(https://go.dev/issue/53685): Use bytes.Buffer.AvailableBuffer instead.
func availableBuffer(b *bytes.Buffer) []byte {
return b.Bytes()[b.Len():]
}
// HTMLEscape appends to dst the JSON-encoded src with <, >, &, U+2028 and U+2029
// characters inside string literals changed to \u003c, \u003e, \u0026, \u2028, \u2029
// so that the JSON will be safe to embed inside HTML <script> tags.
@ -13,7 +18,7 @@ import "bytes"
// escaping within <script> tags, so an alternative JSON encoding must be used.
func HTMLEscape(dst *bytes.Buffer, src []byte) {
dst.Grow(len(src))
dst.Write(appendHTMLEscape(dst.AvailableBuffer(), src))
dst.Write(appendHTMLEscape(availableBuffer(dst), src))
}
func appendHTMLEscape(dst, src []byte) []byte {
@ -40,7 +45,7 @@ func appendHTMLEscape(dst, src []byte) []byte {
// insignificant space characters elided.
func Compact(dst *bytes.Buffer, src []byte) error {
dst.Grow(len(src))
b := dst.AvailableBuffer()
b := availableBuffer(dst)
b, err := appendCompact(b, src, false)
dst.Write(b)
return err
@ -109,7 +114,7 @@ const indentGrowthFactor = 2
// if src ends in a trailing newline, so will dst.
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
dst.Grow(indentGrowthFactor * len(src))
b := dst.AvailableBuffer()
b := availableBuffer(dst)
b, err := appendIndent(b, src, prefix, indent)
dst.Write(b)
return err

View file

@ -0,0 +1,20 @@
package json
import (
"reflect"
"github.com/sagernet/sing/common"
)
func ObjectKeys(object reflect.Type) []string {
switch object.Kind() {
case reflect.Pointer:
return ObjectKeys(object.Elem())
case reflect.Struct:
default:
panic("invalid non-struct input")
}
return common.Map(cachedTypeFields(object).list, func(field field) string {
return field.name
})
}

View file

@ -0,0 +1,26 @@
package json_test
import (
"reflect"
"testing"
json "github.com/sagernet/sing/common/json/internal/contextjson"
"github.com/stretchr/testify/require"
)
type MyObject struct {
Hello string `json:"hello,omitempty"`
MyWorld
MyWorld2 string `json:"-"`
}
type MyWorld struct {
World string `json:"world,omitempty"`
}
func TestObjectKeys(t *testing.T) {
t.Parallel()
keys := json.ObjectKeys(reflect.TypeOf(&MyObject{}))
require.Equal(t, []string{"hello", "world"}, keys)
}

View file

@ -300,7 +300,7 @@ func stateEndValue(s *scanner, c byte) int {
case parseObjectValue:
if c == ',' {
s.parseState[n-1] = parseObjectKey
s.step = stateBeginString
s.step = stateBeginStringOrEmpty
return scanObjectValue
}
if c == '}' {
@ -310,7 +310,7 @@ func stateEndValue(s *scanner, c byte) int {
return s.error(c, "after object key:value pair")
case parseArrayValue:
if c == ',' {
s.step = stateBeginValue
s.step = stateBeginValueOrEmpty
return scanArrayValue
}
if c == ']' {

View file

@ -6,6 +6,7 @@ package json
import (
"bytes"
"context"
"errors"
"io"
)
@ -29,7 +30,11 @@ type Decoder struct {
// The decoder introduces its own buffering and may
// read data from r beyond the JSON values requested.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
return NewDecoderContext(context.Background(), r)
}
func NewDecoderContext(ctx context.Context, r io.Reader) *Decoder {
return &Decoder{r: r, d: decodeState{ctx: ctx}}
}
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
@ -153,6 +158,10 @@ func (dec *Decoder) refill() error {
dec.scanp = 0
}
return dec.refill0()
}
func (dec *Decoder) refill0() error {
// Grow buffer if not large enough.
const minRead = 512
if cap(dec.buf)-len(dec.buf) < minRead {
@ -179,6 +188,7 @@ func nonSpace(b []byte) bool {
// An Encoder writes JSON values to an output stream.
type Encoder struct {
ctx context.Context
w io.Writer
err error
escapeHTML bool
@ -190,7 +200,11 @@ type Encoder struct {
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w, escapeHTML: true}
return NewEncoderContext(context.Background(), w)
}
func NewEncoderContext(ctx context.Context, w io.Writer) *Encoder {
return &Encoder{ctx: ctx, w: w, escapeHTML: true}
}
// Encode writes the JSON encoding of v to the stream,
@ -203,7 +217,7 @@ func (enc *Encoder) Encode(v any) error {
return enc.err
}
e := newEncodeState()
e := newEncodeState(enc.ctx)
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
@ -402,7 +416,7 @@ func (dec *Decoder) Token() (Token, error) {
return Delim('{'), nil
case '}':
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma {
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma && dec.tokenState != tokenObjectKey {
return dec.tokenError(c)
}
dec.scanp++
@ -410,7 +424,6 @@ func (dec *Decoder) Token() (Token, error) {
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim('}'), nil
case ':':
if dec.tokenState != tokenObjectColon {
return dec.tokenError(c)
@ -483,7 +496,26 @@ func (dec *Decoder) tokenError(c byte) (Token, error) {
// current array or object being parsed.
func (dec *Decoder) More() bool {
c, err := dec.peek()
return err == nil && c != ']' && c != '}'
// return err == nil && c != ']' && c != '}'
if err != nil {
return false
}
if c == ']' || c == '}' {
return false
}
if c == ',' {
scanp := dec.scanp
dec.scanp++
c, err = dec.peekNoRefill()
dec.scanp = scanp
if err != nil {
return false
}
if c == ']' || c == '}' {
return false
}
}
return true
}
func (dec *Decoder) peek() (byte, error) {
@ -505,6 +537,25 @@ func (dec *Decoder) peek() (byte, error) {
}
}
func (dec *Decoder) peekNoRefill() (byte, error) {
var err error
for {
for i := dec.scanp; i < len(dec.buf); i++ {
c := dec.buf[i]
if isSpace(c) {
continue
}
dec.scanp = i
return c, nil
}
// buffer has been scanned, now report any error
if err != nil {
return 0, err
}
err = dec.refill0()
}
}
// InputOffset returns the input stream byte offset of the current decoder position.
// The offset gives the location of the end of the most recently returned token
// and the beginning of the next token.

View file

@ -0,0 +1,26 @@
package json
import "context"
func UnmarshalDisallowUnknownFields(data []byte, v any) error {
var d decodeState
d.disallowUnknownFields = true
err := checkValid(data, &d.scan)
if err != nil {
return err
}
d.init(data)
return d.unmarshal(v)
}
func UnmarshalContextDisallowUnknownFields(ctx context.Context, data []byte, v any) error {
var d decodeState
d.ctx = ctx
d.disallowUnknownFields = true
err := checkValid(data, &d.scan)
if err != nil {
return err
}
d.init(data)
return d.unmarshal(v)
}

View file

@ -1,3 +0,0 @@
# contextjson
mod from go1.20.11

File diff suppressed because it is too large Load diff

View file

@ -1,49 +0,0 @@
package json
import "strconv"
type decodeContext struct {
parent *decodeContext
index int
key string
}
func (d *decodeState) formatContext() string {
var description string
context := d.context
var appendDot bool
for context != nil {
if appendDot {
description = "." + description
}
if context.key != "" {
description = context.key + description
appendDot = true
} else {
description = "[" + strconv.Itoa(context.index) + "]" + description
appendDot = false
}
context = context.parent
}
return description
}
type contextError struct {
parent error
context string
index bool
}
func (c *contextError) Unwrap() error {
return c.parent
}
func (c *contextError) Error() string {
//goland:noinspection GoTypeAssertionOnErrors
switch c.parent.(type) {
case *contextError:
return c.context + "." + c.parent.Error()
default:
return c.context + ": " + c.parent.Error()
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,141 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"unicode/utf8"
)
const (
caseMask = ^byte(0x20) // Mask to ignore case in ASCII.
kelvin = '\u212a'
smallLongEss = '\u017f'
)
// foldFunc returns one of four different case folding equivalence
// functions, from most general (and slow) to fastest:
//
// 1) bytes.EqualFold, if the key s contains any non-ASCII UTF-8
// 2) equalFoldRight, if s contains special folding ASCII ('k', 'K', 's', 'S')
// 3) asciiEqualFold, no special, but includes non-letters (including _)
// 4) simpleLetterEqualFold, no specials, no non-letters.
//
// The letters S and K are special because they map to 3 runes, not just 2:
// - S maps to s and to U+017F 'ſ' Latin small letter long s
// - k maps to K and to U+212A 'K' Kelvin sign
//
// See https://play.golang.org/p/tTxjOc0OGo
//
// The returned function is specialized for matching against s and
// should only be given s. It's not curried for performance reasons.
func foldFunc(s []byte) func(s, t []byte) bool {
nonLetter := false
special := false // special letter
for _, b := range s {
if b >= utf8.RuneSelf {
return bytes.EqualFold
}
upper := b & caseMask
if upper < 'A' || upper > 'Z' {
nonLetter = true
} else if upper == 'K' || upper == 'S' {
// See above for why these letters are special.
special = true
}
}
if special {
return equalFoldRight
}
if nonLetter {
return asciiEqualFold
}
return simpleLetterEqualFold
}
// equalFoldRight is a specialization of bytes.EqualFold when s is
// known to be all ASCII (including punctuation), but contains an 's',
// 'S', 'k', or 'K', requiring a Unicode fold on the bytes in t.
// See comments on foldFunc.
func equalFoldRight(s, t []byte) bool {
for _, sb := range s {
if len(t) == 0 {
return false
}
tb := t[0]
if tb < utf8.RuneSelf {
if sb != tb {
sbUpper := sb & caseMask
if 'A' <= sbUpper && sbUpper <= 'Z' {
if sbUpper != tb&caseMask {
return false
}
} else {
return false
}
}
t = t[1:]
continue
}
// sb is ASCII and t is not. t must be either kelvin
// sign or long s; sb must be s, S, k, or K.
tr, size := utf8.DecodeRune(t)
switch sb {
case 's', 'S':
if tr != smallLongEss {
return false
}
case 'k', 'K':
if tr != kelvin {
return false
}
default:
return false
}
t = t[size:]
}
return len(t) == 0
}
// asciiEqualFold is a specialization of bytes.EqualFold for use when
// s is all ASCII (but may contain non-letters) and contains no
// special-folding letters.
// See comments on foldFunc.
func asciiEqualFold(s, t []byte) bool {
if len(s) != len(t) {
return false
}
for i, sb := range s {
tb := t[i]
if sb == tb {
continue
}
if ('a' <= sb && sb <= 'z') || ('A' <= sb && sb <= 'Z') {
if sb&caseMask != tb&caseMask {
return false
}
} else {
return false
}
}
return true
}
// simpleLetterEqualFold is a specialization of bytes.EqualFold for
// use when s is all ASCII letters (no underscores, etc) and also
// doesn't contain 'k', 'K', 's', or 'S'.
// See comments on foldFunc.
func simpleLetterEqualFold(s, t []byte) bool {
if len(s) != len(t) {
return false
}
for i, b := range s {
if b&caseMask != t[i]&caseMask {
return false
}
}
return true
}

View file

@ -1,42 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gofuzz
package json
import (
"fmt"
)
func Fuzz(data []byte) (score int) {
for _, ctor := range []func() any{
func() any { return new(any) },
func() any { return new(map[string]any) },
func() any { return new([]any) },
} {
v := ctor()
err := Unmarshal(data, v)
if err != nil {
continue
}
score = 1
m, err := Marshal(v)
if err != nil {
fmt.Printf("v=%#v\n", v)
panic(err)
}
u := ctor()
err = Unmarshal(m, u)
if err != nil {
fmt.Printf("v=%#v\n", v)
fmt.Printf("m=%s\n", m)
panic(err)
}
}
return
}

View file

@ -1,143 +0,0 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
)
// Compact appends to dst the JSON-encoded src with
// insignificant space characters elided.
func Compact(dst *bytes.Buffer, src []byte) error {
return compact(dst, src, false)
}
func compact(dst *bytes.Buffer, src []byte, escape bool) error {
origLen := dst.Len()
scan := newScanner()
defer freeScanner(scan)
start := 0
for i, c := range src {
if escape && (c == '<' || c == '>' || c == '&') {
if start < i {
dst.Write(src[start:i])
}
dst.WriteString(`\u00`)
dst.WriteByte(hex[c>>4])
dst.WriteByte(hex[c&0xF])
start = i + 1
}
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
if escape && c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
if start < i {
dst.Write(src[start:i])
}
dst.WriteString(`\u202`)
dst.WriteByte(hex[src[i+2]&0xF])
start = i + 3
}
v := scan.step(scan, c)
if v >= scanSkipSpace {
if v == scanError {
break
}
if start < i {
dst.Write(src[start:i])
}
start = i + 1
}
}
if scan.eof() == scanError {
dst.Truncate(origLen)
return scan.err
}
if start < len(src) {
dst.Write(src[start:])
}
return nil
}
func newline(dst *bytes.Buffer, prefix, indent string, depth int) {
dst.WriteByte('\n')
dst.WriteString(prefix)
for i := 0; i < depth; i++ {
dst.WriteString(indent)
}
}
// Indent appends to dst an indented form of the JSON-encoded src.
// Each element in a JSON object or array begins on a new,
// indented line beginning with prefix followed by one or more
// copies of indent according to the indentation nesting.
// The data appended to dst does not begin with the prefix nor
// any indentation, to make it easier to embed inside other formatted JSON data.
// Although leading space characters (space, tab, carriage return, newline)
// at the beginning of src are dropped, trailing space characters
// at the end of src are preserved and copied to dst.
// For example, if src has no trailing spaces, neither will dst;
// if src ends in a trailing newline, so will dst.
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
origLen := dst.Len()
scan := newScanner()
defer freeScanner(scan)
needIndent := false
depth := 0
for _, c := range src {
scan.bytes++
v := scan.step(scan, c)
if v == scanSkipSpace {
continue
}
if v == scanError {
break
}
if needIndent && v != scanEndObject && v != scanEndArray {
needIndent = false
depth++
newline(dst, prefix, indent, depth)
}
// Emit semantically uninteresting bytes
// (in particular, punctuation in strings) unmodified.
if v == scanContinue {
dst.WriteByte(c)
continue
}
// Add spacing around real punctuation.
switch c {
case '{', '[':
// delay indent so that empty object and array are formatted as {} and [].
needIndent = true
dst.WriteByte(c)
case ',':
dst.WriteByte(c)
newline(dst, prefix, indent, depth)
case ':':
dst.WriteByte(c)
dst.WriteByte(' ')
case '}', ']':
if needIndent {
// suppress indent in empty object/array
needIndent = false
} else {
depth--
newline(dst, prefix, indent, depth)
}
dst.WriteByte(c)
default:
dst.WriteByte(c)
}
}
if scan.eof() == scanError {
dst.Truncate(origLen)
return scan.err
}
return nil
}

View file

@ -1,610 +0,0 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
// JSON value parser state machine.
// Just about at the limit of what is reasonable to write by hand.
// Some parts are a bit tedious, but overall it nicely factors out the
// otherwise common code from the multiple scanning functions
// in this package (Compact, Indent, checkValid, etc).
//
// This file starts with two simple examples using the scanner
// before diving into the scanner itself.
import (
"strconv"
"sync"
)
// Valid reports whether data is a valid JSON encoding.
func Valid(data []byte) bool {
scan := newScanner()
defer freeScanner(scan)
return checkValid(data, scan) == nil
}
// checkValid verifies that data is valid JSON-encoded data.
// scan is passed in for use by checkValid to avoid an allocation.
// checkValid returns nil or a SyntaxError.
func checkValid(data []byte, scan *scanner) error {
scan.reset()
for _, c := range data {
scan.bytes++
if scan.step(scan, c) == scanError {
return scan.err
}
}
if scan.eof() == scanError {
return scan.err
}
return nil
}
// A SyntaxError is a description of a JSON syntax error.
// Unmarshal will return a SyntaxError if the JSON can't be parsed.
type SyntaxError struct {
msg string // description of error
Offset int64 // error occurred after reading Offset bytes
}
func (e *SyntaxError) Error() string { return e.msg }
// A scanner is a JSON scanning state machine.
// Callers call scan.reset and then pass bytes in one at a time
// by calling scan.step(&scan, c) for each byte.
// The return value, referred to as an opcode, tells the
// caller about significant parsing events like beginning
// and ending literals, objects, and arrays, so that the
// caller can follow along if it wishes.
// The return value scanEnd indicates that a single top-level
// JSON value has been completed, *before* the byte that
// just got passed in. (The indication must be delayed in order
// to recognize the end of numbers: is 123 a whole value or
// the beginning of 12345e+6?).
type scanner struct {
// The step is a func to be called to execute the next transition.
// Also tried using an integer constant and a single func
// with a switch, but using the func directly was 10% faster
// on a 64-bit Mac Mini, and it's nicer to read.
step func(*scanner, byte) int
// Reached end of top-level value.
endTop bool
// Stack of what we're in the middle of - array values, object keys, object values.
parseState []int
// Error that happened, if any.
err error
// total bytes consumed, updated by decoder.Decode (and deliberately
// not set to zero by scan.reset)
bytes int64
}
var scannerPool = sync.Pool{
New: func() any {
return &scanner{}
},
}
func newScanner() *scanner {
scan := scannerPool.Get().(*scanner)
// scan.reset by design doesn't set bytes to zero
scan.bytes = 0
scan.reset()
return scan
}
func freeScanner(scan *scanner) {
// Avoid hanging on to too much memory in extreme cases.
if len(scan.parseState) > 1024 {
scan.parseState = nil
}
scannerPool.Put(scan)
}
// These values are returned by the state transition functions
// assigned to scanner.state and the method scanner.eof.
// They give details about the current state of the scan that
// callers might be interested to know about.
// It is okay to ignore the return value of any particular
// call to scanner.state: if one call returns scanError,
// every subsequent call will return scanError too.
const (
// Continue.
scanContinue = iota // uninteresting byte
scanBeginLiteral // end implied by next result != scanContinue
scanBeginObject // begin object
scanObjectKey // just finished object key (string)
scanObjectValue // just finished non-last object value
scanEndObject // end object (implies scanObjectValue if possible)
scanBeginArray // begin array
scanArrayValue // just finished array value
scanEndArray // end array (implies scanArrayValue if possible)
scanSkipSpace // space byte; can skip; known to be last "continue" result
// Stop.
scanEnd // top-level value ended *before* this byte; known to be first "stop" result
scanError // hit an error, scanner.err.
)
// These values are stored in the parseState stack.
// They give the current state of a composite value
// being scanned. If the parser is inside a nested value
// the parseState describes the nested state, outermost at entry 0.
const (
parseObjectKey = iota // parsing object key (before colon)
parseObjectValue // parsing object value (after colon)
parseArrayValue // parsing array value
)
// This limits the max nesting depth to prevent stack overflow.
// This is permitted by https://tools.ietf.org/html/rfc7159#section-9
const maxNestingDepth = 10000
// reset prepares the scanner for use.
// It must be called before calling s.step.
func (s *scanner) reset() {
s.step = stateBeginValue
s.parseState = s.parseState[0:0]
s.err = nil
s.endTop = false
}
// eof tells the scanner that the end of input has been reached.
// It returns a scan status just as s.step does.
func (s *scanner) eof() int {
if s.err != nil {
return scanError
}
if s.endTop {
return scanEnd
}
s.step(s, ' ')
if s.endTop {
return scanEnd
}
if s.err == nil {
s.err = &SyntaxError{"unexpected end of JSON input", s.bytes}
}
return scanError
}
// pushParseState pushes a new parse state p onto the parse stack.
// an error state is returned if maxNestingDepth was exceeded, otherwise successState is returned.
func (s *scanner) pushParseState(c byte, newParseState int, successState int) int {
s.parseState = append(s.parseState, newParseState)
if len(s.parseState) <= maxNestingDepth {
return successState
}
return s.error(c, "exceeded max depth")
}
// popParseState pops a parse state (already obtained) off the stack
// and updates s.step accordingly.
func (s *scanner) popParseState() {
n := len(s.parseState) - 1
s.parseState = s.parseState[0:n]
if n == 0 {
s.step = stateEndTop
s.endTop = true
} else {
s.step = stateEndValue
}
}
func isSpace(c byte) bool {
return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n')
}
// stateBeginValueOrEmpty is the state after reading `[`.
func stateBeginValueOrEmpty(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == ']' {
return stateEndValue(s, c)
}
return stateBeginValue(s, c)
}
// stateBeginValue is the state at the beginning of the input.
func stateBeginValue(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
switch c {
case '{':
s.step = stateBeginStringOrEmpty
return s.pushParseState(c, parseObjectKey, scanBeginObject)
case '[':
s.step = stateBeginValueOrEmpty
return s.pushParseState(c, parseArrayValue, scanBeginArray)
case '"':
s.step = stateInString
return scanBeginLiteral
case '-':
s.step = stateNeg
return scanBeginLiteral
case '0': // beginning of 0.123
s.step = state0
return scanBeginLiteral
case 't': // beginning of true
s.step = stateT
return scanBeginLiteral
case 'f': // beginning of false
s.step = stateF
return scanBeginLiteral
case 'n': // beginning of null
s.step = stateN
return scanBeginLiteral
}
if '1' <= c && c <= '9' { // beginning of 1234.5
s.step = state1
return scanBeginLiteral
}
return s.error(c, "looking for beginning of value")
}
// stateBeginStringOrEmpty is the state after reading `{`.
func stateBeginStringOrEmpty(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == '}' {
n := len(s.parseState)
s.parseState[n-1] = parseObjectValue
return stateEndValue(s, c)
}
return stateBeginString(s, c)
}
// stateBeginString is the state after reading `{"key": value,`.
func stateBeginString(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == '"' {
s.step = stateInString
return scanBeginLiteral
}
return s.error(c, "looking for beginning of object key string")
}
// stateEndValue is the state after completing a value,
// such as after reading `{}` or `true` or `["x"`.
func stateEndValue(s *scanner, c byte) int {
n := len(s.parseState)
if n == 0 {
// Completed top-level before the current byte.
s.step = stateEndTop
s.endTop = true
return stateEndTop(s, c)
}
if isSpace(c) {
s.step = stateEndValue
return scanSkipSpace
}
ps := s.parseState[n-1]
switch ps {
case parseObjectKey:
if c == ':' {
s.parseState[n-1] = parseObjectValue
s.step = stateBeginValue
return scanObjectKey
}
return s.error(c, "after object key")
case parseObjectValue:
if c == ',' {
s.parseState[n-1] = parseObjectKey
s.step = stateBeginString
return scanObjectValue
}
if c == '}' {
s.popParseState()
return scanEndObject
}
return s.error(c, "after object key:value pair")
case parseArrayValue:
if c == ',' {
s.step = stateBeginValue
return scanArrayValue
}
if c == ']' {
s.popParseState()
return scanEndArray
}
return s.error(c, "after array element")
}
return s.error(c, "")
}
// stateEndTop is the state after finishing the top-level value,
// such as after reading `{}` or `[1,2,3]`.
// Only space characters should be seen now.
func stateEndTop(s *scanner, c byte) int {
if !isSpace(c) {
// Complain about non-space byte on next call.
s.error(c, "after top-level value")
}
return scanEnd
}
// stateInString is the state after reading `"`.
func stateInString(s *scanner, c byte) int {
if c == '"' {
s.step = stateEndValue
return scanContinue
}
if c == '\\' {
s.step = stateInStringEsc
return scanContinue
}
if c < 0x20 {
return s.error(c, "in string literal")
}
return scanContinue
}
// stateInStringEsc is the state after reading `"\` during a quoted string.
func stateInStringEsc(s *scanner, c byte) int {
switch c {
case 'b', 'f', 'n', 'r', 't', '\\', '/', '"':
s.step = stateInString
return scanContinue
case 'u':
s.step = stateInStringEscU
return scanContinue
}
return s.error(c, "in string escape code")
}
// stateInStringEscU is the state after reading `"\u` during a quoted string.
func stateInStringEscU(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU1
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU1 is the state after reading `"\u1` during a quoted string.
func stateInStringEscU1(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU12
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU12 is the state after reading `"\u12` during a quoted string.
func stateInStringEscU12(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU123
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU123 is the state after reading `"\u123` during a quoted string.
func stateInStringEscU123(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInString
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateNeg is the state after reading `-` during a number.
func stateNeg(s *scanner, c byte) int {
if c == '0' {
s.step = state0
return scanContinue
}
if '1' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return s.error(c, "in numeric literal")
}
// state1 is the state after reading a non-zero integer during a number,
// such as after reading `1` or `100` but not `0`.
func state1(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return state0(s, c)
}
// state0 is the state after reading `0` during a number.
func state0(s *scanner, c byte) int {
if c == '.' {
s.step = stateDot
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateDot is the state after reading the integer and decimal point in a number,
// such as after reading `1.`.
func stateDot(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateDot0
return scanContinue
}
return s.error(c, "after decimal point in numeric literal")
}
// stateDot0 is the state after reading the integer, decimal point, and subsequent
// digits of a number, such as after reading `3.14`.
func stateDot0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateE is the state after reading the mantissa and e in a number,
// such as after reading `314e` or `0.314e`.
func stateE(s *scanner, c byte) int {
if c == '+' || c == '-' {
s.step = stateESign
return scanContinue
}
return stateESign(s, c)
}
// stateESign is the state after reading the mantissa, e, and sign in a number,
// such as after reading `314e-` or `0.314e+`.
func stateESign(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateE0
return scanContinue
}
return s.error(c, "in exponent of numeric literal")
}
// stateE0 is the state after reading the mantissa, e, optional sign,
// and at least one digit of the exponent in a number,
// such as after reading `314e-2` or `0.314e+1` or `3.14e0`.
func stateE0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
return stateEndValue(s, c)
}
// stateT is the state after reading `t`.
func stateT(s *scanner, c byte) int {
if c == 'r' {
s.step = stateTr
return scanContinue
}
return s.error(c, "in literal true (expecting 'r')")
}
// stateTr is the state after reading `tr`.
func stateTr(s *scanner, c byte) int {
if c == 'u' {
s.step = stateTru
return scanContinue
}
return s.error(c, "in literal true (expecting 'u')")
}
// stateTru is the state after reading `tru`.
func stateTru(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal true (expecting 'e')")
}
// stateF is the state after reading `f`.
func stateF(s *scanner, c byte) int {
if c == 'a' {
s.step = stateFa
return scanContinue
}
return s.error(c, "in literal false (expecting 'a')")
}
// stateFa is the state after reading `fa`.
func stateFa(s *scanner, c byte) int {
if c == 'l' {
s.step = stateFal
return scanContinue
}
return s.error(c, "in literal false (expecting 'l')")
}
// stateFal is the state after reading `fal`.
func stateFal(s *scanner, c byte) int {
if c == 's' {
s.step = stateFals
return scanContinue
}
return s.error(c, "in literal false (expecting 's')")
}
// stateFals is the state after reading `fals`.
func stateFals(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal false (expecting 'e')")
}
// stateN is the state after reading `n`.
func stateN(s *scanner, c byte) int {
if c == 'u' {
s.step = stateNu
return scanContinue
}
return s.error(c, "in literal null (expecting 'u')")
}
// stateNu is the state after reading `nu`.
func stateNu(s *scanner, c byte) int {
if c == 'l' {
s.step = stateNul
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateNul is the state after reading `nul`.
func stateNul(s *scanner, c byte) int {
if c == 'l' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateError is the state after reaching a syntax error,
// such as after reading `[1}` or `5.1.2`.
func stateError(s *scanner, c byte) int {
return scanError
}
// error records an error and switches to the error state.
func (s *scanner) error(c byte, context string) int {
s.step = stateError
s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes}
return scanError
}
// quoteChar formats c as a quoted character literal.
func quoteChar(c byte) string {
// special cases - different from quoted strings
if c == '\'' {
return `'\''`
}
if c == '"' {
return `'"'`
}
// use quoted string with different quotation marks
s := strconv.Quote(string(c))
return "'" + s[1:len(s)-1] + "'"
}

View file

@ -1,517 +0,0 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"errors"
"io"
)
// A Decoder reads and decodes JSON values from an input stream.
type Decoder struct {
r io.Reader
buf []byte
d decodeState
scanp int // start of unread data in buf
scanned int64 // amount of data already scanned
scan scanner
err error
tokenState int
tokenStack []int
}
// NewDecoder returns a new decoder that reads from r.
//
// The decoder introduces its own buffering and may
// read data from r beyond the JSON values requested.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
// Number instead of as a float64.
func (dec *Decoder) UseNumber() { dec.d.useNumber = true }
// DisallowUnknownFields causes the Decoder to return an error when the destination
// is a struct and the input contains object keys which do not match any
// non-ignored, exported fields in the destination.
func (dec *Decoder) DisallowUnknownFields() { dec.d.disallowUnknownFields = true }
// Decode reads the next JSON-encoded value from its
// input and stores it in the value pointed to by v.
//
// See the documentation for Unmarshal for details about
// the conversion of JSON into a Go value.
func (dec *Decoder) Decode(v any) error {
if dec.err != nil {
return dec.err
}
if err := dec.tokenPrepareForDecode(); err != nil {
return err
}
if !dec.tokenValueAllowed() {
return &SyntaxError{msg: "not at beginning of value", Offset: dec.InputOffset()}
}
// Read whole value into buffer.
n, err := dec.readValue()
if err != nil {
return err
}
dec.d.init(dec.buf[dec.scanp : dec.scanp+n])
dec.scanp += n
// Don't save err from unmarshal into dec.err:
// the connection is still usable since we read a complete JSON
// object from it before the error happened.
err = dec.d.unmarshal(v)
// fixup token streaming state
dec.tokenValueEnd()
return err
}
// Buffered returns a reader of the data remaining in the Decoder's
// buffer. The reader is valid until the next call to Decode.
func (dec *Decoder) Buffered() io.Reader {
return bytes.NewReader(dec.buf[dec.scanp:])
}
// readValue reads a JSON value into dec.buf.
// It returns the length of the encoding.
func (dec *Decoder) readValue() (int, error) {
dec.scan.reset()
scanp := dec.scanp
var err error
Input:
// help the compiler see that scanp is never negative, so it can remove
// some bounds checks below.
for scanp >= 0 {
// Look in the buffer for a new value.
for ; scanp < len(dec.buf); scanp++ {
c := dec.buf[scanp]
dec.scan.bytes++
switch dec.scan.step(&dec.scan, c) {
case scanEnd:
// scanEnd is delayed one byte so we decrement
// the scanner bytes count by 1 to ensure that
// this value is correct in the next call of Decode.
dec.scan.bytes--
break Input
case scanEndObject, scanEndArray:
// scanEnd is delayed one byte.
// We might block trying to get that byte from src,
// so instead invent a space byte.
if stateEndValue(&dec.scan, ' ') == scanEnd {
scanp++
break Input
}
case scanError:
dec.err = dec.scan.err
return 0, dec.scan.err
}
}
// Did the last read have an error?
// Delayed until now to allow buffer scan.
if err != nil {
if err == io.EOF {
if dec.scan.step(&dec.scan, ' ') == scanEnd {
break Input
}
if nonSpace(dec.buf) {
err = io.ErrUnexpectedEOF
}
}
dec.err = err
return 0, err
}
n := scanp - dec.scanp
err = dec.refill()
scanp = dec.scanp + n
}
return scanp - dec.scanp, nil
}
func (dec *Decoder) refill() error {
// Make room to read more into the buffer.
// First slide down data already consumed.
if dec.scanp > 0 {
dec.scanned += int64(dec.scanp)
n := copy(dec.buf, dec.buf[dec.scanp:])
dec.buf = dec.buf[:n]
dec.scanp = 0
}
// Grow buffer if not large enough.
const minRead = 512
if cap(dec.buf)-len(dec.buf) < minRead {
newBuf := make([]byte, len(dec.buf), 2*cap(dec.buf)+minRead)
copy(newBuf, dec.buf)
dec.buf = newBuf
}
// Read. Delay error for next iteration (after scan).
n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
dec.buf = dec.buf[0 : len(dec.buf)+n]
return err
}
func nonSpace(b []byte) bool {
for _, c := range b {
if !isSpace(c) {
return true
}
}
return false
}
// An Encoder writes JSON values to an output stream.
type Encoder struct {
w io.Writer
err error
escapeHTML bool
indentBuf *bytes.Buffer
indentPrefix string
indentValue string
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w, escapeHTML: true}
}
// Encode writes the JSON encoding of v to the stream,
// followed by a newline character.
//
// See the documentation for Marshal for details about the
// conversion of Go values to JSON.
func (enc *Encoder) Encode(v any) error {
if enc.err != nil {
return enc.err
}
e := newEncodeState()
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
if err != nil {
return err
}
// Terminate each value with a newline.
// This makes the output look a little nicer
// when debugging, and some kind of space
// is required if the encoded value was a number,
// so that the reader knows there aren't more
// digits coming.
e.WriteByte('\n')
b := e.Bytes()
if enc.indentPrefix != "" || enc.indentValue != "" {
if enc.indentBuf == nil {
enc.indentBuf = new(bytes.Buffer)
}
enc.indentBuf.Reset()
err = Indent(enc.indentBuf, b, enc.indentPrefix, enc.indentValue)
if err != nil {
return err
}
b = enc.indentBuf.Bytes()
}
if _, err = enc.w.Write(b); err != nil {
enc.err = err
}
return err
}
// SetIndent instructs the encoder to format each subsequent encoded
// value as if indented by the package-level function Indent(dst, src, prefix, indent).
// Calling SetIndent("", "") disables indentation.
func (enc *Encoder) SetIndent(prefix, indent string) {
enc.indentPrefix = prefix
enc.indentValue = indent
}
// SetEscapeHTML specifies whether problematic HTML characters
// should be escaped inside JSON quoted strings.
// The default behavior is to escape &, <, and > to \u0026, \u003c, and \u003e
// to avoid certain safety problems that can arise when embedding JSON in HTML.
//
// In non-HTML settings where the escaping interferes with the readability
// of the output, SetEscapeHTML(false) disables this behavior.
func (enc *Encoder) SetEscapeHTML(on bool) {
enc.escapeHTML = on
}
// RawMessage is a raw encoded JSON value.
// It implements Marshaler and Unmarshaler and can
// be used to delay JSON decoding or precompute a JSON encoding.
type RawMessage []byte
// MarshalJSON returns m as the JSON encoding of m.
func (m RawMessage) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("null"), nil
}
return m, nil
}
// UnmarshalJSON sets *m to a copy of data.
func (m *RawMessage) UnmarshalJSON(data []byte) error {
if m == nil {
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
}
*m = append((*m)[0:0], data...)
return nil
}
var (
_ Marshaler = (*RawMessage)(nil)
_ Unmarshaler = (*RawMessage)(nil)
)
// A Token holds a value of one of these types:
//
// Delim, for the four JSON delimiters [ ] { }
// bool, for JSON booleans
// float64, for JSON numbers
// Number, for JSON numbers
// string, for JSON string literals
// nil, for JSON null
type Token any
const (
tokenTopValue = iota
tokenArrayStart
tokenArrayValue
tokenArrayComma
tokenObjectStart
tokenObjectKey
tokenObjectColon
tokenObjectValue
tokenObjectComma
)
// advance tokenstate from a separator state to a value state
func (dec *Decoder) tokenPrepareForDecode() error {
// Note: Not calling peek before switch, to avoid
// putting peek into the standard Decode path.
// peek is only called when using the Token API.
switch dec.tokenState {
case tokenArrayComma:
c, err := dec.peek()
if err != nil {
return err
}
if c != ',' {
return &SyntaxError{"expected comma after array element", dec.InputOffset()}
}
dec.scanp++
dec.tokenState = tokenArrayValue
case tokenObjectColon:
c, err := dec.peek()
if err != nil {
return err
}
if c != ':' {
return &SyntaxError{"expected colon after object key", dec.InputOffset()}
}
dec.scanp++
dec.tokenState = tokenObjectValue
}
return nil
}
func (dec *Decoder) tokenValueAllowed() bool {
switch dec.tokenState {
case tokenTopValue, tokenArrayStart, tokenArrayValue, tokenObjectValue:
return true
}
return false
}
func (dec *Decoder) tokenValueEnd() {
switch dec.tokenState {
case tokenArrayStart, tokenArrayValue:
dec.tokenState = tokenArrayComma
case tokenObjectValue:
dec.tokenState = tokenObjectComma
}
}
// A Delim is a JSON array or object delimiter, one of [ ] { or }.
type Delim rune
func (d Delim) String() string {
return string(d)
}
// Token returns the next JSON token in the input stream.
// At the end of the input stream, Token returns nil, io.EOF.
//
// Token guarantees that the delimiters [ ] { } it returns are
// properly nested and matched: if Token encounters an unexpected
// delimiter in the input, it will return an error.
//
// The input stream consists of basic JSON values—bool, string,
// number, and null—along with delimiters [ ] { } of type Delim
// to mark the start and end of arrays and objects.
// Commas and colons are elided.
func (dec *Decoder) Token() (Token, error) {
for {
c, err := dec.peek()
if err != nil {
return nil, err
}
switch c {
case '[':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenArrayStart
return Delim('['), nil
case ']':
if dec.tokenState != tokenArrayStart && dec.tokenState != tokenArrayComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim(']'), nil
case '{':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenObjectStart
return Delim('{'), nil
case '}':
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim('}'), nil
case ':':
if dec.tokenState != tokenObjectColon {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = tokenObjectValue
continue
case ',':
if dec.tokenState == tokenArrayComma {
dec.scanp++
dec.tokenState = tokenArrayValue
continue
}
if dec.tokenState == tokenObjectComma {
dec.scanp++
dec.tokenState = tokenObjectKey
continue
}
return dec.tokenError(c)
case '"':
if dec.tokenState == tokenObjectStart || dec.tokenState == tokenObjectKey {
var x string
old := dec.tokenState
dec.tokenState = tokenTopValue
err := dec.Decode(&x)
dec.tokenState = old
if err != nil {
return nil, err
}
dec.tokenState = tokenObjectColon
return x, nil
}
fallthrough
default:
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
var x any
if err := dec.Decode(&x); err != nil {
return nil, err
}
return x, nil
}
}
}
func (dec *Decoder) tokenError(c byte) (Token, error) {
var context string
switch dec.tokenState {
case tokenTopValue:
context = " looking for beginning of value"
case tokenArrayStart, tokenArrayValue, tokenObjectValue:
context = " looking for beginning of value"
case tokenArrayComma:
context = " after array element"
case tokenObjectKey:
context = " looking for beginning of object key string"
case tokenObjectColon:
context = " after object key"
case tokenObjectComma:
context = " after object key:value pair"
}
return nil, &SyntaxError{"invalid character " + quoteChar(c) + context, dec.InputOffset()}
}
// More reports whether there is another element in the
// current array or object being parsed.
func (dec *Decoder) More() bool {
c, err := dec.peek()
return err == nil && c != ']' && c != '}'
}
func (dec *Decoder) peek() (byte, error) {
var err error
for {
for i := dec.scanp; i < len(dec.buf); i++ {
c := dec.buf[i]
if isSpace(c) {
continue
}
dec.scanp = i
return c, nil
}
// buffer has been scanned, now report any error
if err != nil {
return 0, err
}
err = dec.refill()
}
}
// InputOffset returns the input stream byte offset of the current decoder position.
// The offset gives the location of the end of the most recently returned token
// and the beginning of the next token.
func (dec *Decoder) InputOffset() int64 {
return dec.scanned + int64(dec.scanp)
}

View file

@ -1,218 +0,0 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}

View file

@ -1,38 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"strings"
)
// tagOptions is the string following a comma in a struct field's "json"
// tag, or the empty string. It does not include the leading comma.
type tagOptions string
// parseTag splits a struct field's json tag into its name and
// comma-separated options.
func parseTag(tag string) (string, tagOptions) {
tag, opt, _ := strings.Cut(tag, ",")
return tag, tagOptions(opt)
}
// Contains reports whether a comma-separated list of options
// contains a particular substr flag. substr must be surrounded by a
// string boundary or commas.
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var name string
name, s, _ = strings.Cut(s, ",")
if name == optionName {
return true
}
}
return false
}

Some files were not shown because too many files have changed in this diff Show more