diff --git a/common/json/badjson/merge.go b/common/json/badjson/merge.go index b6534d8..61b5f4c 100644 --- a/common/json/badjson/merge.go +++ b/common/json/badjson/merge.go @@ -1,6 +1,7 @@ package badjson import ( + "os" "reflect" "github.com/sagernet/sing/common" @@ -42,6 +43,9 @@ func Merge[T any](source T, destination T) (T, error) { } func MergeFromSource[T any](rawSource json.RawMessage, destination T) (T, error) { + if rawSource == nil { + return destination, nil + } rawDestination, err := json.Marshal(destination) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal destination") @@ -50,6 +54,9 @@ func MergeFromSource[T any](rawSource json.RawMessage, destination T) (T, error) } func MergeFromDestination[T any](source T, rawDestination json.RawMessage) (T, error) { + if rawDestination == nil { + return source, nil + } rawSource, err := json.Marshal(source) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal source") @@ -71,6 +78,13 @@ func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage) } func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage) (json.RawMessage, error) { + if rawSource == nil && rawDestination == nil { + return nil, os.ErrInvalid + } else if rawSource == nil { + return rawDestination, nil + } else if rawDestination == nil { + return rawSource, nil + } source, err := Decode(rawSource) if err != nil { return nil, E.Cause(err, "decode source")