package server import ( "net/http" "net/http/httptest" "os" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("middlewares", func() { var nextCalled bool next := func(w http.ResponseWriter, r *http.Request) { nextCalled = true } Describe("robotsTXT", func() { BeforeEach(func() { nextCalled = false }) It("returns the robot.txt when requested from root", func() { r := httptest.NewRequest("GET", "/robots.txt", nil) w := httptest.NewRecorder() robotsTXT(os.DirFS("tests/fixtures"))(http.HandlerFunc(next)).ServeHTTP(w, r) Expect(nextCalled).To(BeFalse()) Expect(w.Body.String()).To(HavePrefix("User-agent:")) }) It("allows prefixes", func() { r := httptest.NewRequest("GET", "/app/robots.txt", nil) w := httptest.NewRecorder() robotsTXT(os.DirFS("tests/fixtures"))(http.HandlerFunc(next)).ServeHTTP(w, r) Expect(nextCalled).To(BeFalse()) Expect(w.Body.String()).To(HavePrefix("User-agent:")) }) It("passes through requests for other files", func() { r := httptest.NewRequest("GET", "/this_is_not_a_robots.txt_file", nil) w := httptest.NewRecorder() robotsTXT(os.DirFS("tests/fixtures"))(http.HandlerFunc(next)).ServeHTTP(w, r) Expect(nextCalled).To(BeTrue()) }) }) Describe("serverAddressMiddleware", func() { var ( nextHandler http.Handler middleware http.Handler recorder *httptest.ResponseRecorder req *http.Request ) BeforeEach(func() { nextHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) middleware = serverAddressMiddleware(nextHandler) recorder = httptest.NewRecorder() }) Context("with no X-Forwarded headers", func() { BeforeEach(func() { req, _ = http.NewRequest("GET", "http://example.com", nil) }) It("should not modify the request", func() { middleware.ServeHTTP(recorder, req) Expect(req.Host).To(Equal("example.com")) Expect(req.URL.Scheme).To(Equal("http")) }) }) Context("with X-Forwarded-Host header", func() { BeforeEach(func() { req, _ = http.NewRequest("GET", "http://example.com", nil) req.Header.Set("X-Forwarded-Host", "forwarded.example.com") }) It("should modify the request with the X-Forwarded-Host header value", func() { middleware.ServeHTTP(recorder, req) Expect(req.Host).To(Equal("forwarded.example.com")) Expect(req.URL.Scheme).To(Equal("http")) }) }) Context("with X-Forwarded-Proto header", func() { BeforeEach(func() { req, _ = http.NewRequest("GET", "http://example.com", nil) req.Header.Set("X-Forwarded-Proto", "https") }) It("should modify the request with the X-Forwarded-Proto header value", func() { middleware.ServeHTTP(recorder, req) Expect(req.Host).To(Equal("example.com")) Expect(req.URL.Scheme).To(Equal("https")) }) }) Context("with X-Forwarded-Scheme header", func() { BeforeEach(func() { req, _ = http.NewRequest("GET", "http://example.com", nil) req.Header.Set("X-Forwarded-Scheme", "https") }) It("should modify the request with the X-Forwarded-Scheme header value", func() { middleware.ServeHTTP(recorder, req) Expect(req.Host).To(Equal("example.com")) Expect(req.URL.Scheme).To(Equal("https")) }) }) Context("with multiple X-Forwarded headers", func() { BeforeEach(func() { req, _ = http.NewRequest("GET", "http://example.com", nil) req.Header.Set("X-Forwarded-Host", "forwarded.example.com") req.Header.Set("X-Forwarded-Proto", "https") req.Header.Set("X-Forwarded-Scheme", "http") }) It("should modify the request with the first non-empty X-Forwarded header value", func() { middleware.ServeHTTP(recorder, req) Expect(req.Host).To(Equal("forwarded.example.com")) Expect(req.URL.Scheme).To(Equal("https")) }) }) Context("with multiple values in X-Forwarded-Host header", func() { BeforeEach(func() { req, _ = http.NewRequest("GET", "http://example.com", nil) req.Header.Set("X-Forwarded-Host", "forwarded1.example.com, forwarded2.example.com") }) It("should modify the request with the first value in X-Forwarded-Host header", func() { middleware.ServeHTTP(recorder, req) Expect(req.Host).To(Equal("forwarded1.example.com")) Expect(req.URL.Scheme).To(Equal("http")) }) }) }) })