mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-02 03:17:37 +03:00
badjson: Add context marshaler/unmarshaler
This commit is contained in:
parent
a4eb7fa900
commit
c80c8f907c
13 changed files with 285 additions and 60 deletions
|
@ -2,13 +2,14 @@ package badjson
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
func Decode(content []byte) (any, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||
func Decode(ctx context.Context, content []byte) (any, error) {
|
||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
||||
return decodeJSON(decoder)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
|
@ -9,75 +10,75 @@ import (
|
|||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
func Omitempty[T any](value T) (T, error) {
|
||||
func Omitempty[T any](ctx context.Context, value T) (T, error) {
|
||||
objectContent, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal object")
|
||||
}
|
||||
rawNewObject, err := Decode(objectContent)
|
||||
rawNewObject, err := Decode(ctx, objectContent)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), err
|
||||
}
|
||||
newObjectContent, err := json.Marshal(rawNewObject)
|
||||
newObjectContent, err := json.MarshalContext(ctx, rawNewObject)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
|
||||
}
|
||||
var newObject T
|
||||
err = json.Unmarshal(newObjectContent, &newObject)
|
||||
err = json.UnmarshalContext(ctx, newObjectContent, &newObject)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
|
||||
}
|
||||
return newObject, nil
|
||||
}
|
||||
|
||||
func Merge[T any](source T, destination T, disableAppend bool) (T, error) {
|
||||
rawSource, err := json.Marshal(source)
|
||||
func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) {
|
||||
rawSource, err := json.MarshalContext(ctx, source)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
||||
}
|
||||
rawDestination, err := json.Marshal(destination)
|
||||
rawDestination, err := json.MarshalContext(ctx, destination)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
||||
}
|
||||
return MergeFrom[T](rawSource, rawDestination, disableAppend)
|
||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
||||
}
|
||||
|
||||
func MergeFromSource[T any](rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
|
||||
func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
|
||||
if rawSource == nil {
|
||||
return destination, nil
|
||||
}
|
||||
rawDestination, err := json.Marshal(destination)
|
||||
rawDestination, err := json.MarshalContext(ctx, destination)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
||||
}
|
||||
return MergeFrom[T](rawSource, rawDestination, disableAppend)
|
||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
||||
}
|
||||
|
||||
func MergeFromDestination[T any](source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
||||
func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
||||
if rawDestination == nil {
|
||||
return source, nil
|
||||
}
|
||||
rawSource, err := json.Marshal(source)
|
||||
rawSource, err := json.MarshalContext(ctx, source)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
||||
}
|
||||
return MergeFrom[T](rawSource, rawDestination, disableAppend)
|
||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
||||
}
|
||||
|
||||
func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
||||
rawMerged, err := MergeJSON(rawSource, rawDestination, disableAppend)
|
||||
func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
||||
rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "merge options")
|
||||
}
|
||||
var merged T
|
||||
err = json.Unmarshal(rawMerged, &merged)
|
||||
err = json.UnmarshalContext(ctx, rawMerged, &merged)
|
||||
if err != nil {
|
||||
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
|
||||
func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
|
||||
if rawSource == nil && rawDestination == nil {
|
||||
return nil, os.ErrInvalid
|
||||
} else if rawSource == nil {
|
||||
|
@ -85,16 +86,16 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disabl
|
|||
} else if rawDestination == nil {
|
||||
return rawSource, nil
|
||||
}
|
||||
source, err := Decode(rawSource)
|
||||
source, err := Decode(ctx, rawSource)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode source")
|
||||
}
|
||||
destination, err := Decode(rawDestination)
|
||||
destination, err := Decode(ctx, rawDestination)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode destination")
|
||||
}
|
||||
if source == nil {
|
||||
return json.Marshal(destination)
|
||||
return json.MarshalContext(ctx, destination)
|
||||
} else if destination == nil {
|
||||
return json.Marshal(source)
|
||||
}
|
||||
|
@ -102,7 +103,7 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disabl
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(merged)
|
||||
return json.MarshalContext(ctx, merged)
|
||||
}
|
||||
|
||||
func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) {
|
||||
|
|
|
@ -1,32 +1,42 @@
|
|||
package badjson
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
func MarshallObjects(objects ...any) ([]byte, error) {
|
||||
return MarshallObjectsContext(context.Background(), objects...)
|
||||
}
|
||||
|
||||
func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) {
|
||||
if len(objects) == 1 {
|
||||
return json.Marshal(objects[0])
|
||||
}
|
||||
var content JSONObject
|
||||
for _, object := range objects {
|
||||
objectMap, err := newJSONObject(object)
|
||||
objectMap, err := newJSONObject(ctx, object)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content.PutAll(objectMap)
|
||||
}
|
||||
return content.MarshalJSON()
|
||||
return content.MarshalJSONContext(ctx)
|
||||
}
|
||||
|
||||
func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error {
|
||||
parentContent, err := newJSONObject(parentObject)
|
||||
return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object)
|
||||
}
|
||||
|
||||
func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
|
||||
parentContent, err := newJSONObject(ctx, parentObject)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var content JSONObject
|
||||
err = content.UnmarshalJSON(inputContent)
|
||||
err = content.UnmarshalJSONContext(ctx, inputContent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -39,20 +49,20 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error
|
|||
}
|
||||
return E.New("unexpected key: ", content.Keys()[0])
|
||||
}
|
||||
inputContent, err = content.MarshalJSON()
|
||||
inputContent, err = content.MarshalJSONContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.UnmarshalDisallowUnknownFields(inputContent, object)
|
||||
return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object)
|
||||
}
|
||||
|
||||
func newJSONObject(object any) (*JSONObject, error) {
|
||||
inputContent, err := json.Marshal(object)
|
||||
func newJSONObject(ctx context.Context, object any) (*JSONObject, error) {
|
||||
inputContent, err := json.MarshalContext(ctx, object)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var content JSONObject
|
||||
err = content.UnmarshalJSON(inputContent)
|
||||
err = content.UnmarshalJSONContext(ctx, inputContent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package badjson
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
|
@ -28,6 +29,10 @@ func (m *JSONObject) IsEmpty() bool {
|
|||
}
|
||||
|
||||
func (m *JSONObject) MarshalJSON() ([]byte, error) {
|
||||
return m.MarshalJSONContext(context.Background())
|
||||
}
|
||||
|
||||
func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
buffer := new(bytes.Buffer)
|
||||
buffer.WriteString("{")
|
||||
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
|
||||
|
@ -38,13 +43,13 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) {
|
|||
})
|
||||
iLen := len(items)
|
||||
for i, entry := range items {
|
||||
keyContent, err := json.Marshal(entry.Key)
|
||||
keyContent, err := json.MarshalContext(ctx, entry.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
||||
buffer.WriteString(": ")
|
||||
valueContent, err := json.Marshal(entry.Value)
|
||||
valueContent, err := json.MarshalContext(ctx, entry.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -58,7 +63,11 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
func (m *JSONObject) UnmarshalJSON(content []byte) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||
return m.UnmarshalJSONContext(context.Background(), content)
|
||||
}
|
||||
|
||||
func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
||||
m.Clear()
|
||||
objectStart, err := decoder.Token()
|
||||
if err != nil {
|
||||
|
|
|
@ -2,6 +2,7 @@ package badjson
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
@ -14,18 +15,22 @@ type TypedMap[K comparable, V any] struct {
|
|||
}
|
||||
|
||||
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
|
||||
return m.MarshalJSONContext(context.Background())
|
||||
}
|
||||
|
||||
func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
buffer := new(bytes.Buffer)
|
||||
buffer.WriteString("{")
|
||||
items := m.Entries()
|
||||
iLen := len(items)
|
||||
for i, entry := range items {
|
||||
keyContent, err := json.Marshal(entry.Key)
|
||||
keyContent, err := json.MarshalContext(ctx, entry.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
||||
buffer.WriteString(": ")
|
||||
valueContent, err := json.Marshal(entry.Value)
|
||||
valueContent, err := json.MarshalContext(ctx, entry.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -39,7 +44,11 @@ func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||
return m.UnmarshalJSONContext(context.Background(), content)
|
||||
}
|
||||
|
||||
func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
||||
m.Clear()
|
||||
objectStart, err := decoder.Token()
|
||||
if err != nil {
|
||||
|
@ -47,7 +56,7 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
|
|||
} else if objectStart != json.Delim('{') {
|
||||
return E.New("expected json object start, but starts with ", objectStart)
|
||||
}
|
||||
err = m.decodeJSON(decoder)
|
||||
err = m.decodeJSON(ctx, decoder)
|
||||
if err != nil {
|
||||
return E.Cause(err, "decode json object content")
|
||||
}
|
||||
|
@ -60,18 +69,18 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error {
|
||||
func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error {
|
||||
for decoder.More() {
|
||||
keyToken, err := decoder.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyContent, err := json.Marshal(keyToken)
|
||||
keyContent, err := json.MarshalContext(ctx, keyToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var entryKey K
|
||||
err = json.Unmarshal(keyContent, &entryKey)
|
||||
err = json.UnmarshalContext(ctx, keyContent, &entryKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
23
common/json/context_ext.go
Normal file
23
common/json/context_ext.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing/common/json/internal/contextjson"
|
||||
)
|
||||
|
||||
var (
|
||||
MarshalContext = json.MarshalContext
|
||||
UnmarshalContext = json.UnmarshalContext
|
||||
NewEncoderContext = json.NewEncoderContext
|
||||
NewDecoderContext = json.NewDecoderContext
|
||||
UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields
|
||||
)
|
||||
|
||||
type ContextMarshaler interface {
|
||||
MarshalJSONContext(ctx context.Context) ([]byte, error)
|
||||
}
|
||||
|
||||
type ContextUnmarshaler interface {
|
||||
UnmarshalJSONContext(ctx context.Context, content []byte) error
|
||||
}
|
11
common/json/internal/contextjson/context.go
Normal file
11
common/json/internal/contextjson/context.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package json
|
||||
|
||||
import "context"
|
||||
|
||||
type ContextMarshaler interface {
|
||||
MarshalJSONContext(ctx context.Context) ([]byte, error)
|
||||
}
|
||||
|
||||
type ContextUnmarshaler interface {
|
||||
UnmarshalJSONContext(ctx context.Context, content []byte) error
|
||||
}
|
43
common/json/internal/contextjson/context_test.go
Normal file
43
common/json/internal/contextjson/context_test.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package json_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/json/internal/contextjson"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type myStruct struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
return json.Marshal(ctx.Value("key").(string))
|
||||
}
|
||||
|
||||
func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
m.value = ctx.Value("key").(string)
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
func TestMarshalContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.WithValue(context.Background(), "key", "value")
|
||||
var s myStruct
|
||||
b, err := json.MarshalContext(ctx, &s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte(`"value"`), b)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
func TestUnmarshalContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.WithValue(context.Background(), "key", "value")
|
||||
var s myStruct
|
||||
err := json.UnmarshalContext(ctx, []byte(`{}`), &s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "value", s.value)
|
||||
}
|
|
@ -8,6 +8,7 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
@ -95,10 +96,15 @@ import (
|
|||
// Instead, they are replaced by the Unicode replacement
|
||||
// character U+FFFD.
|
||||
func Unmarshal(data []byte, v any) error {
|
||||
return UnmarshalContext(context.Background(), data, v)
|
||||
}
|
||||
|
||||
func UnmarshalContext(ctx context.Context, data []byte, v any) error {
|
||||
// Check for well-formedness.
|
||||
// Avoids filling out half a data structure
|
||||
// before discovering a JSON syntax error.
|
||||
var d decodeState
|
||||
d.ctx = ctx
|
||||
err := checkValid(data, &d.scan)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -209,6 +215,7 @@ type errorContext struct {
|
|||
|
||||
// decodeState represents the state while decoding a JSON value.
|
||||
type decodeState struct {
|
||||
ctx context.Context
|
||||
data []byte
|
||||
off int // next read offset in data
|
||||
opcode int // last read result
|
||||
|
@ -428,7 +435,7 @@ func (d *decodeState) valueQuoted() any {
|
|||
// If it encounters an Unmarshaler, indirect stops and returns that.
|
||||
// If decodingNull is true, indirect stops at the first settable pointer so it
|
||||
// can be set to nil.
|
||||
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
|
||||
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) {
|
||||
// Issue #24153 indicates that it is generally not a guaranteed property
|
||||
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
|
||||
// and expect the value to still be settable for values derived from
|
||||
|
@ -482,11 +489,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
|
|||
}
|
||||
if v.Type().NumMethod() > 0 && v.CanInterface() {
|
||||
if u, ok := v.Interface().(Unmarshaler); ok {
|
||||
return u, nil, reflect.Value{}
|
||||
return u, nil, nil, reflect.Value{}
|
||||
}
|
||||
if cu, ok := v.Interface().(ContextUnmarshaler); ok {
|
||||
return nil, cu, nil, reflect.Value{}
|
||||
}
|
||||
if !decodingNull {
|
||||
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
|
||||
return nil, u, reflect.Value{}
|
||||
return nil, nil, u, reflect.Value{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -498,14 +508,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
|
|||
v = v.Elem()
|
||||
}
|
||||
}
|
||||
return nil, nil, v
|
||||
return nil, nil, nil, v
|
||||
}
|
||||
|
||||
// array consumes an array from d.data[d.off-1:], decoding into v.
|
||||
// The first byte of the array ('[') has been read already.
|
||||
func (d *decodeState) array(v reflect.Value) error {
|
||||
// Check for unmarshaler.
|
||||
u, ut, pv := indirect(v, false)
|
||||
u, cu, ut, pv := indirect(v, false)
|
||||
if u != nil {
|
||||
start := d.readIndex()
|
||||
d.skip()
|
||||
|
@ -515,6 +525,15 @@ func (d *decodeState) array(v reflect.Value) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
if cu != nil {
|
||||
start := d.readIndex()
|
||||
d.skip()
|
||||
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
|
||||
if err != nil {
|
||||
d.saveError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if ut != nil {
|
||||
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
|
||||
d.skip()
|
||||
|
@ -612,7 +631,7 @@ var (
|
|||
// The first byte ('{') of the object has been read already.
|
||||
func (d *decodeState) object(v reflect.Value) error {
|
||||
// Check for unmarshaler.
|
||||
u, ut, pv := indirect(v, false)
|
||||
u, cu, ut, pv := indirect(v, false)
|
||||
if u != nil {
|
||||
start := d.readIndex()
|
||||
d.skip()
|
||||
|
@ -622,6 +641,15 @@ func (d *decodeState) object(v reflect.Value) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
if cu != nil {
|
||||
start := d.readIndex()
|
||||
d.skip()
|
||||
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
|
||||
if err != nil {
|
||||
d.saveError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if ut != nil {
|
||||
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
|
||||
d.skip()
|
||||
|
@ -870,7 +898,7 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
|
|||
return nil
|
||||
}
|
||||
isNull := item[0] == 'n' // null
|
||||
u, ut, pv := indirect(v, isNull)
|
||||
u, cu, ut, pv := indirect(v, isNull)
|
||||
if u != nil {
|
||||
err := u.UnmarshalJSON(item)
|
||||
if err != nil {
|
||||
|
@ -878,6 +906,13 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
|
|||
}
|
||||
return nil
|
||||
}
|
||||
if cu != nil {
|
||||
err := cu.UnmarshalJSONContext(d.ctx, item)
|
||||
if err != nil {
|
||||
d.saveError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if ut != nil {
|
||||
if item[0] != '"' {
|
||||
if fromQuoted {
|
||||
|
|
|
@ -12,6 +12,7 @@ package json
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
@ -156,7 +157,11 @@ import (
|
|||
// handle them. Passing cyclic structures to Marshal will result in
|
||||
// an error.
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
e := newEncodeState()
|
||||
return MarshalContext(context.Background(), v)
|
||||
}
|
||||
|
||||
func MarshalContext(ctx context.Context, v any) ([]byte, error) {
|
||||
e := newEncodeState(ctx)
|
||||
defer encodeStatePool.Put(e)
|
||||
|
||||
err := e.marshal(v, encOpts{escapeHTML: true})
|
||||
|
@ -251,6 +256,7 @@ var hex = "0123456789abcdef"
|
|||
type encodeState struct {
|
||||
bytes.Buffer // accumulated output
|
||||
|
||||
ctx context.Context
|
||||
// Keep track of what pointers we've seen in the current recursive call
|
||||
// path, to avoid cycles that could lead to a stack overflow. Only do
|
||||
// the relatively expensive map operations if ptrLevel is larger than
|
||||
|
@ -264,7 +270,7 @@ const startDetectingCyclesAfter = 1000
|
|||
|
||||
var encodeStatePool sync.Pool
|
||||
|
||||
func newEncodeState() *encodeState {
|
||||
func newEncodeState(ctx context.Context) *encodeState {
|
||||
if v := encodeStatePool.Get(); v != nil {
|
||||
e := v.(*encodeState)
|
||||
e.Reset()
|
||||
|
@ -274,7 +280,7 @@ func newEncodeState() *encodeState {
|
|||
e.ptrLevel = 0
|
||||
return e
|
||||
}
|
||||
return &encodeState{ptrSeen: make(map[any]struct{})}
|
||||
return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})}
|
||||
}
|
||||
|
||||
// jsonError is an error wrapper type for internal use only.
|
||||
|
@ -371,8 +377,9 @@ func typeEncoder(t reflect.Type) encoderFunc {
|
|||
}
|
||||
|
||||
var (
|
||||
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
|
||||
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
|
||||
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
|
||||
contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem()
|
||||
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
|
||||
)
|
||||
|
||||
// newTypeEncoder constructs an encoderFunc for a type.
|
||||
|
@ -385,9 +392,15 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
|
|||
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
|
||||
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
|
||||
}
|
||||
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) {
|
||||
return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false))
|
||||
}
|
||||
if t.Implements(marshalerType) {
|
||||
return marshalerEncoder
|
||||
}
|
||||
if t.Implements(contextMarshalerType) {
|
||||
return contextMarshalerEncoder
|
||||
}
|
||||
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
|
||||
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
|
||||
}
|
||||
|
@ -470,6 +483,47 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|||
}
|
||||
}
|
||||
|
||||
func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
if v.Kind() == reflect.Pointer && v.IsNil() {
|
||||
e.WriteString("null")
|
||||
return
|
||||
}
|
||||
m, ok := v.Interface().(ContextMarshaler)
|
||||
if !ok {
|
||||
e.WriteString("null")
|
||||
return
|
||||
}
|
||||
b, err := m.MarshalJSONContext(e.ctx)
|
||||
if err == nil {
|
||||
e.Grow(len(b))
|
||||
out := availableBuffer(&e.Buffer)
|
||||
out, err = appendCompact(out, b, opts.escapeHTML)
|
||||
e.Buffer.Write(out)
|
||||
}
|
||||
if err != nil {
|
||||
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
|
||||
}
|
||||
}
|
||||
|
||||
func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
va := v.Addr()
|
||||
if va.IsNil() {
|
||||
e.WriteString("null")
|
||||
return
|
||||
}
|
||||
m := va.Interface().(ContextMarshaler)
|
||||
b, err := m.MarshalJSONContext(e.ctx)
|
||||
if err == nil {
|
||||
e.Grow(len(b))
|
||||
out := availableBuffer(&e.Buffer)
|
||||
out, err = appendCompact(out, b, opts.escapeHTML)
|
||||
e.Buffer.Write(out)
|
||||
}
|
||||
if err != nil {
|
||||
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
|
||||
}
|
||||
}
|
||||
|
||||
func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
if v.Kind() == reflect.Pointer && v.IsNil() {
|
||||
e.WriteString("null")
|
||||
|
@ -827,7 +881,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc {
|
|||
// Byte slices get special treatment; arrays don't.
|
||||
if t.Elem().Kind() == reflect.Uint8 {
|
||||
p := reflect.PointerTo(t.Elem())
|
||||
if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) {
|
||||
if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) {
|
||||
return encodeByteSlice
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ package json
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
@ -29,7 +30,11 @@ type Decoder struct {
|
|||
// The decoder introduces its own buffering and may
|
||||
// read data from r beyond the JSON values requested.
|
||||
func NewDecoder(r io.Reader) *Decoder {
|
||||
return &Decoder{r: r}
|
||||
return NewDecoderContext(context.Background(), r)
|
||||
}
|
||||
|
||||
func NewDecoderContext(ctx context.Context, r io.Reader) *Decoder {
|
||||
return &Decoder{r: r, d: decodeState{ctx: ctx}}
|
||||
}
|
||||
|
||||
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
|
||||
|
@ -183,6 +188,7 @@ func nonSpace(b []byte) bool {
|
|||
|
||||
// An Encoder writes JSON values to an output stream.
|
||||
type Encoder struct {
|
||||
ctx context.Context
|
||||
w io.Writer
|
||||
err error
|
||||
escapeHTML bool
|
||||
|
@ -194,7 +200,11 @@ type Encoder struct {
|
|||
|
||||
// NewEncoder returns a new encoder that writes to w.
|
||||
func NewEncoder(w io.Writer) *Encoder {
|
||||
return &Encoder{w: w, escapeHTML: true}
|
||||
return NewEncoderContext(context.Background(), w)
|
||||
}
|
||||
|
||||
func NewEncoderContext(ctx context.Context, w io.Writer) *Encoder {
|
||||
return &Encoder{ctx: ctx, w: w, escapeHTML: true}
|
||||
}
|
||||
|
||||
// Encode writes the JSON encoding of v to the stream,
|
||||
|
@ -207,7 +217,7 @@ func (enc *Encoder) Encode(v any) error {
|
|||
return enc.err
|
||||
}
|
||||
|
||||
e := newEncodeState()
|
||||
e := newEncodeState(enc.ctx)
|
||||
defer encodeStatePool.Put(e)
|
||||
|
||||
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package json
|
||||
|
||||
import "context"
|
||||
|
||||
func UnmarshalDisallowUnknownFields(data []byte, v any) error {
|
||||
var d decodeState
|
||||
d.disallowUnknownFields = true
|
||||
|
@ -10,3 +12,15 @@ func UnmarshalDisallowUnknownFields(data []byte, v any) error {
|
|||
d.init(data)
|
||||
return d.unmarshal(v)
|
||||
}
|
||||
|
||||
func UnmarshalContextDisallowUnknownFields(ctx context.Context, data []byte, v any) error {
|
||||
var d decodeState
|
||||
d.ctx = ctx
|
||||
d.disallowUnknownFields = true
|
||||
err := checkValid(data, &d.scan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.init(data)
|
||||
return d.unmarshal(v)
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package json
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
|
@ -10,7 +11,11 @@ import (
|
|||
)
|
||||
|
||||
func UnmarshalExtended[T any](content []byte) (T, error) {
|
||||
decoder := NewDecoder(NewCommentFilter(bytes.NewReader(content)))
|
||||
return UnmarshalExtendedContext[T](context.Background(), content)
|
||||
}
|
||||
|
||||
func UnmarshalExtendedContext[T any](ctx context.Context, content []byte) (T, error) {
|
||||
decoder := NewDecoderContext(ctx, NewCommentFilter(bytes.NewReader(content)))
|
||||
var value T
|
||||
err := decoder.Decode(&value)
|
||||
if err == nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue