mirror of
https://github.com/binwiederhier/ntfy.git
synced 2025-04-07 06:57:38 +03:00
Rename auth package to user; add extendToken feature
This commit is contained in:
parent
3aac1b2715
commit
d4c7ad4beb
14 changed files with 368 additions and 276 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/user"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -30,17 +31,17 @@ import (
|
|||
"github.com/emersion/go-smtp"
|
||||
"github.com/gorilla/websocket"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"heckel.io/ntfy/auth"
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
/*
|
||||
TODO
|
||||
expire tokens
|
||||
auto-extend tokens from UI
|
||||
use token auth in "SubscribeDialog"
|
||||
upload files based on user limit
|
||||
database migration
|
||||
publishXHR + poll should pick current user, not from userManager
|
||||
expire tokens
|
||||
auto-refresh tokens from UI
|
||||
reserve topics
|
||||
purge accounts that were not logged into in X
|
||||
sync subscription display name
|
||||
|
@ -55,7 +56,11 @@ import (
|
|||
Polishing:
|
||||
aria-label for everything
|
||||
|
||||
|
||||
Tests:
|
||||
- APIs
|
||||
- CRUD tokens
|
||||
- Expire tokens
|
||||
-
|
||||
*/
|
||||
|
||||
// Server is the main server, providing the UI and API for ntfy
|
||||
|
@ -71,7 +76,7 @@ type Server struct {
|
|||
visitors map[string]*visitor // ip:<ip> or user:<user>
|
||||
firebaseClient *firebaseClient
|
||||
messages int64
|
||||
auth auth.Manager
|
||||
userManager user.Manager
|
||||
messageCache *messageCache
|
||||
fileCache *fileCache
|
||||
closeChan chan bool
|
||||
|
@ -159,9 +164,9 @@ func New(conf *Config) (*Server, error) {
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
var auther auth.Manager
|
||||
var auther user.Manager
|
||||
if conf.AuthFile != "" {
|
||||
auther, err = auth.NewSQLiteAuthManager(conf.AuthFile, conf.AuthDefaultRead, conf.AuthDefaultWrite)
|
||||
auther, err = user.NewSQLiteAuthManager(conf.AuthFile, conf.AuthDefaultRead, conf.AuthDefaultWrite)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -181,7 +186,7 @@ func New(conf *Config) (*Server, error) {
|
|||
firebaseClient: firebaseClient,
|
||||
smtpSender: mailer,
|
||||
topics: topics,
|
||||
auth: auther,
|
||||
userManager: auther,
|
||||
visitors: make(map[string]*visitor),
|
||||
}, nil
|
||||
}
|
||||
|
@ -342,11 +347,13 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
|||
return s.handleAccountDelete(w, r, v)
|
||||
} else if r.Method == http.MethodPost && r.URL.Path == accountPasswordPath {
|
||||
return s.handleAccountPasswordChange(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == accountTokenPath {
|
||||
return s.handleAccountTokenGet(w, r, v)
|
||||
} else if r.Method == http.MethodPost && r.URL.Path == accountTokenPath {
|
||||
return s.handleAccountTokenIssue(w, r, v)
|
||||
} else if r.Method == http.MethodPatch && r.URL.Path == accountTokenPath {
|
||||
return s.handleAccountTokenExtend(w, r, v)
|
||||
} else if r.Method == http.MethodDelete && r.URL.Path == accountTokenPath {
|
||||
return s.handleAccountTokenDelete(w, r, v)
|
||||
} else if r.Method == http.MethodPost && r.URL.Path == accountSettingsPath {
|
||||
} else if r.Method == http.MethodPatch && r.URL.Path == accountSettingsPath {
|
||||
return s.handleAccountSettingsChange(w, r, v)
|
||||
} else if r.Method == http.MethodPost && r.URL.Path == accountSubscriptionPath {
|
||||
return s.handleAccountSubscriptionAdd(w, r, v)
|
||||
|
@ -557,7 +564,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|||
}
|
||||
v.IncrMessages()
|
||||
if v.user != nil {
|
||||
s.auth.EnqueueUpdateStats(v.user)
|
||||
s.userManager.EnqueueStats(v.user)
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.messages++
|
||||
|
@ -1122,7 +1129,7 @@ func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
|
|||
}
|
||||
|
||||
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, DELETE")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible?
|
||||
return nil
|
||||
|
@ -1192,6 +1199,11 @@ func (s *Server) updateStatsAndPrune() {
|
|||
s.mu.Unlock()
|
||||
log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
|
||||
|
||||
// Delete expired user tokens
|
||||
if err := s.userManager.RemoveExpiredTokens(); err != nil {
|
||||
log.Warn("Error expiring user tokens: %s", err.Error())
|
||||
}
|
||||
|
||||
// Delete expired attachments
|
||||
if s.fileCache != nil && s.config.AttachmentExpiryDuration > 0 {
|
||||
olderThan := time.Now().Add(-1 * s.config.AttachmentExpiryDuration)
|
||||
|
@ -1323,7 +1335,7 @@ func (s *Server) sendDelayedMessages() error {
|
|||
for _, m := range messages {
|
||||
var v *visitor
|
||||
if m.User != "" {
|
||||
user, err := s.auth.User(m.User)
|
||||
user, err := s.userManager.User(m.User)
|
||||
if err != nil {
|
||||
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
|
||||
continue
|
||||
|
@ -1457,16 +1469,16 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
|
|||
}
|
||||
|
||||
func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
|
||||
return s.autorizeTopic(next, auth.PermissionWrite)
|
||||
return s.autorizeTopic(next, user.PermissionWrite)
|
||||
}
|
||||
|
||||
func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
|
||||
return s.autorizeTopic(next, auth.PermissionRead)
|
||||
return s.autorizeTopic(next, user.PermissionRead)
|
||||
}
|
||||
|
||||
func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc {
|
||||
func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if s.auth == nil {
|
||||
if s.userManager == nil {
|
||||
return next(w, r, v)
|
||||
}
|
||||
topics, _, err := s.topicsFromPath(r.URL.Path)
|
||||
|
@ -1474,7 +1486,7 @@ func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc
|
|||
return err
|
||||
}
|
||||
for _, t := range topics {
|
||||
if err := s.auth.Authorize(v.user, t.ID, perm); err != nil {
|
||||
if err := s.userManager.Authorize(v.user, t.ID, perm); err != nil {
|
||||
log.Info("unauthorized: %s", err.Error())
|
||||
return errHTTPForbidden
|
||||
}
|
||||
|
@ -1487,7 +1499,7 @@ func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc
|
|||
// Note that this function will always return a visitor, even if an error occurs.
|
||||
func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
|
||||
ip := extractIPAddress(r, s.config.BehindProxy)
|
||||
var user *auth.User // may stay nil if no auth header!
|
||||
var user *user.User // may stay nil if no auth header!
|
||||
if user, err = s.authenticate(r); err != nil {
|
||||
log.Debug("authentication failed: %s", err.Error())
|
||||
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
||||
|
@ -1505,7 +1517,7 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
|
|||
// The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
|
||||
// support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
|
||||
// query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
|
||||
func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) {
|
||||
func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
|
||||
value := r.Header.Get("Authorization")
|
||||
queryParam := readQueryParam(r, "authorization", "auth")
|
||||
if queryParam != "" {
|
||||
|
@ -1524,21 +1536,21 @@ func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) {
|
|||
return s.authenticateBasicAuth(r, value)
|
||||
}
|
||||
|
||||
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *auth.User, err error) {
|
||||
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
|
||||
r.Header.Set("Authorization", value)
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return nil, errors.New("invalid basic auth")
|
||||
}
|
||||
return s.auth.Authenticate(username, password)
|
||||
return s.userManager.Authenticate(username, password)
|
||||
}
|
||||
|
||||
func (s *Server) authenticateBearerAuth(value string) (user *auth.User, err error) {
|
||||
func (s *Server) authenticateBearerAuth(value string) (user *user.User, err error) {
|
||||
token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
|
||||
return s.auth.AuthenticateToken(token)
|
||||
return s.userManager.AuthenticateToken(token)
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor {
|
||||
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, exists := s.visitors[visitorID]
|
||||
|
@ -1554,6 +1566,6 @@ func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
|
|||
return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromUser(user *auth.User, ip netip.Addr) *visitor {
|
||||
func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
|
||||
return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue