diff --git a/core/auth/auth.go b/core/auth/auth.go index aa644ad8b..e03e8ed39 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -16,7 +16,7 @@ import ( var ( once sync.Once - JwtSecret []byte + Secret []byte TokenAuth *jwtauth.JWTAuth sessionTimeOut time.Duration ) @@ -27,8 +27,8 @@ func InitTokenAuth(ds model.DataStore) { if err != nil { log.Error("No JWT secret found in DB. Setting a temp one, but please report this error", err) } - JwtSecret = []byte(secret) - TokenAuth = jwtauth.New("HS256", JwtSecret, nil) + Secret = []byte(secret) + TokenAuth = jwtauth.New("HS256", Secret, nil) }) } @@ -57,19 +57,21 @@ func TouchToken(token *jwt.Token) (string, error) { claims := token.Claims.(jwt.MapClaims) claims["exp"] = expireIn - return token.SignedString(JwtSecret) + return token.SignedString(Secret) +} + +func keyFunc(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + // hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key") + return Secret, nil } func Validate(tokenStr string) (jwt.MapClaims, error) { - token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { - // Don't forget to validate the alg is what you expect: - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) - } - - // hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key") - return JwtSecret, nil - }) + token, err := jwt.Parse(tokenStr, keyFunc) if err != nil { return nil, err } diff --git a/core/auth/auth_test.go b/core/auth/auth_test.go index b84c98ab2..f4414579e 100644 --- a/core/auth/auth_test.go +++ b/core/auth/auth_test.go @@ -4,6 +4,11 @@ import ( "testing" "time" + "github.com/navidrome/navidrome/conf" + + "github.com/navidrome/navidrome/consts" + "github.com/navidrome/navidrome/model" + "github.com/dgrijalva/jwt-go" "github.com/navidrome/navidrome/core/auth" "github.com/navidrome/navidrome/log" @@ -17,13 +22,21 @@ func TestAuth(t *testing.T) { RunSpecs(t, "Auth Test Suite") } -const testJWTSecret = "not so secret" +const ( + testJWTSecret = "not so secret" + oneDay = 24 * time.Hour +) var _ = Describe("Auth", func() { - BeforeEach(func() { - auth.JwtSecret = []byte(testJWTSecret) + BeforeSuite(func() { + conf.Server.SessionTimeout = 2 * oneDay }) - Context("Validate", func() { + + BeforeEach(func() { + auth.Secret = []byte(testJWTSecret) + }) + + Describe("Validate", func() { It("returns error with an invalid JWT token", func() { _, err := auth.Validate("invalid.token") Expect(err).To(Not(BeNil())) @@ -34,7 +47,7 @@ var _ = Describe("Auth", func() { claims := token.Claims.(jwt.MapClaims) claims["iss"] = "issuer" claims["exp"] = time.Now().Add(1 * time.Minute).Unix() - tokenStr, _ := token.SignedString(auth.JwtSecret) + tokenStr, _ := token.SignedString(auth.Secret) decodedClaims, err := auth.Validate(tokenStr) Expect(err).To(BeNil()) @@ -46,10 +59,51 @@ var _ = Describe("Auth", func() { claims := token.Claims.(jwt.MapClaims) claims["iss"] = "issuer" claims["exp"] = time.Now().Add(-1 * time.Minute).Unix() - tokenStr, _ := token.SignedString(auth.JwtSecret) + tokenStr, _ := token.SignedString(auth.Secret) _, err := auth.Validate(tokenStr) Expect(err).To(MatchError("Token is expired")) }) }) + + Describe("CreateToken", func() { + It("creates a valid token", func() { + u := &model.User{ + ID: "123", + UserName: "johndoe", + IsAdmin: true, + } + tokenStr, err := auth.CreateToken(u) + Expect(err).To(BeNil()) + + claims, err := auth.Validate(tokenStr) + Expect(err).To(BeNil()) + + Expect(claims["iss"]).To(Equal(consts.JWTIssuer)) + Expect(claims["sub"]).To(Equal("johndoe")) + Expect(claims["uid"]).To(Equal("123")) + Expect(claims["adm"]).To(Equal(true)) + + exp := time.Unix(int64(claims["exp"].(float64)), 0) + Expect(exp).To(BeTemporally(">", time.Now())) + }) + }) + + Describe("TouchToken", func() { + It("updates the expiration time", func() { + yesterday := time.Now().Add(-oneDay) + token := jwt.New(jwt.SigningMethodHS256) + claims := token.Claims.(jwt.MapClaims) + claims["iss"] = "issuer" + claims["exp"] = yesterday.Unix() + + touched, err := auth.TouchToken(token) + Expect(err).To(BeNil()) + + decodedClaims, err := auth.Validate(touched) + Expect(err).To(BeNil()) + expiration := time.Unix(int64(decodedClaims["exp"].(float64)), 0) + Expect(expiration.Sub(yesterday)).To(BeNumerically(">=", oneDay)) + }) + }) })