Add test, fails

This commit is contained in:
binwiederhier 2023-02-22 21:00:56 -05:00
parent 4ab450309f
commit 29340e7e24
5 changed files with 89 additions and 45 deletions

View file

@ -1889,6 +1889,49 @@ func TestServer_AnonymousUser_And_NonTierUser_Are_Same_Visitor(t *testing.T) {
require.Equal(t, int64(2), account.Stats.Messages)
}
func TestServer_SubscriberRateLimiting(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.VisitorRequestLimitBurst = 3
s := newTestServer(t, c)
subscriber1Fn := func(r *http.Request) {
r.RemoteAddr = "1.2.3.4"
}
rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{
"Subscriber-Rate-Limit-Topics": "mytopic1",
}, subscriber1Fn)
require.Equal(t, 200, rr.Code)
require.Equal(t, "", rr.Body.String())
subscriber2Fn := func(r *http.Request) {
r.RemoteAddr = "8.7.7.1"
}
rr = request(t, s, "GET", "/upSUB2topic/json?poll=1", "", nil, subscriber2Fn)
require.Equal(t, 200, rr.Code)
require.Equal(t, "", rr.Body.String())
for i := 0; i < 3; i++ {
rr := request(t, s, "PUT", "/subscriber1topic", "some message", nil)
require.Equal(t, 200, rr.Code)
}
rr = request(t, s, "PUT", "/subscriber1topic", "some message", nil)
require.Equal(t, 429, rr.Code)
for i := 0; i < 3; i++ {
rr := request(t, s, "PUT", "/upSUB2topic", "some message", nil)
require.Equal(t, 200, rr.Code) // If we fail here, handlePublish is using the wrong visitor!
}
rr = request(t, s, "PUT", "/upSUB2topic", "some message", nil)
require.Equal(t, 429, rr.Code)
for i := 0; i < 3; i++ {
rr := request(t, s, "PUT", "/some-other-topic", "some message", nil)
require.Equal(t, 200, rr.Code)
}
rr = request(t, s, "PUT", "/some-other-topic", "some message", nil)
require.Equal(t, 429, rr.Code)
}
func newTestConfig(t *testing.T) *Config {
conf := NewConfig()
conf.BaseURL = "http://127.0.0.1:12345"
@ -1914,7 +1957,7 @@ func newTestServer(t *testing.T, config *Config) *Server {
return server
}
func request(t *testing.T, s *Server, method, url, body string, headers map[string]string) *httptest.ResponseRecorder {
func request(t *testing.T, s *Server, method, url, body string, headers map[string]string, fn ...func(r *http.Request)) *httptest.ResponseRecorder {
rr := httptest.NewRecorder()
req, err := http.NewRequest(method, url, strings.NewReader(body))
if err != nil {
@ -1924,6 +1967,9 @@ func request(t *testing.T, s *Server, method, url, body string, headers map[stri
for k, v := range headers {
req.Header.Set(k, v)
}
for _, f := range fn {
f(req)
}
s.handle(rr, req)
return rr
}