From a31dba8ad2ffb857ba15d9a96f62f627d57e95c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 23 Jun 2024 14:07:07 +0800 Subject: [PATCH] bianry: Improve varint read and write --- common/binary/variant_data.go | 305 ---------------------- common/binary/varint_data.go | 386 ++++++++++++++++++++++++++++ common/binary/varint_unsafe.go | 106 ++++++++ common/binary/varint_unsafe_test.go | 35 +++ common/binary/varint_write.go | 41 +++ 5 files changed, 568 insertions(+), 305 deletions(-) delete mode 100644 common/binary/variant_data.go create mode 100644 common/binary/varint_data.go create mode 100644 common/binary/varint_unsafe.go create mode 100644 common/binary/varint_unsafe_test.go create mode 100644 common/binary/varint_write.go diff --git a/common/binary/variant_data.go b/common/binary/variant_data.go deleted file mode 100644 index eb2b128..0000000 --- a/common/binary/variant_data.go +++ /dev/null @@ -1,305 +0,0 @@ -package binary - -import ( - "bufio" - "errors" - "io" - "reflect" - - E "github.com/sagernet/sing/common/exceptions" -) - -func ReadDataSlice(r *bufio.Reader, order ByteOrder, data ...any) error { - for index, item := range data { - err := ReadData(r, order, item) - if err != nil { - return E.Cause(err, "[", index, "]") - } - } - return nil -} - -func ReadData(r *bufio.Reader, order ByteOrder, data any) error { - switch dataPtr := data.(type) { - case *[]uint8: - bytesLen, err := ReadUvarint(r) - if err != nil { - return E.Cause(err, "bytes length") - } - newBytes := make([]uint8, bytesLen) - _, err = io.ReadFull(r, newBytes) - if err != nil { - return E.Cause(err, "bytes value") - } - *dataPtr = newBytes - default: - if intBaseDataSize(data) != 0 { - return Read(r, order, data) - } - } - dataValue := reflect.ValueOf(data) - if dataValue.Kind() == reflect.Pointer { - dataValue = dataValue.Elem() - } - return readData(r, order, dataValue) -} - -func readData(r *bufio.Reader, order ByteOrder, data reflect.Value) error { - switch data.Kind() { - case reflect.Pointer: - pointerValue, err := r.ReadByte() - if err != nil { - return err - } - if pointerValue == 0 { - data.SetZero() - return nil - } - if data.IsNil() { - data.Set(reflect.New(data.Type().Elem())) - } - return readData(r, order, data.Elem()) - case reflect.String: - stringLength, err := ReadUvarint(r) - if err != nil { - return E.Cause(err, "string length") - } - if stringLength == 0 { - data.SetZero() - } else { - stringData := make([]byte, stringLength) - _, err = io.ReadFull(r, stringData) - if err != nil { - return E.Cause(err, "string value") - } - data.SetString(string(stringData)) - } - case reflect.Array: - arrayLen := data.Len() - for i := 0; i < arrayLen; i++ { - err := readData(r, order, data.Index(i)) - if err != nil { - return E.Cause(err, "[", i, "]") - } - } - case reflect.Slice: - sliceLength, err := ReadUvarint(r) - if err != nil { - return E.Cause(err, "slice length") - } - if !data.IsNil() && data.Cap() >= int(sliceLength) { - data.SetLen(int(sliceLength)) - } else if sliceLength > 0 { - data.Set(reflect.MakeSlice(data.Type(), int(sliceLength), int(sliceLength))) - } - if sliceLength > 0 { - if data.Type().Elem().Kind() == reflect.Uint8 { - _, err = io.ReadFull(r, data.Bytes()) - if err != nil { - return E.Cause(err, "bytes value") - } - } else { - for index := 0; index < int(sliceLength); index++ { - err = readData(r, order, data.Index(index)) - if err != nil { - return E.Cause(err, "[", index, "]") - } - } - } - } - case reflect.Map: - mapLength, err := ReadUvarint(r) - if err != nil { - return E.Cause(err, "map length") - } - data.Set(reflect.MakeMap(data.Type())) - for index := 0; index < int(mapLength); index++ { - key := reflect.New(data.Type().Key()).Elem() - err = readData(r, order, key) - if err != nil { - return E.Cause(err, "[", index, "].key") - } - value := reflect.New(data.Type().Elem()).Elem() - err = readData(r, order, value) - if err != nil { - return E.Cause(err, "[", index, "].value") - } - data.SetMapIndex(key, value) - } - case reflect.Struct: - fieldType := data.Type() - fieldLen := data.NumField() - for i := 0; i < fieldLen; i++ { - field := data.Field(i) - fieldName := fieldType.Field(i).Name - if field.CanSet() || fieldName != "_" { - err := readData(r, order, field) - if err != nil { - return E.Cause(err, fieldName) - } - } - } - default: - size := dataSize(data) - if size < 0 { - return errors.New("invalid type " + reflect.TypeOf(data).String()) - } - d := &decoder{order: order, buf: make([]byte, size)} - _, err := io.ReadFull(r, d.buf) - if err != nil { - return err - } - d.value(data) - } - return nil -} - -func WriteDataSlice(writer *bufio.Writer, order ByteOrder, data ...any) error { - for index, item := range data { - err := WriteData(writer, order, item) - if err != nil { - return E.Cause(err, "[", index, "]") - } - } - return nil -} - -func WriteData(writer *bufio.Writer, order ByteOrder, data any) error { - switch dataPtr := data.(type) { - case []uint8: - _, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(len(dataPtr)))) - if err != nil { - return E.Cause(err, "bytes length") - } - _, err = writer.Write(dataPtr) - if err != nil { - return E.Cause(err, "bytes value") - } - default: - if intBaseDataSize(data) != 0 { - return Write(writer, order, data) - } - } - return writeData(writer, order, reflect.Indirect(reflect.ValueOf(data))) -} - -func writeData(writer *bufio.Writer, order ByteOrder, data reflect.Value) error { - switch data.Kind() { - case reflect.Pointer: - if data.IsNil() { - err := writer.WriteByte(0) - if err != nil { - return err - } - } else { - err := writer.WriteByte(1) - if err != nil { - return err - } - return writeData(writer, order, data.Elem()) - } - case reflect.String: - stringValue := data.String() - _, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(len(stringValue)))) - if err != nil { - return E.Cause(err, "string length") - } - if stringValue != "" { - _, err = writer.WriteString(stringValue) - if err != nil { - return E.Cause(err, "string value") - } - } - case reflect.Array: - dataLen := data.Len() - for i := 0; i < dataLen; i++ { - err := writeData(writer, order, data.Index(i)) - if err != nil { - return E.Cause(err, "[", i, "]") - } - } - case reflect.Slice: - dataLen := data.Len() - _, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(dataLen))) - if err != nil { - return E.Cause(err, "slice length") - } - if dataLen > 0 { - if data.Type().Elem().Kind() == reflect.Uint8 { - _, err = writer.Write(data.Bytes()) - if err != nil { - return E.Cause(err, "bytes value") - } - } else { - for i := 0; i < dataLen; i++ { - err = writeData(writer, order, data.Index(i)) - if err != nil { - return E.Cause(err, "[", i, "]") - } - } - } - } - case reflect.Map: - dataLen := data.Len() - _, err := writer.Write(AppendUvarint(writer.AvailableBuffer(), uint64(dataLen))) - if err != nil { - return E.Cause(err, "map length") - } - if dataLen > 0 { - for index, key := range data.MapKeys() { - err = writeData(writer, order, key) - if err != nil { - return E.Cause(err, "[", index, "].key") - } - err = writeData(writer, order, data.MapIndex(key)) - if err != nil { - return E.Cause(err, "[", index, "].value") - } - } - } - case reflect.Struct: - fieldType := data.Type() - fieldLen := data.NumField() - for i := 0; i < fieldLen; i++ { - field := data.Field(i) - fieldName := fieldType.Field(i).Name - if field.CanSet() || fieldName != "_" { - err := writeData(writer, order, field) - if err != nil { - return E.Cause(err, fieldName) - } - } - } - default: - size := dataSize(data) - if size < 0 { - return errors.New("binary.Write: some values are not fixed-sized in type " + data.Type().String()) - } - buf := make([]byte, size) - e := &encoder{order: order, buf: buf} - e.value(data) - _, err := writer.Write(buf) - if err != nil { - return E.Cause(err, reflect.TypeOf(data).String()) - } - } - return nil -} - -func intBaseDataSize(data any) int { - switch data.(type) { - case bool, int8, uint8: - return 1 - case int16, uint16: - return 2 - case int32, uint32: - return 4 - case int64, uint64: - return 8 - case float32: - return 4 - case float64: - return 8 - } - return 0 -} diff --git a/common/binary/varint_data.go b/common/binary/varint_data.go new file mode 100644 index 0000000..d1c691f --- /dev/null +++ b/common/binary/varint_data.go @@ -0,0 +1,386 @@ +package binary + +import ( + "errors" + "io" + "reflect" + + E "github.com/sagernet/sing/common/exceptions" +) + +type Reader interface { + io.Reader + io.ByteReader +} + +type Writer interface { + io.Writer + io.ByteWriter +} + +func ReadData(r Reader, order ByteOrder, rawData any) error { + switch data := rawData.(type) { + case *[]bool: + return readBaseData(r, order, data) + case *[]int8: + return readBaseData(r, order, data) + case *[]uint8: + return readBaseData(r, order, data) + case *[]int16: + return readBaseData(r, order, data) + case *[]uint16: + return readBaseData(r, order, data) + case *[]int32: + return readBaseData(r, order, data) + case *[]uint32: + return readBaseData(r, order, data) + case *[]int64: + return readBaseData(r, order, data) + case *[]uint64: + return readBaseData(r, order, data) + case *[]float32: + return readBaseData(r, order, data) + case *[]float64: + return readBaseData(r, order, data) + default: + if intBaseDataSize(rawData) != 0 { + return Read(r, order, rawData) + } + } + return readData(r, order, reflect.Indirect(reflect.ValueOf(rawData))) +} + +func readBaseData[T any](r Reader, order ByteOrder, data *[]T) error { + dataLen, err := ReadUvarint(r) + if err != nil { + return E.Cause(err, "slice length") + } + if dataLen == 0 { + *data = nil + return nil + } + dataSlices := make([]T, dataLen) + err = Read(r, order, dataSlices) + if err != nil { + return err + } + *data = dataSlices + return nil +} + +func readData(r Reader, order ByteOrder, data reflect.Value) error { + switch data.Kind() { + case reflect.Pointer: + pointerValue, err := r.ReadByte() + if err != nil { + return err + } + if pointerValue == 0 { + data.SetZero() + return nil + } + if data.IsNil() { + data.Set(reflect.New(data.Type().Elem())) + } + return readData(r, order, data.Elem()) + case reflect.String: + stringLength, err := ReadUvarint(r) + if err != nil { + return E.Cause(err, "string length") + } + if stringLength == 0 { + data.SetZero() + } else { + stringData := make([]byte, stringLength) + _, err = io.ReadFull(r, stringData) + if err != nil { + return E.Cause(err, "string value") + } + data.SetString(string(stringData)) + } + case reflect.Array: + arrayLen := data.Len() + itemSize := sizeof(data.Type()) + if itemSize > 0 { + buf := make([]byte, itemSize*arrayLen) + _, err := io.ReadFull(r, buf) + if err != nil { + return err + } + d := &decoder{order: order, buf: buf} + d.value(data) + } else { + for i := 0; i < arrayLen; i++ { + err := readData(r, order, data.Index(i)) + if err != nil { + return E.Cause(err, "[", i, "]") + } + } + } + case reflect.Slice: + sliceLength, err := ReadUvarint(r) + if err != nil { + return E.Cause(err, "slice length") + } + if sliceLength == 0 { + data.SetZero() + } else { + dataSlices := makeBaseDataSlices(data, int(sliceLength)) + if dataSlices != nil { + err = Read(r, order, dataSlices) + if err != nil { + return err + } + setBaseDataSlices(data, dataSlices) + } else { + if !data.IsNil() && data.Cap() >= int(sliceLength) { + data.SetLen(int(sliceLength)) + } else if sliceLength > 0 { + data.Set(reflect.MakeSlice(data.Type(), int(sliceLength), int(sliceLength))) + } + for i := 0; i < int(sliceLength); i++ { + err = readData(r, order, data.Index(i)) + if err != nil { + return E.Cause(err, "[", i, "]") + } + } + } + } + case reflect.Map: + mapLength, err := ReadUvarint(r) + if err != nil { + return E.Cause(err, "map length") + } + data.Set(reflect.MakeMap(data.Type())) + for index := 0; index < int(mapLength); index++ { + key := reflect.New(data.Type().Key()).Elem() + err = readData(r, order, key) + if err != nil { + return E.Cause(err, "[", index, "].key") + } + value := reflect.New(data.Type().Elem()).Elem() + err = readData(r, order, value) + if err != nil { + return E.Cause(err, "[", index, "].value") + } + data.SetMapIndex(key, value) + } + case reflect.Struct: + fieldType := data.Type() + fieldLen := data.NumField() + for i := 0; i < fieldLen; i++ { + field := data.Field(i) + fieldName := fieldType.Field(i).Name + if field.CanSet() || fieldName != "_" { + err := readData(r, order, field) + if err != nil { + return E.Cause(err, fieldName) + } + } + } + default: + size := dataSize(data) + if size < 0 { + return errors.New("invalid type " + reflect.TypeOf(data).String()) + } + d := &decoder{order: order, buf: make([]byte, size)} + _, err := io.ReadFull(r, d.buf) + if err != nil { + return err + } + d.value(data) + } + return nil +} + +func WriteData(writer Writer, order ByteOrder, rawData any) error { + switch data := rawData.(type) { + case []bool: + return writeBaseData(writer, order, data) + case []int8: + return writeBaseData(writer, order, data) + case []uint8: + return writeBaseData(writer, order, data) + case []int16: + return writeBaseData(writer, order, data) + case []uint16: + return writeBaseData(writer, order, data) + case []int32: + return writeBaseData(writer, order, data) + case []uint32: + return writeBaseData(writer, order, data) + case []int64: + return writeBaseData(writer, order, data) + case []uint64: + return writeBaseData(writer, order, data) + case []float32: + return writeBaseData(writer, order, data) + case []float64: + return writeBaseData(writer, order, data) + default: + if intBaseDataSize(rawData) != 0 { + return Write(writer, order, rawData) + } + } + return writeData(writer, order, reflect.Indirect(reflect.ValueOf(rawData))) +} + +func writeBaseData[T any](writer Writer, order ByteOrder, data []T) error { + _, err := WriteUvarint(writer, uint64(len(data))) + if err != nil { + return err + } + if len(data) > 0 { + return Write(writer, order, data) + } + return nil +} + +func writeData(writer Writer, order ByteOrder, data reflect.Value) error { + switch data.Kind() { + case reflect.Pointer: + if data.IsNil() { + err := writer.WriteByte(0) + if err != nil { + return err + } + } else { + err := writer.WriteByte(1) + if err != nil { + return err + } + return writeData(writer, order, data.Elem()) + } + case reflect.String: + stringValue := data.String() + _, err := WriteUvarint(writer, uint64(len(stringValue))) + if err != nil { + return E.Cause(err, "string length") + } + if stringValue != "" { + _, err = writer.Write([]byte(stringValue)) + if err != nil { + return E.Cause(err, "string value") + } + } + case reflect.Array: + dataLen := data.Len() + if dataLen > 0 { + itemSize := intItemBaseDataSize(data) + if itemSize > 0 { + buf := make([]byte, itemSize*dataLen) + e := &encoder{order: order, buf: buf} + e.value(data) + _, err := writer.Write(buf) + if err != nil { + return E.Cause(err, reflect.TypeOf(data).String()) + } + } else { + for i := 0; i < dataLen; i++ { + err := writeData(writer, order, data.Index(i)) + if err != nil { + return E.Cause(err, "[", i, "]") + } + } + } + } + case reflect.Slice: + dataLen := data.Len() + _, err := WriteUvarint(writer, uint64(dataLen)) + if err != nil { + return E.Cause(err, "slice length") + } + if dataLen > 0 { + dataSlices := baseDataSlices(data) + if dataSlices != nil { + err = Write(writer, order, dataSlices) + if err != nil { + return err + } + } else { + for i := 0; i < dataLen; i++ { + err = writeData(writer, order, data.Index(i)) + if err != nil { + return E.Cause(err, "[", i, "]") + } + } + } + } + case reflect.Map: + dataLen := data.Len() + _, err := WriteUvarint(writer, uint64(dataLen)) + if err != nil { + return E.Cause(err, "map length") + } + if dataLen > 0 { + for index, key := range data.MapKeys() { + err = writeData(writer, order, key) + if err != nil { + return E.Cause(err, "[", index, "].key") + } + err = writeData(writer, order, data.MapIndex(key)) + if err != nil { + return E.Cause(err, "[", index, "].value") + } + } + } + case reflect.Struct: + fieldType := data.Type() + fieldLen := data.NumField() + for i := 0; i < fieldLen; i++ { + field := data.Field(i) + fieldName := fieldType.Field(i).Name + if field.CanSet() || fieldName != "_" { + err := writeData(writer, order, field) + if err != nil { + return E.Cause(err, fieldName) + } + } + } + default: + size := dataSize(data) + if size < 0 { + return errors.New("binary.Write: some values are not fixed-sized in type " + data.Type().String()) + } + buf := make([]byte, size) + e := &encoder{order: order, buf: buf} + e.value(data) + _, err := writer.Write(buf) + if err != nil { + return E.Cause(err, reflect.TypeOf(data).String()) + } + } + return nil +} + +func intItemBaseDataSize(data reflect.Value) int { + itemType := data.Type().Elem() + switch itemType.Kind() { + case reflect.Bool, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return itemType.Len() + default: + return -1 + } +} + +func intBaseDataSize(data any) int { + switch data.(type) { + case bool, int8, uint8, + *bool, *int8, *uint8: + return 1 + case int16, uint16, *int16, *uint16: + return 2 + case int32, uint32, *int32, *uint32: + return 4 + case int64, uint64, *int64, *uint64: + return 8 + case float32, *float32: + return 4 + case float64, *float64: + return 8 + } + return 0 +} diff --git a/common/binary/varint_unsafe.go b/common/binary/varint_unsafe.go new file mode 100644 index 0000000..ac782e4 --- /dev/null +++ b/common/binary/varint_unsafe.go @@ -0,0 +1,106 @@ +package binary + +import ( + "reflect" + "unsafe" +) + +type myValue struct { + typ_ *any + ptr unsafe.Pointer +} + +func slicesValue[T any](value reflect.Value) []T { + v := (*myValue)(unsafe.Pointer(&value)) + return *(*[]T)(v.ptr) +} + +func setSliceValue[T any](value reflect.Value, x []T) { + v := (*myValue)(unsafe.Pointer(&value)) + *(*[]T)(v.ptr) = x +} + +func baseDataSlices(data reflect.Value) any { + switch data.Type().Elem().Kind() { + case reflect.Bool: + return slicesValue[bool](data) + case reflect.Int8: + return slicesValue[int8](data) + case reflect.Uint8: + return slicesValue[uint8](data) + case reflect.Int16: + return slicesValue[int16](data) + case reflect.Uint16: + return slicesValue[uint16](data) + case reflect.Int32: + return slicesValue[int32](data) + case reflect.Uint32: + return slicesValue[uint32](data) + case reflect.Int64: + return slicesValue[int64](data) + case reflect.Uint64: + return slicesValue[uint64](data) + case reflect.Float32: + return slicesValue[float32](data) + case reflect.Float64: + return slicesValue[float64](data) + default: + return nil + } +} + +func makeBaseDataSlices(data reflect.Value, dataLen int) any { + switch data.Type().Elem().Kind() { + case reflect.Bool: + return make([]bool, dataLen) + case reflect.Int8: + return make([]int8, dataLen) + case reflect.Uint8: + return make([]uint8, dataLen) + case reflect.Int16: + return make([]int16, dataLen) + case reflect.Uint16: + return make([]uint16, dataLen) + case reflect.Int32: + return make([]int32, dataLen) + case reflect.Uint32: + return make([]uint32, dataLen) + case reflect.Int64: + return make([]int64, dataLen) + case reflect.Uint64: + return make([]uint64, dataLen) + case reflect.Float32: + return make([]float32, dataLen) + case reflect.Float64: + return make([]float64, dataLen) + default: + return nil + } +} + +func setBaseDataSlices(data reflect.Value, rawDataSlices any) { + switch dataSlices := rawDataSlices.(type) { + case []bool: + setSliceValue(data, dataSlices) + case []int8: + setSliceValue(data, dataSlices) + case []uint8: + setSliceValue(data, dataSlices) + case []int16: + setSliceValue(data, dataSlices) + case []uint16: + setSliceValue(data, dataSlices) + case []int32: + setSliceValue(data, dataSlices) + case []uint32: + setSliceValue(data, dataSlices) + case []int64: + setSliceValue(data, dataSlices) + case []uint64: + setSliceValue(data, dataSlices) + case []float32: + setSliceValue(data, dataSlices) + case []float64: + setSliceValue(data, dataSlices) + } +} diff --git a/common/binary/varint_unsafe_test.go b/common/binary/varint_unsafe_test.go new file mode 100644 index 0000000..7921df6 --- /dev/null +++ b/common/binary/varint_unsafe_test.go @@ -0,0 +1,35 @@ +package binary + +import ( + "math/rand/v2" + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSlicesValue(t *testing.T) { + int64Arr := make([]int64, 64) + for i := range int64Arr { + int64Arr[i] = rand.Int64() + } + require.Equal(t, int64Arr, slicesValue[int64](reflect.ValueOf(int64Arr))) + require.Equal(t, int64Arr, baseDataSlices(reflect.ValueOf(int64Arr))) +} + +func TestSetSliceValue(t *testing.T) { + int64Arr := make([]int64, 64) + value := reflect.Indirect(reflect.ValueOf(&int64Arr)) + newInt64Arr := make([]int64, 64) + for i := range newInt64Arr { + newInt64Arr[i] = rand.Int64() + } + setSliceValue[int64](value, newInt64Arr) + require.Equal(t, newInt64Arr, slicesValue[int64](value)) + newInt64Arr2 := makeBaseDataSlices(value, 64) + copy(newInt64Arr2.([]int64), newInt64Arr) + require.Equal(t, newInt64Arr, newInt64Arr2) + value.SetZero() + setBaseDataSlices(value, newInt64Arr2) + require.Equal(t, newInt64Arr, slicesValue[int64](value)) +} diff --git a/common/binary/varint_write.go b/common/binary/varint_write.go new file mode 100644 index 0000000..fb32eed --- /dev/null +++ b/common/binary/varint_write.go @@ -0,0 +1,41 @@ +package binary + +import "io" + +func WriteUvarint(writer io.ByteWriter, value uint64) (int, error) { + var writeN int + for value >= 0x80 { + err := writer.WriteByte(byte(value) | 0x80) + if err != nil { + return writeN, err + } + value >>= 7 + writeN++ + } + err := writer.WriteByte(byte(value)) + if err != nil { + return writeN, err + } + return writeN + 1, nil +} + +func UvarintLen(x uint64) int { + switch { + case x < 1<<(7*1): + return 1 + case x < 1<<(7*2): + return 2 + case x < 1<<(7*3): + return 3 + case x < 1<<(7*4): + return 4 + case x < 1<<(7*5): + return 5 + case x < 1<<(7*6): + return 6 + case x < 1<<(7*7): + return 7 + default: + return 8 + } +}