Auth rate limiter

This commit is contained in:
binwiederhier 2023-02-08 15:20:44 -05:00
parent 3ac315a9e7
commit e1a4a74905
16 changed files with 152 additions and 60 deletions

View file

@ -34,9 +34,9 @@ import (
/*
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
- HIGH Account limit creation triggers when account is taken!
- HIGH Docs
- tiers
- api
- HIGH Self-review
- MEDIUM: Test for expiring messages after reservation removal
- MEDIUM: Test new token endpoints & never-expiring token
@ -1540,18 +1540,6 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
return nil
}
func (s *Server) limitRequests(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
return next(w, r, v)
} else if err := v.RequestAllowed(); err != nil {
logvr(v, r).Err(err).Trace("Request not allowed by rate limiter")
return errHTTPTooManyRequestsLimitRequests
}
return next(w, r, v)
}
}
// transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
// before passing it on to the next handler. This is meant to be used in combination with handlePublish.
func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
@ -1648,43 +1636,65 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
}
}
// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor.
// Note that this function will always return a visitor, even if an error occurs.
func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
// maybeAuthenticate reads the "Authorization" header and will try to authenticate the user
// if it is set.
//
// - If the header is not set, an IP-based visitor is returned
// - If the header is set, authenticate will be called to check the username/password (Basic auth),
// or the token (Bearer auth), and read the user from the database
//
// This function will ALWAYS return a visitor, even if an error occurs (e.g. unauthorized), so
// that subsequent logging calls still have a visitor context.
func (s *Server) maybeAuthenticate(r *http.Request) (*visitor, error) {
// Read "Authorization" header value, and exit out early if it's not set
ip := extractIPAddress(r, s.config.BehindProxy)
var u *user.User // may stay nil if no auth header!
if u, err = s.authenticate(r); err != nil {
logr(r).Err(err).Debug("Authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
vip := s.visitor(ip, nil)
header, err := readAuthHeader(r)
if err != nil {
return vip, err
} else if header == "" {
return vip, nil
} else if s.userManager == nil {
return vip, errHTTPUnauthorized
}
v = s.visitor(ip, u)
v.SetUser(u) // Update visitor user with latest from database!
return v, err // Always return visitor, even when error occurs!
// If we're trying to auth, check the rate limiter first
if !vip.AuthAllowed() {
return vip, errHTTPTooManyRequestsLimitAuthFailure // Always return visitor, even when error occurs!
}
u, err := s.authenticate(r, header)
if err != nil {
vip.AuthFailed()
logr(r).Err(err).Debug("Authentication failed")
return vip, errHTTPUnauthorized // Always return visitor, even when error occurs!
}
// Authentication with user was successful
return s.visitor(ip, u), nil
}
// authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
// The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
// support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
// query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
// query param is effectively doubly base64 encoded. Its format is base64(Basic base64(user:pass)).
func (s *Server) authenticate(r *http.Request, header string) (user *user.User, err error) {
if strings.HasPrefix(header, "Bearer") {
return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(header, "Bearer")))
}
return s.authenticateBasicAuth(r, header)
}
// readAuthHeader reads the raw value of the Authorization header, either from the actual HTTP header,
// or from the ?auth... query parameter
func readAuthHeader(r *http.Request) (string, error) {
value := strings.TrimSpace(r.Header.Get("Authorization"))
queryParam := readQueryParam(r, "authorization", "auth")
if queryParam != "" {
a, err := base64.RawURLEncoding.DecodeString(queryParam)
if err != nil {
return nil, err
return "", err
}
value = strings.TrimSpace(string(a))
}
if value == "" {
return nil, nil
} else if s.userManager == nil {
return nil, errHTTPUnauthorized
}
if strings.HasPrefix(value, "Bearer") {
return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(value, "Bearer")))
}
return s.authenticateBasicAuth(r, value)
return value, nil
}
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
@ -1721,6 +1731,7 @@ func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
return s.visitors[id]
}
v.Keepalive()
v.SetUser(user) // Always update with the latest user, may be nil!
return v
}