http3: return the context cancellation error from RoundTrip (#4203)

This commit is contained in:
Marten Seemann 2023-12-21 11:16:30 +07:00 committed by GitHub
parent d3c5f389d4
commit 2243fdefbf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 1 deletions

View file

@ -254,6 +254,15 @@ func (c *client) maxHeaderBytes() uint64 {
// RoundTripOpt executes a request and returns a response
func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
rsp, err := c.roundTripOpt(req, opt)
if err != nil && req.Context().Err() != nil {
// if the context was canceled, return the context cancellation error
err = req.Context().Err()
}
return rsp, err
}
func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
}

View file

@ -922,7 +922,7 @@ var _ = Describe("Client", func() {
return 0, errors.New("test done")
})
_, err := cl.RoundTripOpt(req, roundTripOpt)
Expect(err).To(MatchError("test done"))
Expect(err).To(MatchError(context.Canceled))
Eventually(done).Should(BeClosed())
})
})

View file

@ -319,6 +319,21 @@ var _ = Describe("HTTP tests", func() {
Expect(string(body)).To(Equal("Hello, World!\n"))
})
It("handles context cancellations", func() {
mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
})
ctx, cancel := context.WithCancel(context.Background())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/cancel", port), nil)
Expect(err).ToNot(HaveOccurred())
time.AfterFunc(50*time.Millisecond, cancel)
_, err = client.Do(req)
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(context.Canceled))
})
It("cancels requests", func() {
handlerCalled := make(chan struct{})
mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) {