From caa4340dc9b89b1284edda63268d9f7136a5b76b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 23 Jun 2024 15:51:39 +0800 Subject: [PATCH] binary: Move varint utils to new package --- common/binary/export.go | 18 ++ common/bufio/buffer.go | 20 ++ common/bufio/copy.go | 9 +- common/{rw => network}/duplex.go | 2 +- .../{binary/varint_data.go => varbin/data.go} | 174 ++++++++++-------- common/varbin/data_if.go | 34 ++++ .../varint_write.go => varbin/uvarint.go} | 2 +- .../value_slices_unsafe.go} | 2 +- .../value_slices_unsafe_test.go} | 8 +- 9 files changed, 180 insertions(+), 89 deletions(-) create mode 100644 common/binary/export.go rename common/{rw => network}/duplex.go (96%) rename common/{binary/varint_data.go => varbin/data.go} (64%) create mode 100644 common/varbin/data_if.go rename common/{binary/varint_write.go => varbin/uvarint.go} (97%) rename common/{binary/varint_unsafe.go => varbin/value_slices_unsafe.go} (99%) rename common/{binary/varint_unsafe_test.go => varbin/value_slices_unsafe_test.go} (90%) diff --git a/common/binary/export.go b/common/binary/export.go new file mode 100644 index 0000000..cdccbeb --- /dev/null +++ b/common/binary/export.go @@ -0,0 +1,18 @@ +package binary + +import ( + "encoding/binary" + "reflect" +) + +func DataSize(t reflect.Value) int { + return dataSize(t) +} + +func EncodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) { + (&encoder{order: order, buf: buf}).value(v) +} + +func DecodeValue(order binary.ByteOrder, buf []byte, v reflect.Value) { + (&decoder{order: order, buf: buf}).value(v) +} diff --git a/common/bufio/buffer.go b/common/bufio/buffer.go index cdd2896..74db53e 100644 --- a/common/bufio/buffer.go +++ b/common/bufio/buffer.go @@ -4,6 +4,7 @@ import ( "io" "sync" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" ) @@ -41,6 +42,25 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) { } } +func (w *BufferedWriter) WriteByte(c byte) error { + w.access.Lock() + defer w.access.Unlock() + if w.buffer == nil { + return common.Error(w.upstream.Write([]byte{c})) + } + for { + err := w.buffer.WriteByte(c) + if err == nil { + return nil + } + _, err = w.upstream.Write(w.buffer.Bytes()) + if err != nil { + return err + } + w.buffer.Reset() + } +} + func (w *BufferedWriter) Fallthrough() error { w.access.Lock() defer w.access.Unlock() diff --git a/common/bufio/copy.go b/common/bufio/copy.go index f8e63cd..309f56d 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -12,7 +12,6 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/task" ) @@ -163,11 +162,11 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error { var group task.Group - if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex { + if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex { group.Append("upload", func(ctx context.Context) error { err := common.Error(Copy(destination, source)) if err == nil { - rw.CloseWrite(destination) + N.CloseWrite(destination) } else { common.Close(destination) } @@ -179,11 +178,11 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina return common.Error(Copy(destination, source)) }) } - if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex { + if _, srcDuplex := common.Cast[N.WriteCloser](source); srcDuplex { group.Append("download", func(ctx context.Context) error { err := common.Error(Copy(source, destination)) if err == nil { - rw.CloseWrite(source) + N.CloseWrite(source) } else { common.Close(source) } diff --git a/common/rw/duplex.go b/common/network/duplex.go similarity index 96% rename from common/rw/duplex.go rename to common/network/duplex.go index ba5a754..a59fa7f 100644 --- a/common/rw/duplex.go +++ b/common/network/duplex.go @@ -1,4 +1,4 @@ -package rw +package network import ( "github.com/sagernet/sing/common" diff --git a/common/binary/varint_data.go b/common/varbin/data.go similarity index 64% rename from common/binary/varint_data.go rename to common/varbin/data.go index d1c691f..2bbfc31 100644 --- a/common/binary/varint_data.go +++ b/common/varbin/data.go @@ -1,57 +1,61 @@ -package binary +package varbin import ( "errors" "io" "reflect" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/binary" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" ) -type Reader interface { - io.Reader - io.ByteReader -} - -type Writer interface { - io.Writer - io.ByteWriter -} - -func ReadData(r Reader, order ByteOrder, rawData any) error { +func Read(r io.Reader, order binary.ByteOrder, rawData any) error { + reader := StubReader(r) switch data := rawData.(type) { case *[]bool: - return readBaseData(r, order, data) + return readBase(reader, order, data) case *[]int8: - return readBaseData(r, order, data) + return readBase(reader, order, data) case *[]uint8: - return readBaseData(r, order, data) + return readBase(reader, order, data) case *[]int16: - return readBaseData(r, order, data) + return readBase(reader, order, data) case *[]uint16: - return readBaseData(r, order, data) + return readBase(reader, order, data) case *[]int32: - return readBaseData(r, order, data) + return readBase(reader, order, data) case *[]uint32: - return readBaseData(r, order, data) + return readBase(reader, order, data) case *[]int64: - return readBaseData(r, order, data) + return readBase(reader, order, data) case *[]uint64: - return readBaseData(r, order, data) + return readBase(reader, order, data) case *[]float32: - return readBaseData(r, order, data) + return readBase(reader, order, data) case *[]float64: - return readBaseData(r, order, data) + return readBase(reader, order, data) default: if intBaseDataSize(rawData) != 0 { - return Read(r, order, rawData) + return binary.Read(reader, order, rawData) } } - return readData(r, order, reflect.Indirect(reflect.ValueOf(rawData))) + return read(reader, order, reflect.Indirect(reflect.ValueOf(rawData))) } -func readBaseData[T any](r Reader, order ByteOrder, data *[]T) error { - dataLen, err := ReadUvarint(r) +func ReadValue[T any](r io.Reader, order binary.ByteOrder) (T, error) { + var value T + err := Read(r, order, &value) + if err != nil { + return common.DefaultValue[T](), err + } + return value, nil +} + +func readBase[T any](r Reader, order binary.ByteOrder, data *[]T) error { + dataLen, err := binary.ReadUvarint(r) if err != nil { return E.Cause(err, "slice length") } @@ -60,7 +64,7 @@ func readBaseData[T any](r Reader, order ByteOrder, data *[]T) error { return nil } dataSlices := make([]T, dataLen) - err = Read(r, order, dataSlices) + err = binary.Read(r, order, dataSlices) if err != nil { return err } @@ -68,7 +72,7 @@ func readBaseData[T any](r Reader, order ByteOrder, data *[]T) error { return nil } -func readData(r Reader, order ByteOrder, data reflect.Value) error { +func read(r Reader, order binary.ByteOrder, data reflect.Value) error { switch data.Kind() { case reflect.Pointer: pointerValue, err := r.ReadByte() @@ -82,9 +86,9 @@ func readData(r Reader, order ByteOrder, data reflect.Value) error { if data.IsNil() { data.Set(reflect.New(data.Type().Elem())) } - return readData(r, order, data.Elem()) + return read(r, order, data.Elem()) case reflect.String: - stringLength, err := ReadUvarint(r) + stringLength, err := binary.ReadUvarint(r) if err != nil { return E.Cause(err, "string length") } @@ -100,25 +104,24 @@ func readData(r Reader, order ByteOrder, data reflect.Value) error { } case reflect.Array: arrayLen := data.Len() - itemSize := sizeof(data.Type()) + itemSize := data.Type().Elem().Len() if itemSize > 0 { buf := make([]byte, itemSize*arrayLen) _, err := io.ReadFull(r, buf) if err != nil { return err } - d := &decoder{order: order, buf: buf} - d.value(data) + binary.DecodeValue(order, buf, data) } else { for i := 0; i < arrayLen; i++ { - err := readData(r, order, data.Index(i)) + err := read(r, order, data.Index(i)) if err != nil { return E.Cause(err, "[", i, "]") } } } case reflect.Slice: - sliceLength, err := ReadUvarint(r) + sliceLength, err := binary.ReadUvarint(r) if err != nil { return E.Cause(err, "slice length") } @@ -127,7 +130,7 @@ func readData(r Reader, order ByteOrder, data reflect.Value) error { } else { dataSlices := makeBaseDataSlices(data, int(sliceLength)) if dataSlices != nil { - err = Read(r, order, dataSlices) + err = binary.Read(r, order, dataSlices) if err != nil { return err } @@ -139,7 +142,7 @@ func readData(r Reader, order ByteOrder, data reflect.Value) error { data.Set(reflect.MakeSlice(data.Type(), int(sliceLength), int(sliceLength))) } for i := 0; i < int(sliceLength); i++ { - err = readData(r, order, data.Index(i)) + err = read(r, order, data.Index(i)) if err != nil { return E.Cause(err, "[", i, "]") } @@ -147,19 +150,19 @@ func readData(r Reader, order ByteOrder, data reflect.Value) error { } } case reflect.Map: - mapLength, err := ReadUvarint(r) + mapLength, err := binary.ReadUvarint(r) if err != nil { return E.Cause(err, "map length") } data.Set(reflect.MakeMap(data.Type())) for index := 0; index < int(mapLength); index++ { key := reflect.New(data.Type().Key()).Elem() - err = readData(r, order, key) + err = read(r, order, key) if err != nil { return E.Cause(err, "[", index, "].key") } value := reflect.New(data.Type().Elem()).Elem() - err = readData(r, order, value) + err = read(r, order, value) if err != nil { return E.Cause(err, "[", index, "].value") } @@ -172,71 +175,90 @@ func readData(r Reader, order ByteOrder, data reflect.Value) error { field := data.Field(i) fieldName := fieldType.Field(i).Name if field.CanSet() || fieldName != "_" { - err := readData(r, order, field) + err := read(r, order, field) if err != nil { return E.Cause(err, fieldName) } } } default: - size := dataSize(data) + size := binary.DataSize(data) if size < 0 { return errors.New("invalid type " + reflect.TypeOf(data).String()) } - d := &decoder{order: order, buf: make([]byte, size)} - _, err := io.ReadFull(r, d.buf) + buf := make([]byte, size) + _, err := io.ReadFull(r, buf) if err != nil { return err } - d.value(data) + binary.DecodeValue(order, buf, data) } return nil } -func WriteData(writer Writer, order ByteOrder, rawData any) error { +func Write(w io.Writer, order binary.ByteOrder, rawData any) error { + if intBaseDataSize(rawData) != 0 { + return binary.Write(w, order, rawData) + } + var ( + writer Writer + bufferedWriter *bufio.BufferedWriter + ) + if bw, ok := w.(Writer); ok { + writer = bw + } else { + bufferedWriter = bufio.NewBufferedWriter(w, buf.NewSize(1024)) + writer = bufferedWriter + } switch data := rawData.(type) { case []bool: - return writeBaseData(writer, order, data) + return writeBase(writer, order, data) case []int8: - return writeBaseData(writer, order, data) + return writeBase(writer, order, data) case []uint8: - return writeBaseData(writer, order, data) + return writeBase(writer, order, data) case []int16: - return writeBaseData(writer, order, data) + return writeBase(writer, order, data) case []uint16: - return writeBaseData(writer, order, data) + return writeBase(writer, order, data) case []int32: - return writeBaseData(writer, order, data) + return writeBase(writer, order, data) case []uint32: - return writeBaseData(writer, order, data) + return writeBase(writer, order, data) case []int64: - return writeBaseData(writer, order, data) + return writeBase(writer, order, data) case []uint64: - return writeBaseData(writer, order, data) + return writeBase(writer, order, data) case []float32: - return writeBaseData(writer, order, data) + return writeBase(writer, order, data) case []float64: - return writeBaseData(writer, order, data) - default: - if intBaseDataSize(rawData) != 0 { - return Write(writer, order, rawData) + return writeBase(writer, order, data) + } + err := write(writer, order, reflect.Indirect(reflect.ValueOf(rawData))) + if err != nil { + return err + } + if bufferedWriter != nil { + err = bufferedWriter.Fallthrough() + if err != nil { + return err } } - return writeData(writer, order, reflect.Indirect(reflect.ValueOf(rawData))) + return nil } -func writeBaseData[T any](writer Writer, order ByteOrder, data []T) error { +func writeBase[T any](writer Writer, order binary.ByteOrder, data []T) error { _, err := WriteUvarint(writer, uint64(len(data))) if err != nil { return err } if len(data) > 0 { - return Write(writer, order, data) + return binary.Write(writer, order, data) } return nil } -func writeData(writer Writer, order ByteOrder, data reflect.Value) error { +func write(writer Writer, order binary.ByteOrder, data reflect.Value) error { switch data.Kind() { case reflect.Pointer: if data.IsNil() { @@ -249,7 +271,7 @@ func writeData(writer Writer, order ByteOrder, data reflect.Value) error { if err != nil { return err } - return writeData(writer, order, data.Elem()) + return write(writer, order, data.Elem()) } case reflect.String: stringValue := data.String() @@ -269,15 +291,14 @@ func writeData(writer Writer, order ByteOrder, data reflect.Value) error { itemSize := intItemBaseDataSize(data) if itemSize > 0 { buf := make([]byte, itemSize*dataLen) - e := &encoder{order: order, buf: buf} - e.value(data) + binary.EncodeValue(order, buf, data) _, err := writer.Write(buf) if err != nil { return E.Cause(err, reflect.TypeOf(data).String()) } } else { for i := 0; i < dataLen; i++ { - err := writeData(writer, order, data.Index(i)) + err := write(writer, order, data.Index(i)) if err != nil { return E.Cause(err, "[", i, "]") } @@ -293,13 +314,13 @@ func writeData(writer Writer, order ByteOrder, data reflect.Value) error { if dataLen > 0 { dataSlices := baseDataSlices(data) if dataSlices != nil { - err = Write(writer, order, dataSlices) + err = binary.Write(writer, order, dataSlices) if err != nil { return err } } else { for i := 0; i < dataLen; i++ { - err = writeData(writer, order, data.Index(i)) + err = write(writer, order, data.Index(i)) if err != nil { return E.Cause(err, "[", i, "]") } @@ -314,11 +335,11 @@ func writeData(writer Writer, order ByteOrder, data reflect.Value) error { } if dataLen > 0 { for index, key := range data.MapKeys() { - err = writeData(writer, order, key) + err = write(writer, order, key) if err != nil { return E.Cause(err, "[", index, "].key") } - err = writeData(writer, order, data.MapIndex(key)) + err = write(writer, order, data.MapIndex(key)) if err != nil { return E.Cause(err, "[", index, "].value") } @@ -331,20 +352,19 @@ func writeData(writer Writer, order ByteOrder, data reflect.Value) error { field := data.Field(i) fieldName := fieldType.Field(i).Name if field.CanSet() || fieldName != "_" { - err := writeData(writer, order, field) + err := write(writer, order, field) if err != nil { return E.Cause(err, fieldName) } } } default: - size := dataSize(data) + size := binary.DataSize(data) if size < 0 { return errors.New("binary.Write: some values are not fixed-sized in type " + data.Type().String()) } buf := make([]byte, size) - e := &encoder{order: order, buf: buf} - e.value(data) + binary.EncodeValue(order, buf, data) _, err := writer.Write(buf) if err != nil { return E.Cause(err, reflect.TypeOf(data).String()) diff --git a/common/varbin/data_if.go b/common/varbin/data_if.go new file mode 100644 index 0000000..54e523d --- /dev/null +++ b/common/varbin/data_if.go @@ -0,0 +1,34 @@ +package varbin + +import ( + "io" +) + +type Reader interface { + io.Reader + io.ByteReader +} + +type Writer interface { + io.Writer + io.ByteWriter +} + +var _ Reader = stubReader{} + +func StubReader(reader io.Reader) Reader { + if r, ok := reader.(Reader); ok { + return r + } + return stubReader{reader} +} + +type stubReader struct { + io.Reader +} + +func (r stubReader) ReadByte() (byte, error) { + var b [1]byte + _, err := r.Read(b[:]) + return b[0], err +} diff --git a/common/binary/varint_write.go b/common/varbin/uvarint.go similarity index 97% rename from common/binary/varint_write.go rename to common/varbin/uvarint.go index fb32eed..9c57443 100644 --- a/common/binary/varint_write.go +++ b/common/varbin/uvarint.go @@ -1,4 +1,4 @@ -package binary +package varbin import "io" diff --git a/common/binary/varint_unsafe.go b/common/varbin/value_slices_unsafe.go similarity index 99% rename from common/binary/varint_unsafe.go rename to common/varbin/value_slices_unsafe.go index ac782e4..5f96085 100644 --- a/common/binary/varint_unsafe.go +++ b/common/varbin/value_slices_unsafe.go @@ -1,4 +1,4 @@ -package binary +package varbin import ( "reflect" diff --git a/common/binary/varint_unsafe_test.go b/common/varbin/value_slices_unsafe_test.go similarity index 90% rename from common/binary/varint_unsafe_test.go rename to common/varbin/value_slices_unsafe_test.go index 7921df6..ed876a4 100644 --- a/common/binary/varint_unsafe_test.go +++ b/common/varbin/value_slices_unsafe_test.go @@ -1,7 +1,7 @@ -package binary +package varbin import ( - "math/rand/v2" + "math/rand" "reflect" "testing" @@ -11,7 +11,7 @@ import ( func TestSlicesValue(t *testing.T) { int64Arr := make([]int64, 64) for i := range int64Arr { - int64Arr[i] = rand.Int64() + int64Arr[i] = rand.Int63() } require.Equal(t, int64Arr, slicesValue[int64](reflect.ValueOf(int64Arr))) require.Equal(t, int64Arr, baseDataSlices(reflect.ValueOf(int64Arr))) @@ -22,7 +22,7 @@ func TestSetSliceValue(t *testing.T) { value := reflect.Indirect(reflect.ValueOf(&int64Arr)) newInt64Arr := make([]int64, 64) for i := range newInt64Arr { - newInt64Arr[i] = rand.Int64() + newInt64Arr[i] = rand.Int63() } setSliceValue[int64](value, newInt64Arr) require.Equal(t, newInt64Arr, slicesValue[int64](value))