Polish a little

This commit is contained in:
binwiederhier 2023-02-22 22:26:43 -05:00
parent 21b27b5dbe
commit bdeec4d297
5 changed files with 63 additions and 40 deletions

View file

@ -104,15 +104,15 @@ var (
)
const (
firebaseControlTopic = "~control" // See Android if changed
firebasePollTopic = "~poll" // See iOS if changed
emptyMessageBody = "triggered" // Used if message body is empty
newMessageBody = "New message" // Used in poll requests as generic message
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
jsonBodyBytesLimit = 16384
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
rateVisitorExpiryDuration = 12 * time.Hour
firebaseControlTopic = "~control" // See Android if changed
firebasePollTopic = "~poll" // See iOS if changed
emptyMessageBody = "triggered" // Used if message body is empty
newMessageBody = "New message" // Used in poll requests as generic message
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
jsonBodyBytesLimit = 16384
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
rateTopicsWildcard = "*" // Allows defining all topics in the request subscriber-rate-limited topics
)
// WebSocket constants
@ -571,11 +571,11 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
}
func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
vrate, ok := r.Context().Value("vRate").(*visitor)
vrate, ok := r.Context().Value(contextRateVisitor).(*visitor)
if !ok {
return nil, errHTTPInternalError
}
t, ok := r.Context().Value("topic").(*topic)
t, ok := r.Context().Value(contextTopic).(*topic)
if !ok {
return nil, errHTTPInternalError
}
@ -709,7 +709,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
}
}
func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
func (s *Server) parsePublishParams(r *http.Request, vrate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
cache = readBoolParam(r, true, "x-cache", "cache")
firebase = readBoolParam(r, true, "x-firebase", "firebase")
m.Title = readParam(r, "x-title", "title", "t")
@ -749,7 +749,7 @@ func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message)
}
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
if email != "" {
if !vRate.EmailAllowed() {
if !vrate.EmailAllowed() {
return false, false, "", false, errHTTPTooManyRequestsLimitEmails
}
}
@ -954,7 +954,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
if err != nil {
return err
}
poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
poll, since, scheduled, filters, rateTopics, err := parseSubscribeParams(r)
if err != nil {
return err
}
@ -984,12 +984,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
}
return nil
}
for _, t := range topics {
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
if subscriberRateLimited {
t.SetRateVisitor(v)
}
}
registerRateVisitors(topics, rateTopics, v)
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
if poll {
@ -1042,7 +1037,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
if err != nil {
return err
}
poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
poll, since, scheduled, filters, rateTopics, err := parseSubscribeParams(r)
if err != nil {
return err
}
@ -1125,12 +1120,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
return conn.WriteJSON(msg)
}
for _, t := range topics {
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
if subscriberRateLimited {
t.SetRateVisitor(v)
}
}
registerRateVisitors(topics, rateTopics, v)
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
if poll {
return s.sendOldMessages(topics, since, scheduled, v, sub)
@ -1158,7 +1148,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
return err
}
func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, subscriberTopics []string, err error) {
func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, rateTopics []string, err error) {
poll = readBoolParam(r, false, "x-poll", "poll", "po")
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
since, err = parseSince(r, poll)
@ -1169,10 +1159,29 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
if err != nil {
return
}
subscriberTopics = readCommaSeparatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
rateTopics = readCommaSeparatedParam(r, "x-rate-topics", "rate-topics")
return
}
// registerRateVisitors sets the rate visitor on a topic, indicating that all messages published to that topic
// will be rate limited against the rate visitor instead of the publishing visitor.
//
// Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition
// until the Android app will send the "Rate-Topics" header.
func registerRateVisitors(topics []*topic, rateTopics []string, v *visitor) {
if len(rateTopics) > 0 && rateTopics[0] == rateTopicsWildcard {
for _, t := range topics {
t.SetRateVisitor(v)
}
} else {
for _, t := range topics {
if util.Contains(rateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) {
t.SetRateVisitor(v)
}
}
}
}
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
// marker, returning only messages that are newer than the marker.
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {