http3: add ConnContext to the server (#4230)

* Add ConnContext to http3.Server

ConnContext can be used to modify the context used by a new http
Request.

* Make linter happy

* Add nil check and integration test

* Add the ServerContextKey check to the ConnContext func

* Update integrationtests/self/http_test.go

Co-authored-by: Marten Seemann <martenseemann@gmail.com>

* Update http3/server.go

Co-authored-by: Marten Seemann <martenseemann@gmail.com>

---------

Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
Robin Thellend 2024-01-04 19:13:53 -08:00 committed by GitHub
parent f1b3bdbcb0
commit 3ff50295ce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 0 deletions

View file

@ -211,6 +211,11 @@ type Server struct {
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
// ConnContext optionally specifies a function that modifies
// the context used for a new connection c. The provided ctx
// has a ServerContextKey value.
ConnContext func(ctx context.Context, c quic.Connection) context.Context
mutex sync.RWMutex
listeners map[*QUICEarlyListener]listenerInfo
@ -610,6 +615,12 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
if s.ConnContext != nil {
ctx = s.ConnContext(ctx, conn)
if ctx == nil {
panic("http3: ConnContext returned nil")
}
}
req = req.WithContext(ctx)
r := newResponseWriter(str, conn, s.logger)
if req.Method == http.MethodHead {

View file

@ -67,11 +67,15 @@ var _ = Describe("Server", func() {
s *Server
origQuicListenAddr = quicListenAddr
)
type testConnContextKey string
BeforeEach(func() {
s = &Server{
TLSConfig: testdata.GetTLSConfig(),
logger: utils.DefaultLogger,
ConnContext: func(ctx context.Context, c quic.Connection) context.Context {
return context.WithValue(ctx, testConnContextKey("test"), c)
},
}
origQuicListenAddr = quicListenAddr
})
@ -163,6 +167,7 @@ var _ = Describe("Server", func() {
Expect(req.Host).To(Equal("www.example.com"))
Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337"))
Expect(req.Context().Value(ServerContextKey)).To(Equal(s))
Expect(req.Context().Value(testConnContextKey("test"))).ToNot(Equal(nil))
})
It("returns 200 with an empty handler", func() {

View file

@ -528,4 +528,35 @@ var _ = Describe("HTTP tests", func() {
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
})
It("sets conn context", func() {
type ctxKey int
server.ConnContext = func(ctx context.Context, c quic.Connection) context.Context {
serv, ok := ctx.Value(http3.ServerContextKey).(*http3.Server)
Expect(ok).To(BeTrue())
Expect(serv).To(Equal(server))
ctx = context.WithValue(ctx, ctxKey(0), "Hello")
ctx = context.WithValue(ctx, ctxKey(1), c)
return ctx
}
mux.HandleFunc("/conn-context", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
v, ok := r.Context().Value(ctxKey(0)).(string)
Expect(ok).To(BeTrue())
Expect(v).To(Equal("Hello"))
c, ok := r.Context().Value(ctxKey(1)).(quic.Connection)
Expect(ok).To(BeTrue())
Expect(c).ToNot(BeNil())
serv, ok := r.Context().Value(http3.ServerContextKey).(*http3.Server)
Expect(ok).To(BeTrue())
Expect(serv).To(Equal(server))
})
resp, err := client.Get(fmt.Sprintf("https://localhost:%d/conn-context", port))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
})
})