badjson: Add context marshaler/unmarshaler

This commit is contained in:
世界 2024-10-31 20:21:00 +08:00
parent a4eb7fa900
commit c80c8f907c
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
13 changed files with 285 additions and 60 deletions

View file

@ -2,13 +2,14 @@ package badjson
import (
"bytes"
"context"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
func Decode(content []byte) (any, error) {
decoder := json.NewDecoder(bytes.NewReader(content))
func Decode(ctx context.Context, content []byte) (any, error) {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
return decodeJSON(decoder)
}

View file

@ -1,6 +1,7 @@
package badjson
import (
"context"
"os"
"reflect"
@ -9,75 +10,75 @@ import (
"github.com/sagernet/sing/common/json"
)
func Omitempty[T any](value T) (T, error) {
func Omitempty[T any](ctx context.Context, value T) (T, error) {
objectContent, err := json.Marshal(value)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal object")
}
rawNewObject, err := Decode(objectContent)
rawNewObject, err := Decode(ctx, objectContent)
if err != nil {
return common.DefaultValue[T](), err
}
newObjectContent, err := json.Marshal(rawNewObject)
newObjectContent, err := json.MarshalContext(ctx, rawNewObject)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
}
var newObject T
err = json.Unmarshal(newObjectContent, &newObject)
err = json.UnmarshalContext(ctx, newObjectContent, &newObject)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
}
return newObject, nil
}
func Merge[T any](source T, destination T, disableAppend bool) (T, error) {
rawSource, err := json.Marshal(source)
func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) {
rawSource, err := json.MarshalContext(ctx, source)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source")
}
rawDestination, err := json.Marshal(destination)
rawDestination, err := json.MarshalContext(ctx, destination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
}
return MergeFrom[T](rawSource, rawDestination, disableAppend)
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}
func MergeFromSource[T any](rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
if rawSource == nil {
return destination, nil
}
rawDestination, err := json.Marshal(destination)
rawDestination, err := json.MarshalContext(ctx, destination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
}
return MergeFrom[T](rawSource, rawDestination, disableAppend)
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}
func MergeFromDestination[T any](source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
if rawDestination == nil {
return source, nil
}
rawSource, err := json.Marshal(source)
rawSource, err := json.MarshalContext(ctx, source)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source")
}
return MergeFrom[T](rawSource, rawDestination, disableAppend)
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}
func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
rawMerged, err := MergeJSON(rawSource, rawDestination, disableAppend)
func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "merge options")
}
var merged T
err = json.Unmarshal(rawMerged, &merged)
err = json.UnmarshalContext(ctx, rawMerged, &merged)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
}
return merged, nil
}
func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
if rawSource == nil && rawDestination == nil {
return nil, os.ErrInvalid
} else if rawSource == nil {
@ -85,16 +86,16 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disabl
} else if rawDestination == nil {
return rawSource, nil
}
source, err := Decode(rawSource)
source, err := Decode(ctx, rawSource)
if err != nil {
return nil, E.Cause(err, "decode source")
}
destination, err := Decode(rawDestination)
destination, err := Decode(ctx, rawDestination)
if err != nil {
return nil, E.Cause(err, "decode destination")
}
if source == nil {
return json.Marshal(destination)
return json.MarshalContext(ctx, destination)
} else if destination == nil {
return json.Marshal(source)
}
@ -102,7 +103,7 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disabl
if err != nil {
return nil, err
}
return json.Marshal(merged)
return json.MarshalContext(ctx, merged)
}
func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) {

View file

@ -1,32 +1,42 @@
package badjson
import (
"context"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
func MarshallObjects(objects ...any) ([]byte, error) {
return MarshallObjectsContext(context.Background(), objects...)
}
func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) {
if len(objects) == 1 {
return json.Marshal(objects[0])
}
var content JSONObject
for _, object := range objects {
objectMap, err := newJSONObject(object)
objectMap, err := newJSONObject(ctx, object)
if err != nil {
return nil, err
}
content.PutAll(objectMap)
}
return content.MarshalJSON()
return content.MarshalJSONContext(ctx)
}
func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error {
parentContent, err := newJSONObject(parentObject)
return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object)
}
func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
parentContent, err := newJSONObject(ctx, parentObject)
if err != nil {
return err
}
var content JSONObject
err = content.UnmarshalJSON(inputContent)
err = content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return err
}
@ -39,20 +49,20 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error
}
return E.New("unexpected key: ", content.Keys()[0])
}
inputContent, err = content.MarshalJSON()
inputContent, err = content.MarshalJSONContext(ctx)
if err != nil {
return err
}
return json.UnmarshalDisallowUnknownFields(inputContent, object)
return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object)
}
func newJSONObject(object any) (*JSONObject, error) {
inputContent, err := json.Marshal(object)
func newJSONObject(ctx context.Context, object any) (*JSONObject, error) {
inputContent, err := json.MarshalContext(ctx, object)
if err != nil {
return nil, err
}
var content JSONObject
err = content.UnmarshalJSON(inputContent)
err = content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return nil, err
}

View file

@ -2,6 +2,7 @@ package badjson
import (
"bytes"
"context"
"strings"
"github.com/sagernet/sing/common"
@ -28,6 +29,10 @@ func (m *JSONObject) IsEmpty() bool {
}
func (m *JSONObject) MarshalJSON() ([]byte, error) {
return m.MarshalJSONContext(context.Background())
}
func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
buffer := new(bytes.Buffer)
buffer.WriteString("{")
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
@ -38,13 +43,13 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) {
})
iLen := len(items)
for i, entry := range items {
keyContent, err := json.Marshal(entry.Key)
keyContent, err := json.MarshalContext(ctx, entry.Key)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ")
valueContent, err := json.Marshal(entry.Value)
valueContent, err := json.MarshalContext(ctx, entry.Value)
if err != nil {
return nil, err
}
@ -58,7 +63,11 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) {
}
func (m *JSONObject) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content))
return m.UnmarshalJSONContext(context.Background(), content)
}
func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
m.Clear()
objectStart, err := decoder.Token()
if err != nil {

View file

@ -2,6 +2,7 @@ package badjson
import (
"bytes"
"context"
"strings"
E "github.com/sagernet/sing/common/exceptions"
@ -14,18 +15,22 @@ type TypedMap[K comparable, V any] struct {
}
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
return m.MarshalJSONContext(context.Background())
}
func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
buffer := new(bytes.Buffer)
buffer.WriteString("{")
items := m.Entries()
iLen := len(items)
for i, entry := range items {
keyContent, err := json.Marshal(entry.Key)
keyContent, err := json.MarshalContext(ctx, entry.Key)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ")
valueContent, err := json.Marshal(entry.Value)
valueContent, err := json.MarshalContext(ctx, entry.Value)
if err != nil {
return nil, err
}
@ -39,7 +44,11 @@ func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
}
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content))
return m.UnmarshalJSONContext(context.Background(), content)
}
func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
m.Clear()
objectStart, err := decoder.Token()
if err != nil {
@ -47,7 +56,7 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
} else if objectStart != json.Delim('{') {
return E.New("expected json object start, but starts with ", objectStart)
}
err = m.decodeJSON(decoder)
err = m.decodeJSON(ctx, decoder)
if err != nil {
return E.Cause(err, "decode json object content")
}
@ -60,18 +69,18 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
return nil
}
func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error {
func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error {
for decoder.More() {
keyToken, err := decoder.Token()
if err != nil {
return err
}
keyContent, err := json.Marshal(keyToken)
keyContent, err := json.MarshalContext(ctx, keyToken)
if err != nil {
return err
}
var entryKey K
err = json.Unmarshal(keyContent, &entryKey)
err = json.UnmarshalContext(ctx, keyContent, &entryKey)
if err != nil {
return err
}

View file

@ -0,0 +1,23 @@
package json
import (
"context"
"github.com/sagernet/sing/common/json/internal/contextjson"
)
var (
MarshalContext = json.MarshalContext
UnmarshalContext = json.UnmarshalContext
NewEncoderContext = json.NewEncoderContext
NewDecoderContext = json.NewDecoderContext
UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields
)
type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}
type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}

View file

@ -0,0 +1,11 @@
package json
import "context"
type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}
type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}

View file

@ -0,0 +1,43 @@
package json_test
import (
"context"
"testing"
"github.com/sagernet/sing/common/json/internal/contextjson"
"github.com/stretchr/testify/require"
)
type myStruct struct {
value string
}
func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) {
return json.Marshal(ctx.Value("key").(string))
}
func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error {
m.value = ctx.Value("key").(string)
return nil
}
//nolint:staticcheck
func TestMarshalContext(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), "key", "value")
var s myStruct
b, err := json.MarshalContext(ctx, &s)
require.NoError(t, err)
require.Equal(t, []byte(`"value"`), b)
}
//nolint:staticcheck
func TestUnmarshalContext(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), "key", "value")
var s myStruct
err := json.UnmarshalContext(ctx, []byte(`{}`), &s)
require.NoError(t, err)
require.Equal(t, "value", s.value)
}

View file

@ -8,6 +8,7 @@
package json
import (
"context"
"encoding"
"encoding/base64"
"fmt"
@ -95,10 +96,15 @@ import (
// Instead, they are replaced by the Unicode replacement
// character U+FFFD.
func Unmarshal(data []byte, v any) error {
return UnmarshalContext(context.Background(), data, v)
}
func UnmarshalContext(ctx context.Context, data []byte, v any) error {
// Check for well-formedness.
// Avoids filling out half a data structure
// before discovering a JSON syntax error.
var d decodeState
d.ctx = ctx
err := checkValid(data, &d.scan)
if err != nil {
return err
@ -209,6 +215,7 @@ type errorContext struct {
// decodeState represents the state while decoding a JSON value.
type decodeState struct {
ctx context.Context
data []byte
off int // next read offset in data
opcode int // last read result
@ -428,7 +435,7 @@ func (d *decodeState) valueQuoted() any {
// If it encounters an Unmarshaler, indirect stops and returns that.
// If decodingNull is true, indirect stops at the first settable pointer so it
// can be set to nil.
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) {
// Issue #24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from
@ -482,11 +489,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
}
if v.Type().NumMethod() > 0 && v.CanInterface() {
if u, ok := v.Interface().(Unmarshaler); ok {
return u, nil, reflect.Value{}
return u, nil, nil, reflect.Value{}
}
if cu, ok := v.Interface().(ContextUnmarshaler); ok {
return nil, cu, nil, reflect.Value{}
}
if !decodingNull {
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return nil, u, reflect.Value{}
return nil, nil, u, reflect.Value{}
}
}
}
@ -498,14 +508,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
v = v.Elem()
}
}
return nil, nil, v
return nil, nil, nil, v
}
// array consumes an array from d.data[d.off-1:], decoding into v.
// The first byte of the array ('[') has been read already.
func (d *decodeState) array(v reflect.Value) error {
// Check for unmarshaler.
u, ut, pv := indirect(v, false)
u, cu, ut, pv := indirect(v, false)
if u != nil {
start := d.readIndex()
d.skip()
@ -515,6 +525,15 @@ func (d *decodeState) array(v reflect.Value) error {
}
return nil
}
if cu != nil {
start := d.readIndex()
d.skip()
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
d.skip()
@ -612,7 +631,7 @@ var (
// The first byte ('{') of the object has been read already.
func (d *decodeState) object(v reflect.Value) error {
// Check for unmarshaler.
u, ut, pv := indirect(v, false)
u, cu, ut, pv := indirect(v, false)
if u != nil {
start := d.readIndex()
d.skip()
@ -622,6 +641,15 @@ func (d *decodeState) object(v reflect.Value) error {
}
return nil
}
if cu != nil {
start := d.readIndex()
d.skip()
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
d.skip()
@ -870,7 +898,7 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
return nil
}
isNull := item[0] == 'n' // null
u, ut, pv := indirect(v, isNull)
u, cu, ut, pv := indirect(v, isNull)
if u != nil {
err := u.UnmarshalJSON(item)
if err != nil {
@ -878,6 +906,13 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
}
return nil
}
if cu != nil {
err := cu.UnmarshalJSONContext(d.ctx, item)
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil {
if item[0] != '"' {
if fromQuoted {

View file

@ -12,6 +12,7 @@ package json
import (
"bytes"
"context"
"encoding"
"encoding/base64"
"fmt"
@ -156,7 +157,11 @@ import (
// handle them. Passing cyclic structures to Marshal will result in
// an error.
func Marshal(v any) ([]byte, error) {
e := newEncodeState()
return MarshalContext(context.Background(), v)
}
func MarshalContext(ctx context.Context, v any) ([]byte, error) {
e := newEncodeState(ctx)
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: true})
@ -251,6 +256,7 @@ var hex = "0123456789abcdef"
type encodeState struct {
bytes.Buffer // accumulated output
ctx context.Context
// Keep track of what pointers we've seen in the current recursive call
// path, to avoid cycles that could lead to a stack overflow. Only do
// the relatively expensive map operations if ptrLevel is larger than
@ -264,7 +270,7 @@ const startDetectingCyclesAfter = 1000
var encodeStatePool sync.Pool
func newEncodeState() *encodeState {
func newEncodeState(ctx context.Context) *encodeState {
if v := encodeStatePool.Get(); v != nil {
e := v.(*encodeState)
e.Reset()
@ -274,7 +280,7 @@ func newEncodeState() *encodeState {
e.ptrLevel = 0
return e
}
return &encodeState{ptrSeen: make(map[any]struct{})}
return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})}
}
// jsonError is an error wrapper type for internal use only.
@ -371,8 +377,9 @@ func typeEncoder(t reflect.Type) encoderFunc {
}
var (
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
)
// newTypeEncoder constructs an encoderFunc for a type.
@ -385,9 +392,15 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) {
return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Implements(marshalerType) {
return marshalerEncoder
}
if t.Implements(contextMarshalerType) {
return contextMarshalerEncoder
}
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
}
@ -470,6 +483,47 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
}
}
func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null")
return
}
m, ok := v.Interface().(ContextMarshaler)
if !ok {
e.WriteString("null")
return
}
b, err := m.MarshalJSONContext(e.ctx)
if err == nil {
e.Grow(len(b))
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
va := v.Addr()
if va.IsNil() {
e.WriteString("null")
return
}
m := va.Interface().(ContextMarshaler)
b, err := m.MarshalJSONContext(e.ctx)
if err == nil {
e.Grow(len(b))
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null")
@ -827,7 +881,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc {
// Byte slices get special treatment; arrays don't.
if t.Elem().Kind() == reflect.Uint8 {
p := reflect.PointerTo(t.Elem())
if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) {
if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) {
return encodeByteSlice
}
}

View file

@ -6,6 +6,7 @@ package json
import (
"bytes"
"context"
"errors"
"io"
)
@ -29,7 +30,11 @@ type Decoder struct {
// The decoder introduces its own buffering and may
// read data from r beyond the JSON values requested.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
return NewDecoderContext(context.Background(), r)
}
func NewDecoderContext(ctx context.Context, r io.Reader) *Decoder {
return &Decoder{r: r, d: decodeState{ctx: ctx}}
}
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
@ -183,6 +188,7 @@ func nonSpace(b []byte) bool {
// An Encoder writes JSON values to an output stream.
type Encoder struct {
ctx context.Context
w io.Writer
err error
escapeHTML bool
@ -194,7 +200,11 @@ type Encoder struct {
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w, escapeHTML: true}
return NewEncoderContext(context.Background(), w)
}
func NewEncoderContext(ctx context.Context, w io.Writer) *Encoder {
return &Encoder{ctx: ctx, w: w, escapeHTML: true}
}
// Encode writes the JSON encoding of v to the stream,
@ -207,7 +217,7 @@ func (enc *Encoder) Encode(v any) error {
return enc.err
}
e := newEncodeState()
e := newEncodeState(enc.ctx)
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})

View file

@ -1,5 +1,7 @@
package json
import "context"
func UnmarshalDisallowUnknownFields(data []byte, v any) error {
var d decodeState
d.disallowUnknownFields = true
@ -10,3 +12,15 @@ func UnmarshalDisallowUnknownFields(data []byte, v any) error {
d.init(data)
return d.unmarshal(v)
}
func UnmarshalContextDisallowUnknownFields(ctx context.Context, data []byte, v any) error {
var d decodeState
d.ctx = ctx
d.disallowUnknownFields = true
err := checkValid(data, &d.scan)
if err != nil {
return err
}
d.init(data)
return d.unmarshal(v)
}

View file

@ -2,6 +2,7 @@ package json
import (
"bytes"
"context"
"errors"
"strings"
@ -10,7 +11,11 @@ import (
)
func UnmarshalExtended[T any](content []byte) (T, error) {
decoder := NewDecoder(NewCommentFilter(bytes.NewReader(content)))
return UnmarshalExtendedContext[T](context.Background(), content)
}
func UnmarshalExtendedContext[T any](ctx context.Context, content []byte) (T, error) {
decoder := NewDecoderContext(ctx, NewCommentFilter(bytes.NewReader(content)))
var value T
err := decoder.Decode(&value)
if err == nil {