mirror of
https://github.com/navidrome/navidrome.git
synced 2025-04-03 20:47:35 +03:00
172 lines
4.2 KiB
Go
172 lines
4.2 KiB
Go
package persistence
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
|
|
. "github.com/Masterminds/squirrel"
|
|
"github.com/deluan/rest"
|
|
"github.com/fatih/structs"
|
|
"github.com/navidrome/navidrome/log"
|
|
"github.com/navidrome/navidrome/model"
|
|
)
|
|
|
|
type filterFunc = func(field string, value any) Sqlizer
|
|
|
|
func (r *sqlRepository) parseRestFilters(ctx context.Context, options rest.QueryOptions) Sqlizer {
|
|
if len(options.Filters) == 0 {
|
|
return nil
|
|
}
|
|
filters := And{}
|
|
for f, v := range options.Filters {
|
|
// Ignore filters with empty values
|
|
if v == "" {
|
|
continue
|
|
}
|
|
// Look for a custom filter function
|
|
f = strings.ToLower(f)
|
|
if ff, ok := r.filterMappings[f]; ok {
|
|
filters = append(filters, ff(f, v))
|
|
continue
|
|
}
|
|
// Ignore invalid filters (not based on a field or filter function)
|
|
if r.isFieldWhiteListed != nil && !r.isFieldWhiteListed(f) {
|
|
log.Warn(ctx, "Ignoring filter not whitelisted", "filter", f)
|
|
continue
|
|
}
|
|
// For fields ending in "id", use an exact match
|
|
if strings.HasSuffix(f, "id") {
|
|
filters = append(filters, eqFilter(f, v))
|
|
continue
|
|
}
|
|
// Default to a "starts with" filter
|
|
filters = append(filters, startsWithFilter(f, v))
|
|
}
|
|
return filters
|
|
}
|
|
|
|
func (r *sqlRepository) parseRestOptions(ctx context.Context, options ...rest.QueryOptions) model.QueryOptions {
|
|
qo := model.QueryOptions{}
|
|
if len(options) > 0 {
|
|
qo.Sort, qo.Order = r.sanitizeSort(options[0].Sort, options[0].Order)
|
|
qo.Max = options[0].Max
|
|
qo.Offset = options[0].Offset
|
|
if seed, ok := options[0].Filters["seed"].(string); ok {
|
|
qo.Seed = seed
|
|
delete(options[0].Filters, "seed")
|
|
}
|
|
qo.Filters = r.parseRestFilters(ctx, options[0])
|
|
}
|
|
return qo
|
|
}
|
|
|
|
func (r sqlRepository) sanitizeSort(sort, order string) (string, string) {
|
|
if sort != "" {
|
|
sort = toSnakeCase(sort)
|
|
if mapped, ok := r.sortMappings[sort]; ok {
|
|
sort = mapped
|
|
} else {
|
|
if !r.isFieldWhiteListed(sort) {
|
|
log.Warn(r.ctx, "Ignoring sort not whitelisted", "sort", sort)
|
|
sort = ""
|
|
}
|
|
}
|
|
}
|
|
if order != "" {
|
|
order = strings.ToLower(order)
|
|
if order != "desc" {
|
|
order = "asc"
|
|
}
|
|
}
|
|
return sort, order
|
|
}
|
|
|
|
func eqFilter(field string, value any) Sqlizer {
|
|
return Eq{field: value}
|
|
}
|
|
|
|
func startsWithFilter(field string, value any) Sqlizer {
|
|
return Like{field: fmt.Sprintf("%s%%", value)}
|
|
}
|
|
|
|
func containsFilter(field string) func(string, any) Sqlizer {
|
|
return func(_ string, value any) Sqlizer {
|
|
return Like{field: fmt.Sprintf("%%%s%%", value)}
|
|
}
|
|
}
|
|
|
|
func booleanFilter(field string, value any) Sqlizer {
|
|
v := strings.ToLower(value.(string))
|
|
return Eq{field: strings.ToLower(v) == "true"}
|
|
}
|
|
|
|
func fullTextFilter(_ string, value any) Sqlizer {
|
|
return fullTextExpr(value.(string))
|
|
}
|
|
|
|
func substringFilter(field string, value any) Sqlizer {
|
|
parts := strings.Split(value.(string), " ")
|
|
filters := And{}
|
|
for _, part := range parts {
|
|
filters = append(filters, Like{field: "%" + part + "%"})
|
|
}
|
|
return filters
|
|
}
|
|
|
|
func idFilter(tableName string) func(string, any) Sqlizer {
|
|
return func(field string, value any) Sqlizer {
|
|
return Eq{tableName + ".id": value}
|
|
}
|
|
}
|
|
|
|
func invalidFilter(ctx context.Context) func(string, any) Sqlizer {
|
|
return func(field string, value any) Sqlizer {
|
|
log.Warn(ctx, "Invalid filter", "fieldName", field, "value", value)
|
|
return Eq{"1": "0"}
|
|
}
|
|
}
|
|
|
|
var (
|
|
whiteList = map[string]map[string]struct{}{}
|
|
mutex sync.RWMutex
|
|
)
|
|
|
|
func registerModelWhiteList(instance any) fieldWhiteListedFunc {
|
|
name := reflect.TypeOf(instance).String()
|
|
registerFieldWhiteList(name, instance)
|
|
return getFieldWhiteListedFunc(name)
|
|
}
|
|
|
|
func registerFieldWhiteList(name string, instance any) {
|
|
mutex.Lock()
|
|
defer mutex.Unlock()
|
|
if whiteList[name] != nil {
|
|
return
|
|
}
|
|
m := structs.Map(instance)
|
|
whiteList[name] = map[string]struct{}{}
|
|
for k := range m {
|
|
whiteList[name][toSnakeCase(k)] = struct{}{}
|
|
}
|
|
ma := structs.Map(model.Annotations{})
|
|
for k := range ma {
|
|
whiteList[name][toSnakeCase(k)] = struct{}{}
|
|
}
|
|
}
|
|
|
|
type fieldWhiteListedFunc func(field string) bool
|
|
|
|
func getFieldWhiteListedFunc(tableName string) fieldWhiteListedFunc {
|
|
return func(field string) bool {
|
|
mutex.RLock()
|
|
defer mutex.RUnlock()
|
|
if _, ok := whiteList[tableName]; !ok {
|
|
return false
|
|
}
|
|
_, ok := whiteList[tableName][field]
|
|
return ok
|
|
}
|
|
}
|