diff --git a/http3/server.go b/http3/server.go index 75b6dc2d..d872b31e 100644 --- a/http3/server.go +++ b/http3/server.go @@ -28,6 +28,22 @@ var ( const nextProtoH3 = "h3-24" +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { return "quic-go/http3 context value " + k.name } + +var ( + // ServerContextKey is a context key. It can be used in HTTP + // handlers with Context.Value to access the server that + // started the handler. The associated value will be of + // type *http3.Server. + ServerContextKey = &contextKey{"http3-server"} +) + type requestError struct { err error streamErr errorCode @@ -248,7 +264,10 @@ func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpac s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI) } - req = req.WithContext(str.Context()) + ctx := str.Context() + ctx = context.WithValue(ctx, ServerContextKey, s) + ctx = context.WithValue(ctx, http.LocalAddrContextKey, sess.LocalAddr()) + req = req.WithContext(ctx) responseWriter := newResponseWriter(str, s.logger) handler := s.Handler if handler == nil { diff --git a/http3/server_test.go b/http3/server_test.go index ce359ccb..a951aaee 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -108,6 +108,7 @@ var _ = Describe("Server", func() { sess = mockquic.NewMockEarlySession(mockCtrl) addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} sess.EXPECT().RemoteAddr().Return(addr).AnyTimes() + sess.EXPECT().LocalAddr().AnyTimes() }) It("calls the HTTP handler function", func() { @@ -127,6 +128,7 @@ var _ = Describe("Server", func() { Eventually(requestChan).Should(Receive(&req)) Expect(req.Host).To(Equal("www.example.com")) Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) + Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) }) It("returns 200 with an empty handler", func() { @@ -176,6 +178,7 @@ var _ = Describe("Server", func() { sess.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) sess.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) sess.EXPECT().RemoteAddr().Return(addr).AnyTimes() + sess.EXPECT().LocalAddr().AnyTimes() }) It("cancels reading when client sends a body in GET request", func() {