diff --git a/common/json/badjson/typed.go b/common/json/badjson/typed.go index 0f83303..66f41a6 100644 --- a/common/json/badjson/typed.go +++ b/common/json/badjson/typed.go @@ -9,11 +9,11 @@ import ( "github.com/sagernet/sing/common/x/linkedhashmap" ) -type TypedMap[T any] struct { - linkedhashmap.Map[string, T] +type TypedMap[K comparable, V any] struct { + linkedhashmap.Map[K, V] } -func (m TypedMap[T]) MarshalJSON() ([]byte, error) { +func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) { buffer := new(bytes.Buffer) buffer.WriteString("{") items := m.Entries() @@ -38,7 +38,7 @@ func (m TypedMap[T]) MarshalJSON() ([]byte, error) { return buffer.Bytes(), nil } -func (m *TypedMap[T]) UnmarshalJSON(content []byte) error { +func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error { decoder := json.NewDecoder(bytes.NewReader(content)) m.Clear() objectStart, err := decoder.Token() @@ -60,15 +60,22 @@ func (m *TypedMap[T]) UnmarshalJSON(content []byte) error { return nil } -func (m *TypedMap[T]) decodeJSON(decoder *json.Decoder) error { +func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error { for decoder.More() { - var entryKey string keyToken, err := decoder.Token() if err != nil { return err } - entryKey = keyToken.(string) - var entryValue T + keyContent, err := json.Marshal(keyToken) + if err != nil { + return err + } + var entryKey K + err = json.Unmarshal(keyContent, &entryKey) + if err != nil { + return err + } + var entryValue V err = decoder.Decode(&entryValue) if err != nil { return err