mirror of
https://github.com/binwiederhier/ntfy.git
synced 2025-04-05 06:07:34 +03:00
Add test, fails
This commit is contained in:
parent
4ab450309f
commit
29340e7e24
5 changed files with 89 additions and 45 deletions
|
@ -1889,6 +1889,49 @@ func TestServer_AnonymousUser_And_NonTierUser_Are_Same_Visitor(t *testing.T) {
|
|||
require.Equal(t, int64(2), account.Stats.Messages)
|
||||
}
|
||||
|
||||
func TestServer_SubscriberRateLimiting(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.VisitorRequestLimitBurst = 3
|
||||
s := newTestServer(t, c)
|
||||
|
||||
subscriber1Fn := func(r *http.Request) {
|
||||
r.RemoteAddr = "1.2.3.4"
|
||||
}
|
||||
rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{
|
||||
"Subscriber-Rate-Limit-Topics": "mytopic1",
|
||||
}, subscriber1Fn)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "", rr.Body.String())
|
||||
|
||||
subscriber2Fn := func(r *http.Request) {
|
||||
r.RemoteAddr = "8.7.7.1"
|
||||
}
|
||||
rr = request(t, s, "GET", "/upSUB2topic/json?poll=1", "", nil, subscriber2Fn)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "", rr.Body.String())
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
rr := request(t, s, "PUT", "/subscriber1topic", "some message", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
rr = request(t, s, "PUT", "/subscriber1topic", "some message", nil)
|
||||
require.Equal(t, 429, rr.Code)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
rr := request(t, s, "PUT", "/upSUB2topic", "some message", nil)
|
||||
require.Equal(t, 200, rr.Code) // If we fail here, handlePublish is using the wrong visitor!
|
||||
}
|
||||
rr = request(t, s, "PUT", "/upSUB2topic", "some message", nil)
|
||||
require.Equal(t, 429, rr.Code)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
rr := request(t, s, "PUT", "/some-other-topic", "some message", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
rr = request(t, s, "PUT", "/some-other-topic", "some message", nil)
|
||||
require.Equal(t, 429, rr.Code)
|
||||
}
|
||||
|
||||
func newTestConfig(t *testing.T) *Config {
|
||||
conf := NewConfig()
|
||||
conf.BaseURL = "http://127.0.0.1:12345"
|
||||
|
@ -1914,7 +1957,7 @@ func newTestServer(t *testing.T, config *Config) *Server {
|
|||
return server
|
||||
}
|
||||
|
||||
func request(t *testing.T, s *Server, method, url, body string, headers map[string]string) *httptest.ResponseRecorder {
|
||||
func request(t *testing.T, s *Server, method, url, body string, headers map[string]string, fn ...func(r *http.Request)) *httptest.ResponseRecorder {
|
||||
rr := httptest.NewRecorder()
|
||||
req, err := http.NewRequest(method, url, strings.NewReader(body))
|
||||
if err != nil {
|
||||
|
@ -1924,6 +1967,9 @@ func request(t *testing.T, s *Server, method, url, body string, headers map[stri
|
|||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
for _, f := range fn {
|
||||
f(req)
|
||||
}
|
||||
s.handle(rr, req)
|
||||
return rr
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue