diff --git a/cmd/root.go b/cmd/root.go index 8e3ba9c5c..f6ea22a0d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -75,6 +75,9 @@ func startServer() (func() error, func(err error)) { a := CreateServer(conf.Server.MusicFolder) a.MountRouter("Subsonic API", consts.URLPathSubsonicAPI, CreateSubsonicAPIRouter()) a.MountRouter("Native API", consts.URLPathNativeAPI, CreateNativeAPIRouter()) + if conf.Server.DevEnableScrobble { + a.MountRouter("LastFM Auth", consts.URLPathNativeAPI+"/lastfm", CreateLastFMRouter()) + } return a.Run(fmt.Sprintf("%s:%d", conf.Server.Address, conf.Server.Port)) }, func(err error) { if err != nil { diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index a35e0d4a0..a46934ad7 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -8,6 +8,7 @@ package cmd import ( "github.com/google/wire" "github.com/navidrome/navidrome/core" + "github.com/navidrome/navidrome/core/agents/lastfm" "github.com/navidrome/navidrome/core/scrobbler" "github.com/navidrome/navidrome/core/transcoder" "github.com/navidrome/navidrome/persistence" @@ -53,6 +54,12 @@ func CreateSubsonicAPIRouter() *subsonic.Router { return router } +func CreateLastFMRouter() *lastfm.Router { + dataStore := persistence.New() + router := lastfm.NewRouter(dataStore) + return router +} + func createScanner() scanner.Scanner { dataStore := persistence.New() artworkCache := core.GetImageCache() @@ -75,7 +82,7 @@ func createScheduler() scheduler.Scheduler { // wire_injectors.go: -var allProviders = wire.NewSet(core.Set, subsonic.New, nativeapi.New, persistence.New, GetBroker) +var allProviders = wire.NewSet(core.Set, subsonic.New, nativeapi.New, persistence.New, lastfm.NewRouter, GetBroker) // Scanner must be a Singleton var ( diff --git a/cmd/wire_injectors.go b/cmd/wire_injectors.go index 352ed2683..95e9c913b 100644 --- a/cmd/wire_injectors.go +++ b/cmd/wire_injectors.go @@ -7,6 +7,7 @@ import ( "github.com/google/wire" "github.com/navidrome/navidrome/core" + "github.com/navidrome/navidrome/core/agents/lastfm" "github.com/navidrome/navidrome/persistence" "github.com/navidrome/navidrome/scanner" "github.com/navidrome/navidrome/scheduler" @@ -21,6 +22,7 @@ var allProviders = wire.NewSet( subsonic.New, nativeapi.New, persistence.New, + lastfm.NewRouter, GetBroker, ) @@ -44,6 +46,12 @@ func CreateSubsonicAPIRouter() *subsonic.Router { )) } +func CreateLastFMRouter() *lastfm.Router { + panic(wire.Build( + allProviders, + )) +} + // Scanner must be a Singleton var ( onceScanner sync.Once diff --git a/core/agents/lastfm/auth_router.go b/core/agents/lastfm/auth_router.go new file mode 100644 index 000000000..7fce1dd17 --- /dev/null +++ b/core/agents/lastfm/auth_router.go @@ -0,0 +1,195 @@ +package lastfm + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/navidrome/navidrome/model/request" + "github.com/navidrome/navidrome/server" + + "github.com/ReneKroon/ttlcache/v2" + + "github.com/deluan/rest" + + "github.com/go-chi/chi/v5" + "github.com/navidrome/navidrome/conf" + "github.com/navidrome/navidrome/log" + "github.com/navidrome/navidrome/model" +) + +const ( + authURL = "https://www.last.fm/api/auth/" + sessionKeyPropertyPrefix = "LastFMSessionKey_" +) + +var ( + ErrLinkPending = errors.New("linking pending") + ErrUnlinked = errors.New("account not linked") +) + +type Router struct { + http.Handler + ds model.DataStore + client *Client + sessionMan *sessionMan + apiKey string + secret string +} + +func NewRouter(ds model.DataStore) *Router { + r := &Router{ds: ds, apiKey: lastFMAPIKey, secret: lastFMAPISecret} + r.Handler = r.routes() + if conf.Server.LastFM.ApiKey != "" { + r.apiKey = conf.Server.LastFM.ApiKey + r.secret = conf.Server.LastFM.Secret + } + r.client = NewClient(r.apiKey, r.secret, "en", http.DefaultClient) + r.sessionMan = newSessionMan(ds, r.client) + return r +} + +func (s *Router) routes() http.Handler { + r := chi.NewRouter() + + r.Use(server.Authenticator(s.ds)) + r.Use(server.JWTRefresher) + + r.Get("/link", s.starLink) + r.Get("/link/status", s.getLinkStatus) + r.Delete("/link", s.unlink) + + return r +} + +func (s *Router) starLink(w http.ResponseWriter, r *http.Request) { + token, err := s.client.GetToken(r.Context()) + if err != nil { + log.Error(r.Context(), "Error obtaining token from LastFM", err) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(fmt.Sprintf("Error obtaining token from LastFM: %s", err))) + return + } + username, _ := request.UsernameFrom(r.Context()) + s.sessionMan.FetchSession(username, token) + params := url.Values{} + params.Add("api_key", s.apiKey) + params.Add("token", token) + http.Redirect(w, r, authURL+"?"+params.Encode(), http.StatusFound) +} + +func (s *Router) getLinkStatus(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + username, _ := request.UsernameFrom(ctx) + _, err := s.sessionMan.Session(ctx, username) + resp := map[string]string{"status": "linked"} + if err != nil { + switch err { + case ErrLinkPending: + resp["status"] = "pending" + case ErrUnlinked: + resp["status"] = "unlinked" + default: + resp["status"] = "unlinked" + resp["error"] = err.Error() + _ = rest.RespondWithJSON(w, http.StatusInternalServerError, resp) + return + } + } + _ = rest.RespondWithJSON(w, http.StatusOK, resp) +} + +func (s *Router) unlink(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + username, _ := request.UsernameFrom(ctx) + err := s.sessionMan.RemoveSession(ctx, username) + if err != nil { + _ = rest.RespondWithError(w, http.StatusInternalServerError, err.Error()) + } else { + _ = rest.RespondWithJSON(w, http.StatusOK, map[string]string{}) + } +} + +type sessionMan struct { + ds model.DataStore + client *Client + tokens *ttlcache.Cache +} + +func newSessionMan(ds model.DataStore, client *Client) *sessionMan { + s := &sessionMan{ + ds: ds, + client: client, + } + s.tokens = ttlcache.NewCache() + s.tokens.SetCacheSizeLimit(0) + _ = s.tokens.SetTTL(30 * time.Second) + s.tokens.SkipTTLExtensionOnHit(true) + go s.run() + return s +} + +func (s *sessionMan) FetchSession(username, token string) { + _ = s.ds.Property(context.Background()).Delete(sessionKeyPropertyPrefix + username) + _ = s.tokens.Set(username, token) +} + +func (s *sessionMan) Session(ctx context.Context, username string) (string, error) { + properties := s.ds.Property(context.Background()) + key, err := properties.Get(sessionKeyPropertyPrefix + username) + if key != "" { + return key, nil + } + if err != nil && err != model.ErrNotFound { + return "", err + } + _, err = s.tokens.Get(username) + if err == nil { + return "", ErrLinkPending + } + return "", ErrUnlinked +} + +func (s *sessionMan) RemoveSession(ctx context.Context, username string) error { + _ = s.tokens.Remove(username) + properties := s.ds.Property(context.Background()) + return properties.Delete(sessionKeyPropertyPrefix + username) +} + +func (s *sessionMan) run() { + t := time.NewTicker(2 * time.Second) + defer t.Stop() + for { + <-t.C + if s.tokens.Count() == 0 { + continue + } + s.fetchSessions() + } +} + +func (s *sessionMan) fetchSessions() { + ctx := context.Background() + for _, username := range s.tokens.GetKeys() { + token, err := s.tokens.Get(username) + if err != nil { + log.Error("Error retrieving token from cache", "username", username, err) + _ = s.tokens.Remove(username) + continue + } + sessionKey, err := s.client.GetSession(ctx, token.(string)) + log.Debug(ctx, "Fetching session", "username", username, "sessionKey", sessionKey, "token", token, err) + if err != nil { + continue + } + properties := s.ds.Property(ctx) + err = properties.Put(sessionKeyPropertyPrefix+username, sessionKey) + if err != nil { + log.Error("Could not save LastFM session key", "username", username, err) + } + _ = s.tokens.Remove(username) + } +} diff --git a/model/properties.go b/model/properties.go index 697a38789..0c3f100cb 100644 --- a/model/properties.go +++ b/model/properties.go @@ -12,5 +12,6 @@ type Property struct { type PropertyRepository interface { Put(id string, value string) error Get(id string) (string, error) + Delete(id string) error DefaultGet(id string, defaultValue string) (string, error) } diff --git a/persistence/property_repository.go b/persistence/property_repository.go index d2a6ddc09..b2c9542ce 100644 --- a/persistence/property_repository.go +++ b/persistence/property_repository.go @@ -3,7 +3,7 @@ package persistence import ( "context" - "github.com/Masterminds/squirrel" + . "github.com/Masterminds/squirrel" "github.com/astaxie/beego/orm" "github.com/navidrome/navidrome/model" ) @@ -21,7 +21,7 @@ func NewPropertyRepository(ctx context.Context, o orm.Ormer) model.PropertyRepos } func (r propertyRepository) Put(id string, value string) error { - update := squirrel.Update(r.tableName).Set("value", value).Where(squirrel.Eq{"id": id}) + update := Update(r.tableName).Set("value", value).Where(Eq{"id": id}) count, err := r.executeSQL(update) if err != nil { return nil @@ -29,13 +29,13 @@ func (r propertyRepository) Put(id string, value string) error { if count > 0 { return nil } - insert := squirrel.Insert(r.tableName).Columns("id", "value").Values(id, value) + insert := Insert(r.tableName).Columns("id", "value").Values(id, value) _, err = r.executeSQL(insert) return err } func (r propertyRepository) Get(id string) (string, error) { - sel := squirrel.Select("value").From(r.tableName).Where(squirrel.Eq{"id": id}) + sel := Select("value").From(r.tableName).Where(Eq{"id": id}) resp := struct { Value string }{} @@ -56,3 +56,7 @@ func (r propertyRepository) DefaultGet(id string, defaultValue string) (string, } return value, nil } + +func (r propertyRepository) Delete(id string) error { + return r.delete(Eq{"id": id}) +} diff --git a/server/server.go b/server/server.go index 1395738e0..a602d969a 100644 --- a/server/server.go +++ b/server/server.go @@ -15,7 +15,6 @@ import ( "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/ui" - "github.com/navidrome/navidrome/utils" ) type Server struct { @@ -85,15 +84,6 @@ func (s *Server) initRoutes() { r.Post("/createAdmin", createAdmin(s.ds)) }) - r.Get("/api/lastfm/link/status", func(w http.ResponseWriter, r *http.Request) { - rs := "false" - c := utils.ParamInt(r, "c", 0) - if (c == 4) { - rs = "true" - } - _, _ = w.Write([]byte(rs)) - }) - // Redirect root to UI URL r.Get("/*", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, s.appRoot+"/", http.StatusFound) diff --git a/tests/mock_property_repo.go b/tests/mock_property_repo.go index 0ca66a357..dfc9c7c96 100644 --- a/tests/mock_property_repo.go +++ b/tests/mock_property_repo.go @@ -3,7 +3,7 @@ package tests import "github.com/navidrome/navidrome/model" type MockedPropertyRepo struct { - model.UserRepository + model.PropertyRepository data map[string]string err error }