mirror of
https://github.com/binwiederhier/ntfy.git
synced 2025-04-05 22:17:40 +03:00
Rate limiting, docs
This commit is contained in:
parent
e1c9fef6dc
commit
23cf77e0b7
7 changed files with 180 additions and 57 deletions
120
server/server.go
120
server/server.go
|
@ -4,11 +4,12 @@ import (
|
|||
"bytes"
|
||||
_ "embed" // required for go:embed
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/time/rate"
|
||||
"heckel.io/ntfy/config"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
@ -16,19 +17,33 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// Server is the main server
|
||||
type Server struct {
|
||||
config *config.Config
|
||||
topics map[string]*topic
|
||||
mu sync.Mutex
|
||||
config *config.Config
|
||||
topics map[string]*topic
|
||||
visitors map[string]*visitor
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type message struct {
|
||||
Time int64 `json:"time"`
|
||||
Message string `json:"message"`
|
||||
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
|
||||
type visitor struct {
|
||||
limiter *rate.Limiter
|
||||
seen time.Time
|
||||
}
|
||||
|
||||
// errHTTP is a generic HTTP error for any non-200 HTTP error
|
||||
type errHTTP struct {
|
||||
Code int
|
||||
Status string
|
||||
}
|
||||
|
||||
func (e errHTTP) Error() string {
|
||||
return fmt.Sprintf("http: %s", e.Status)
|
||||
}
|
||||
|
||||
const (
|
||||
messageLimit = 1024
|
||||
messageLimit = 1024
|
||||
visitorExpungeAfter = 30 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -40,18 +55,26 @@ var (
|
|||
//go:embed "index.html"
|
||||
indexSource string
|
||||
|
||||
errTopicNotFound = errors.New("topic not found")
|
||||
errHTTPNotFound = &errHTTP{http.StatusNotFound, http.StatusText(http.StatusNotFound)}
|
||||
errHTTPTooManyRequests = &errHTTP{http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)}
|
||||
)
|
||||
|
||||
func New(conf *config.Config) *Server {
|
||||
return &Server{
|
||||
config: conf,
|
||||
topics: make(map[string]*topic),
|
||||
config: conf,
|
||||
topics: make(map[string]*topic),
|
||||
visitors: make(map[string]*visitor),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Run() error {
|
||||
go s.runMonitor()
|
||||
go func() {
|
||||
ticker := time.NewTicker(s.config.ManagerInterval)
|
||||
for {
|
||||
<-ticker.C
|
||||
s.updateStatsAndExpire()
|
||||
}
|
||||
}()
|
||||
return s.listenAndServe()
|
||||
}
|
||||
|
||||
|
@ -61,29 +84,43 @@ func (s *Server) listenAndServe() error {
|
|||
return http.ListenAndServe(s.config.ListenHTTP, nil)
|
||||
}
|
||||
|
||||
func (s *Server) runMonitor() {
|
||||
for {
|
||||
time.Sleep(30 * time.Second)
|
||||
s.mu.Lock()
|
||||
var subscribers, messages int
|
||||
for _, t := range s.topics {
|
||||
subs, msgs := t.Stats()
|
||||
subscribers += subs
|
||||
messages += msgs
|
||||
func (s *Server) updateStatsAndExpire() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Expire visitors from rate visitors map
|
||||
for ip, v := range s.visitors {
|
||||
if time.Since(v.seen) > visitorExpungeAfter {
|
||||
delete(s.visitors, ip)
|
||||
}
|
||||
log.Printf("Stats: %d topic(s), %d subscriber(s), %d message(s) sent", len(s.topics), subscribers, messages)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Print stats
|
||||
var subscribers, messages int
|
||||
for _, t := range s.topics {
|
||||
subs, msgs := t.Stats()
|
||||
subscribers += subs
|
||||
messages += msgs
|
||||
}
|
||||
log.Printf("Stats: %d topic(s), %d subscriber(s), %d message(s) sent, %d visitor(s)",
|
||||
len(s.topics), subscribers, messages, len(s.visitors))
|
||||
}
|
||||
|
||||
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
||||
if err := s.handleInternal(w, r); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = io.WriteString(w, err.Error()+"\n")
|
||||
if e, ok := err.(*errHTTP); ok {
|
||||
s.fail(w, r, e.Code, e)
|
||||
} else {
|
||||
s.fail(w, r, http.StatusInternalServerError, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
||||
v := s.visitor(r.RemoteAddr)
|
||||
if !v.limiter.Allow() {
|
||||
return errHTTPTooManyRequests
|
||||
}
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/" {
|
||||
return s.handleHome(w, r)
|
||||
} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
|
||||
|
@ -95,8 +132,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
|||
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
|
||||
return s.handlePublishHTTP(w, r)
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
return nil
|
||||
return errHTTPNotFound
|
||||
}
|
||||
|
||||
func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
|
||||
|
@ -206,7 +242,7 @@ func (s *Server) topic(topicID string) (*topic, error) {
|
|||
defer s.mu.Unlock()
|
||||
c, ok := s.topics[topicID]
|
||||
if !ok {
|
||||
return nil, errTopicNotFound
|
||||
return nil, errHTTPNotFound
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
@ -218,3 +254,31 @@ func (s *Server) unsubscribe(t *topic, subscriberID int) {
|
|||
delete(s.topics, t.id)
|
||||
}
|
||||
}
|
||||
|
||||
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
||||
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
||||
func (s *Server) visitor(remoteAddr string) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
ip = remoteAddr // This should not happen in real life; only in tests.
|
||||
}
|
||||
v, exists := s.visitors[ip]
|
||||
if !exists {
|
||||
v = &visitor{
|
||||
rate.NewLimiter(s.config.Limit, s.config.LimitBurst),
|
||||
time.Now(),
|
||||
}
|
||||
s.visitors[ip] = v
|
||||
return v
|
||||
}
|
||||
v.seen = time.Now()
|
||||
return v
|
||||
}
|
||||
|
||||
func (s *Server) fail(w http.ResponseWriter, r *http.Request, code int, err error) {
|
||||
log.Printf("[%s] %s - %d - %s", r.RemoteAddr, r.Method, code, err.Error())
|
||||
w.WriteHeader(code)
|
||||
io.WriteString(w, fmt.Sprintf("%s\n", http.StatusText(code)))
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue