binary: Move varint utils to new package

This commit is contained in:
世界 2024-06-23 15:51:39 +08:00
parent a31dba8ad2
commit caa4340dc9
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
9 changed files with 180 additions and 89 deletions

18
common/binary/export.go Normal file
View file

@ -0,0 +1,18 @@
package binary
import (
"encoding/binary"
"reflect"
)
func DataSize(t reflect.Value) int {
return dataSize(t)
}
func EncodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
(&encoder{order: order, buf: buf}).value(v)
}
func DecodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) {
(&decoder{order: order, buf: buf}).value(v)
}

View file

@ -4,6 +4,7 @@ import (
"io" "io"
"sync" "sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
) )
@ -41,6 +42,25 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) {
} }
} }
func (w *BufferedWriter) WriteByte(c byte) error {
w.access.Lock()
defer w.access.Unlock()
if w.buffer == nil {
return common.Error(w.upstream.Write([]byte{c}))
}
for {
err := w.buffer.WriteByte(c)
if err == nil {
return nil
}
_, err = w.upstream.Write(w.buffer.Bytes())
if err != nil {
return err
}
w.buffer.Reset()
}
}
func (w *BufferedWriter) Fallthrough() error { func (w *BufferedWriter) Fallthrough() error {
w.access.Lock() w.access.Lock()
defer w.access.Unlock() defer w.access.Unlock()

View file

@ -12,7 +12,6 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/task"
) )
@ -163,11 +162,11 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error { func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
var group task.Group var group task.Group
if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex { if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex {
group.Append("upload", func(ctx context.Context) error { group.Append("upload", func(ctx context.Context) error {
err := common.Error(Copy(destination, source)) err := common.Error(Copy(destination, source))
if err == nil { if err == nil {
rw.CloseWrite(destination) N.CloseWrite(destination)
} else { } else {
common.Close(destination) common.Close(destination)
} }
@ -179,11 +178,11 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
return common.Error(Copy(destination, source)) return common.Error(Copy(destination, source))
}) })
} }
if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex { if _, srcDuplex := common.Cast[N.WriteCloser](source); srcDuplex {
group.Append("download", func(ctx context.Context) error { group.Append("download", func(ctx context.Context) error {
err := common.Error(Copy(source, destination)) err := common.Error(Copy(source, destination))
if err == nil { if err == nil {
rw.CloseWrite(source) N.CloseWrite(source)
} else { } else {
common.Close(source) common.Close(source)
} }

View file

@ -1,4 +1,4 @@
package rw package network
import ( import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"

View file

@ -1,57 +1,61 @@
package binary package varbin
import ( import (
"errors" "errors"
"io" "io"
"reflect" "reflect"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/binary"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
type Reader interface { func Read(r io.Reader, order binary.ByteOrder, rawData any) error {
io.Reader reader := StubReader(r)
io.ByteReader
}
type Writer interface {
io.Writer
io.ByteWriter
}
func ReadData(r Reader, order ByteOrder, rawData any) error {
switch data := rawData.(type) { switch data := rawData.(type) {
case *[]bool: case *[]bool:
return readBaseData(r, order, data) return readBase(reader, order, data)
case *[]int8: case *[]int8:
return readBaseData(r, order, data) return readBase(reader, order, data)
case *[]uint8: case *[]uint8:
return readBaseData(r, order, data) return readBase(reader, order, data)
case *[]int16: case *[]int16:
return readBaseData(r, order, data) return readBase(reader, order, data)
case *[]uint16: case *[]uint16:
return readBaseData(r, order, data) return readBase(reader, order, data)
case *[]int32: case *[]int32:
return readBaseData(r, order, data) return readBase(reader, order, data)
case *[]uint32: case *[]uint32:
return readBaseData(r, order, data) return readBase(reader, order, data)
case *[]int64: case *[]int64:
return readBaseData(r, order, data) return readBase(reader, order, data)
case *[]uint64: case *[]uint64:
return readBaseData(r, order, data) return readBase(reader, order, data)
case *[]float32: case *[]float32:
return readBaseData(r, order, data) return readBase(reader, order, data)
case *[]float64: case *[]float64:
return readBaseData(r, order, data) return readBase(reader, order, data)
default: default:
if intBaseDataSize(rawData) != 0 { if intBaseDataSize(rawData) != 0 {
return Read(r, order, rawData) return binary.Read(reader, order, rawData)
} }
} }
return readData(r, order, reflect.Indirect(reflect.ValueOf(rawData))) return read(reader, order, reflect.Indirect(reflect.ValueOf(rawData)))
} }
func readBaseData[T any](r Reader, order ByteOrder, data *[]T) error { func ReadValue[T any](r io.Reader, order binary.ByteOrder) (T, error) {
dataLen, err := ReadUvarint(r) var value T
err := Read(r, order, &value)
if err != nil {
return common.DefaultValue[T](), err
}
return value, nil
}
func readBase[T any](r Reader, order binary.ByteOrder, data *[]T) error {
dataLen, err := binary.ReadUvarint(r)
if err != nil { if err != nil {
return E.Cause(err, "slice length") return E.Cause(err, "slice length")
} }
@ -60,7 +64,7 @@ func readBaseData[T any](r Reader, order ByteOrder, data *[]T) error {
return nil return nil
} }
dataSlices := make([]T, dataLen) dataSlices := make([]T, dataLen)
err = Read(r, order, dataSlices) err = binary.Read(r, order, dataSlices)
if err != nil { if err != nil {
return err return err
} }
@ -68,7 +72,7 @@ func readBaseData[T any](r Reader, order ByteOrder, data *[]T) error {
return nil return nil
} }
func readData(r Reader, order ByteOrder, data reflect.Value) error { func read(r Reader, order binary.ByteOrder, data reflect.Value) error {
switch data.Kind() { switch data.Kind() {
case reflect.Pointer: case reflect.Pointer:
pointerValue, err := r.ReadByte() pointerValue, err := r.ReadByte()
@ -82,9 +86,9 @@ func readData(r Reader, order ByteOrder, data reflect.Value) error {
if data.IsNil() { if data.IsNil() {
data.Set(reflect.New(data.Type().Elem())) data.Set(reflect.New(data.Type().Elem()))
} }
return readData(r, order, data.Elem()) return read(r, order, data.Elem())
case reflect.String: case reflect.String:
stringLength, err := ReadUvarint(r) stringLength, err := binary.ReadUvarint(r)
if err != nil { if err != nil {
return E.Cause(err, "string length") return E.Cause(err, "string length")
} }
@ -100,25 +104,24 @@ func readData(r Reader, order ByteOrder, data reflect.Value) error {
} }
case reflect.Array: case reflect.Array:
arrayLen := data.Len() arrayLen := data.Len()
itemSize := sizeof(data.Type()) itemSize := data.Type().Elem().Len()
if itemSize > 0 { if itemSize > 0 {
buf := make([]byte, itemSize*arrayLen) buf := make([]byte, itemSize*arrayLen)
_, err := io.ReadFull(r, buf) _, err := io.ReadFull(r, buf)
if err != nil { if err != nil {
return err return err
} }
d := &decoder{order: order, buf: buf} binary.DecodeValue(order, buf, data)
d.value(data)
} else { } else {
for i := 0; i < arrayLen; i++ { for i := 0; i < arrayLen; i++ {
err := readData(r, order, data.Index(i)) err := read(r, order, data.Index(i))
if err != nil { if err != nil {
return E.Cause(err, "[", i, "]") return E.Cause(err, "[", i, "]")
} }
} }
} }
case reflect.Slice: case reflect.Slice:
sliceLength, err := ReadUvarint(r) sliceLength, err := binary.ReadUvarint(r)
if err != nil { if err != nil {
return E.Cause(err, "slice length") return E.Cause(err, "slice length")
} }
@ -127,7 +130,7 @@ func readData(r Reader, order ByteOrder, data reflect.Value) error {
} else { } else {
dataSlices := makeBaseDataSlices(data, int(sliceLength)) dataSlices := makeBaseDataSlices(data, int(sliceLength))
if dataSlices != nil { if dataSlices != nil {
err = Read(r, order, dataSlices) err = binary.Read(r, order, dataSlices)
if err != nil { if err != nil {
return err return err
} }
@ -139,7 +142,7 @@ func readData(r Reader, order ByteOrder, data reflect.Value) error {
data.Set(reflect.MakeSlice(data.Type(), int(sliceLength), int(sliceLength))) data.Set(reflect.MakeSlice(data.Type(), int(sliceLength), int(sliceLength)))
} }
for i := 0; i < int(sliceLength); i++ { for i := 0; i < int(sliceLength); i++ {
err = readData(r, order, data.Index(i)) err = read(r, order, data.Index(i))
if err != nil { if err != nil {
return E.Cause(err, "[", i, "]") return E.Cause(err, "[", i, "]")
} }
@ -147,19 +150,19 @@ func readData(r Reader, order ByteOrder, data reflect.Value) error {
} }
} }
case reflect.Map: case reflect.Map:
mapLength, err := ReadUvarint(r) mapLength, err := binary.ReadUvarint(r)
if err != nil { if err != nil {
return E.Cause(err, "map length") return E.Cause(err, "map length")
} }
data.Set(reflect.MakeMap(data.Type())) data.Set(reflect.MakeMap(data.Type()))
for index := 0; index < int(mapLength); index++ { for index := 0; index < int(mapLength); index++ {
key := reflect.New(data.Type().Key()).Elem() key := reflect.New(data.Type().Key()).Elem()
err = readData(r, order, key) err = read(r, order, key)
if err != nil { if err != nil {
return E.Cause(err, "[", index, "].key") return E.Cause(err, "[", index, "].key")
} }
value := reflect.New(data.Type().Elem()).Elem() value := reflect.New(data.Type().Elem()).Elem()
err = readData(r, order, value) err = read(r, order, value)
if err != nil { if err != nil {
return E.Cause(err, "[", index, "].value") return E.Cause(err, "[", index, "].value")
} }
@ -172,71 +175,90 @@ func readData(r Reader, order ByteOrder, data reflect.Value) error {
field := data.Field(i) field := data.Field(i)
fieldName := fieldType.Field(i).Name fieldName := fieldType.Field(i).Name
if field.CanSet() || fieldName != "_" { if field.CanSet() || fieldName != "_" {
err := readData(r, order, field) err := read(r, order, field)
if err != nil { if err != nil {
return E.Cause(err, fieldName) return E.Cause(err, fieldName)
} }
} }
} }
default: default:
size := dataSize(data) size := binary.DataSize(data)
if size < 0 { if size < 0 {
return errors.New("invalid type " + reflect.TypeOf(data).String()) return errors.New("invalid type " + reflect.TypeOf(data).String())
} }
d := &decoder{order: order, buf: make([]byte, size)} buf := make([]byte, size)
_, err := io.ReadFull(r, d.buf) _, err := io.ReadFull(r, buf)
if err != nil { if err != nil {
return err return err
} }
d.value(data) binary.DecodeValue(order, buf, data)
} }
return nil return nil
} }
func WriteData(writer Writer, order ByteOrder, rawData any) error { func Write(w io.Writer, order binary.ByteOrder, rawData any) error {
if intBaseDataSize(rawData) != 0 {
return binary.Write(w, order, rawData)
}
var (
writer Writer
bufferedWriter *bufio.BufferedWriter
)
if bw, ok := w.(Writer); ok {
writer = bw
} else {
bufferedWriter = bufio.NewBufferedWriter(w, buf.NewSize(1024))
writer = bufferedWriter
}
switch data := rawData.(type) { switch data := rawData.(type) {
case []bool: case []bool:
return writeBaseData(writer, order, data) return writeBase(writer, order, data)
case []int8: case []int8:
return writeBaseData(writer, order, data) return writeBase(writer, order, data)
case []uint8: case []uint8:
return writeBaseData(writer, order, data) return writeBase(writer, order, data)
case []int16: case []int16:
return writeBaseData(writer, order, data) return writeBase(writer, order, data)
case []uint16: case []uint16:
return writeBaseData(writer, order, data) return writeBase(writer, order, data)
case []int32: case []int32:
return writeBaseData(writer, order, data) return writeBase(writer, order, data)
case []uint32: case []uint32:
return writeBaseData(writer, order, data) return writeBase(writer, order, data)
case []int64: case []int64:
return writeBaseData(writer, order, data) return writeBase(writer, order, data)
case []uint64: case []uint64:
return writeBaseData(writer, order, data) return writeBase(writer, order, data)
case []float32: case []float32:
return writeBaseData(writer, order, data) return writeBase(writer, order, data)
case []float64: case []float64:
return writeBaseData(writer, order, data) return writeBase(writer, order, data)
default: }
if intBaseDataSize(rawData) != 0 { err := write(writer, order, reflect.Indirect(reflect.ValueOf(rawData)))
return Write(writer, order, rawData) if err != nil {
return err
}
if bufferedWriter != nil {
err = bufferedWriter.Fallthrough()
if err != nil {
return err
} }
} }
return writeData(writer, order, reflect.Indirect(reflect.ValueOf(rawData))) return nil
} }
func writeBaseData[T any](writer Writer, order ByteOrder, data []T) error { func writeBase[T any](writer Writer, order binary.ByteOrder, data []T) error {
_, err := WriteUvarint(writer, uint64(len(data))) _, err := WriteUvarint(writer, uint64(len(data)))
if err != nil { if err != nil {
return err return err
} }
if len(data) > 0 { if len(data) > 0 {
return Write(writer, order, data) return binary.Write(writer, order, data)
} }
return nil return nil
} }
func writeData(writer Writer, order ByteOrder, data reflect.Value) error { func write(writer Writer, order binary.ByteOrder, data reflect.Value) error {
switch data.Kind() { switch data.Kind() {
case reflect.Pointer: case reflect.Pointer:
if data.IsNil() { if data.IsNil() {
@ -249,7 +271,7 @@ func writeData(writer Writer, order ByteOrder, data reflect.Value) error {
if err != nil { if err != nil {
return err return err
} }
return writeData(writer, order, data.Elem()) return write(writer, order, data.Elem())
} }
case reflect.String: case reflect.String:
stringValue := data.String() stringValue := data.String()
@ -269,15 +291,14 @@ func writeData(writer Writer, order ByteOrder, data reflect.Value) error {
itemSize := intItemBaseDataSize(data) itemSize := intItemBaseDataSize(data)
if itemSize > 0 { if itemSize > 0 {
buf := make([]byte, itemSize*dataLen) buf := make([]byte, itemSize*dataLen)
e := &encoder{order: order, buf: buf} binary.EncodeValue(order, buf, data)
e.value(data)
_, err := writer.Write(buf) _, err := writer.Write(buf)
if err != nil { if err != nil {
return E.Cause(err, reflect.TypeOf(data).String()) return E.Cause(err, reflect.TypeOf(data).String())
} }
} else { } else {
for i := 0; i < dataLen; i++ { for i := 0; i < dataLen; i++ {
err := writeData(writer, order, data.Index(i)) err := write(writer, order, data.Index(i))
if err != nil { if err != nil {
return E.Cause(err, "[", i, "]") return E.Cause(err, "[", i, "]")
} }
@ -293,13 +314,13 @@ func writeData(writer Writer, order ByteOrder, data reflect.Value) error {
if dataLen > 0 { if dataLen > 0 {
dataSlices := baseDataSlices(data) dataSlices := baseDataSlices(data)
if dataSlices != nil { if dataSlices != nil {
err = Write(writer, order, dataSlices) err = binary.Write(writer, order, dataSlices)
if err != nil { if err != nil {
return err return err
} }
} else { } else {
for i := 0; i < dataLen; i++ { for i := 0; i < dataLen; i++ {
err = writeData(writer, order, data.Index(i)) err = write(writer, order, data.Index(i))
if err != nil { if err != nil {
return E.Cause(err, "[", i, "]") return E.Cause(err, "[", i, "]")
} }
@ -314,11 +335,11 @@ func writeData(writer Writer, order ByteOrder, data reflect.Value) error {
} }
if dataLen > 0 { if dataLen > 0 {
for index, key := range data.MapKeys() { for index, key := range data.MapKeys() {
err = writeData(writer, order, key) err = write(writer, order, key)
if err != nil { if err != nil {
return E.Cause(err, "[", index, "].key") return E.Cause(err, "[", index, "].key")
} }
err = writeData(writer, order, data.MapIndex(key)) err = write(writer, order, data.MapIndex(key))
if err != nil { if err != nil {
return E.Cause(err, "[", index, "].value") return E.Cause(err, "[", index, "].value")
} }
@ -331,20 +352,19 @@ func writeData(writer Writer, order ByteOrder, data reflect.Value) error {
field := data.Field(i) field := data.Field(i)
fieldName := fieldType.Field(i).Name fieldName := fieldType.Field(i).Name
if field.CanSet() || fieldName != "_" { if field.CanSet() || fieldName != "_" {
err := writeData(writer, order, field) err := write(writer, order, field)
if err != nil { if err != nil {
return E.Cause(err, fieldName) return E.Cause(err, fieldName)
} }
} }
} }
default: default:
size := dataSize(data) size := binary.DataSize(data)
if size < 0 { if size < 0 {
return errors.New("binary.Write: some values are not fixed-sized in type " + data.Type().String()) return errors.New("binary.Write: some values are not fixed-sized in type " + data.Type().String())
} }
buf := make([]byte, size) buf := make([]byte, size)
e := &encoder{order: order, buf: buf} binary.EncodeValue(order, buf, data)
e.value(data)
_, err := writer.Write(buf) _, err := writer.Write(buf)
if err != nil { if err != nil {
return E.Cause(err, reflect.TypeOf(data).String()) return E.Cause(err, reflect.TypeOf(data).String())

34
common/varbin/data_if.go Normal file
View file

@ -0,0 +1,34 @@
package varbin
import (
"io"
)
type Reader interface {
io.Reader
io.ByteReader
}
type Writer interface {
io.Writer
io.ByteWriter
}
var _ Reader = stubReader{}
func StubReader(reader io.Reader) Reader {
if r, ok := reader.(Reader); ok {
return r
}
return stubReader{reader}
}
type stubReader struct {
io.Reader
}
func (r stubReader) ReadByte() (byte, error) {
var b [1]byte
_, err := r.Read(b[:])
return b[0], err
}

View file

@ -1,4 +1,4 @@
package binary package varbin
import "io" import "io"

View file

@ -1,4 +1,4 @@
package binary package varbin
import ( import (
"reflect" "reflect"

View file

@ -1,7 +1,7 @@
package binary package varbin
import ( import (
"math/rand/v2" "math/rand"
"reflect" "reflect"
"testing" "testing"
@ -11,7 +11,7 @@ import (
func TestSlicesValue(t *testing.T) { func TestSlicesValue(t *testing.T) {
int64Arr := make([]int64, 64) int64Arr := make([]int64, 64)
for i := range int64Arr { for i := range int64Arr {
int64Arr[i] = rand.Int64() int64Arr[i] = rand.Int63()
} }
require.Equal(t, int64Arr, slicesValue[int64](reflect.ValueOf(int64Arr))) require.Equal(t, int64Arr, slicesValue[int64](reflect.ValueOf(int64Arr)))
require.Equal(t, int64Arr, baseDataSlices(reflect.ValueOf(int64Arr))) require.Equal(t, int64Arr, baseDataSlices(reflect.ValueOf(int64Arr)))
@ -22,7 +22,7 @@ func TestSetSliceValue(t *testing.T) {
value := reflect.Indirect(reflect.ValueOf(&int64Arr)) value := reflect.Indirect(reflect.ValueOf(&int64Arr))
newInt64Arr := make([]int64, 64) newInt64Arr := make([]int64, 64)
for i := range newInt64Arr { for i := range newInt64Arr {
newInt64Arr[i] = rand.Int64() newInt64Arr[i] = rand.Int63()
} }
setSliceValue[int64](value, newInt64Arr) setSliceValue[int64](value, newInt64Arr)
require.Equal(t, newInt64Arr, slicesValue[int64](value)) require.Equal(t, newInt64Arr, slicesValue[int64](value))