introduce a http3.RoundTripOpt to prevent closing of request stream (#3411)

This commit is contained in:
Marten Seemann 2022-05-20 11:54:31 +02:00 committed by GitHub
parent 6511723c28
commit de5f08171b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 107 additions and 73 deletions

View file

@ -236,10 +236,10 @@ func (c *client) maxHeaderBytes() uint64 {
return uint64(c.opts.MaxHeaderBytes)
}
// RoundTrip executes a request and returns a response
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
// RoundTripOpt executes a request and returns a response
func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
return nil, fmt.Errorf("http3 client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
}
c.dialOnce.Do(func() {
@ -268,7 +268,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
}
// Request Cancellation:
// This go routine keeps running even after RoundTrip() returns.
// This go routine keeps running even after RoundTripOpt() returns.
// It is shut down when the application is done processing the body.
reqDone := make(chan struct{})
go func() {
@ -280,7 +280,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
}
}()
rsp, rerr := c.doRequest(req, str, reqDone)
rsp, rerr := c.doRequest(req, str, opt, reqDone)
if rerr.err != nil { // if any error occurred
close(reqDone)
if rerr.streamErr != 0 { // if it was a stream error
@ -297,16 +297,12 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
return rsp, rerr.err
}
func (c *client) doRequest(
req *http.Request,
str quic.Stream,
reqDone chan struct{},
) (*http.Response, requestError) {
func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan struct{}) (*http.Response, requestError) {
var requestGzip bool
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
requestGzip = true
}
if err := c.requestWriter.WriteRequest(str, req, requestGzip); err != nil {
if err := c.requestWriter.WriteRequest(str, req, opt.DontCloseRequestStream, requestGzip); err != nil {
return nil, newStreamError(errorInternalError, err)
}

View file

@ -72,7 +72,7 @@ var _ = Describe("Client", func() {
dialAddrCalled = true
return nil, errors.New("test done")
}
client.RoundTrip(req)
client.RoundTripOpt(req, RoundTripOpt{})
Expect(dialAddrCalled).To(BeTrue())
})
@ -87,7 +87,7 @@ var _ = Describe("Client", func() {
}
req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
Expect(err).ToNot(HaveOccurred())
client.RoundTrip(req)
client.RoundTripOpt(req, RoundTripOpt{})
Expect(dialAddrCalled).To(BeTrue())
})
@ -108,7 +108,7 @@ var _ = Describe("Client", func() {
dialAddrCalled = true
return nil, errors.New("test done")
}
client.RoundTrip(req)
client.RoundTripOpt(req, RoundTripOpt{})
Expect(dialAddrCalled).To(BeTrue())
// make sure the original tls.Config was not modified
Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
@ -131,7 +131,7 @@ var _ = Describe("Client", func() {
}
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req.WithContext(ctx))
_, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
Expect(err).To(MatchError(testErr))
Expect(dialerCalled).To(BeTrue())
})
@ -144,7 +144,7 @@ var _ = Describe("Client", func() {
Expect(quicConf.EnableDatagrams).To(BeTrue())
return nil, testErr
}
_, err = client.RoundTrip(req)
_, err = client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
@ -155,7 +155,7 @@ var _ = Describe("Client", func() {
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return nil, testErr
}
_, err = client.RoundTrip(req)
_, err = client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
@ -169,8 +169,8 @@ var _ = Describe("Client", func() {
It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("http3 client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
_, err = client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
})
It("allows requests using a different scheme", func() {
@ -180,14 +180,14 @@ var _ = Describe("Client", func() {
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return nil, testErr
}
_, err = client.RoundTrip(req)
_, err = client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
})
Context("hijacking unidirectional streams", func() {
var (
request *http.Request
req *http.Request
conn *mockquic.MockEarlyConnection
settingsFrameWritten chan struct{}
)
@ -209,7 +209,7 @@ var _ = Describe("Client", func() {
return conn, nil
}
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
@ -236,7 +236,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -261,7 +261,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -270,7 +270,7 @@ var _ = Describe("Client", func() {
Context("control stream handling", func() {
var (
request *http.Request
req *http.Request
conn *mockquic.MockEarlyConnection
settingsFrameWritten chan struct{}
)
@ -291,7 +291,7 @@ var _ = Describe("Client", func() {
return conn, nil
}
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
@ -313,7 +313,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
@ -338,7 +338,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
})
@ -361,7 +361,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -385,7 +385,7 @@ var _ = Describe("Client", func() {
Expect(code).To(BeEquivalentTo(errorMissingSettings))
close(done)
})
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -411,7 +411,7 @@ var _ = Describe("Client", func() {
Expect(code).To(BeEquivalentTo(errorFrameError))
close(done)
})
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -434,7 +434,7 @@ var _ = Describe("Client", func() {
Expect(code).To(BeEquivalentTo(errorIDError))
close(done)
})
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -461,7 +461,7 @@ var _ = Describe("Client", func() {
Expect(reason).To(Equal("missing QUIC Datagram support"))
close(done)
})
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -469,7 +469,7 @@ var _ = Describe("Client", func() {
Context("Doing requests", func() {
var (
request *http.Request
req *http.Request
str *mockquic.MockStream
conn *mockquic.MockEarlyConnection
settingsFrameWritten chan struct{}
@ -540,7 +540,7 @@ var _ = Describe("Client", func() {
return conn, nil
}
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
@ -554,13 +554,13 @@ var _ = Describe("Client", func() {
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
It("performs a 0-RTT request", func() {
testErr := errors.New("stream open error")
request.Method = MethodGet0RTT
req.Method = MethodGet0RTT
// don't EXPECT any calls to HandshakeComplete()
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
@ -570,7 +570,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
return 0, testErr
})
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
})
@ -585,7 +585,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTrip(request)
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3"))
Expect(rsp.ProtoMajor).To(Equal(3))
@ -604,7 +604,7 @@ var _ = Describe("Client", func() {
body := &mockBody{}
body.SetData([]byte("request body"))
var err error
request, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
Expect(err).ToNot(HaveOccurred())
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
})
@ -620,7 +620,7 @@ var _ = Describe("Client", func() {
<-done
return 0, errors.New("test done")
})
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(strBuf)
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
@ -628,7 +628,7 @@ var _ = Describe("Client", func() {
})
It("returns the error that occurred when reading the body", func() {
request.Body.(*mockBody).readErr = errors.New("testErr")
req.Body.(*mockBody).readErr = errors.New("testErr")
done := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) {
@ -642,7 +642,7 @@ var _ = Describe("Client", func() {
<-done
return 0, errors.New("test done")
})
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
})
@ -660,7 +660,7 @@ var _ = Describe("Client", func() {
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors
// the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
req, err := client.RoundTrip(request)
req, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
Expect(req.ContentLength).To(BeEquivalentTo(1337))
Eventually(done).Should(BeClosed())
@ -673,7 +673,7 @@ var _ = Describe("Client", func() {
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
Eventually(closed).Should(BeClosed())
})
@ -685,7 +685,7 @@ var _ = Describe("Client", func() {
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
Eventually(closed).Should(BeClosed())
})
@ -694,12 +694,12 @@ var _ = Describe("Client", func() {
Context("request cancellations", func() {
It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(context.Background())
errChan := make(chan error)
go func() {
_, err := client.RoundTrip(req)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
@ -709,7 +709,7 @@ var _ = Describe("Client", func() {
It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
@ -729,7 +729,7 @@ var _ = Describe("Client", func() {
<-canceled
return 0, errors.New("test done")
})
_, err := client.RoundTrip(req)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
Eventually(done).Should(BeClosed())
})
@ -738,7 +738,7 @@ var _ = Describe("Client", func() {
rspBuf := bytes.NewBuffer(getResponse(404))
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
@ -750,7 +750,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
_, err := client.RoundTrip(req)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(done).Should(BeClosed())
@ -771,7 +771,7 @@ var _ = Describe("Client", func() {
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err := client.RoundTrip(request)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
@ -788,7 +788,7 @@ var _ = Describe("Client", func() {
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err = client.RoundTrip(request)
_, err = client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).ToNot(HaveKey("accept-encoding"))
@ -810,7 +810,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTrip(request)
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
@ -833,7 +833,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTrip(request)
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())

21
http3/conn.go Normal file
View file

@ -0,0 +1,21 @@
package http3
import "github.com/lucas-clemente/quic-go"
type ConnState struct {
SupportsDatagram bool
}
type Conn struct {
conn quic.Connection
supportsDatagram bool
}
func (c *Conn) State() ConnState {
return ConnState{SupportsDatagram: c.supportsDatagram}
}
func (c *Conn) SendDatagram(b []byte) error {
return c.conn.SendMessage(b)
}

View file

@ -38,7 +38,7 @@ func newRequestWriter(logger utils.Logger) *requestWriter {
}
}
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error {
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, dontCloseStr, gzip bool) error {
buf := &bytes.Buffer{}
if err := w.writeHeaders(buf, req, gzip); err != nil {
return err
@ -48,7 +48,9 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo
}
// TODO: add support for trailers
if req.Body == nil {
if !dontCloseStr {
str.Close()
}
return nil
}
@ -84,7 +86,9 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo
return
}
}
if !dontCloseStr {
str.Close()
}
}()
return nil

View file

@ -60,7 +60,7 @@ var _ = Describe("Request Writer", func() {
str.EXPECT().Close()
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
@ -69,13 +69,21 @@ var _ = Describe("Request Writer", func() {
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
})
It("writes a GET request without closing the stream", func() {
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io", nil)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, true, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
})
It("writes a POST request", func() {
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
postData := bytes.NewReader([]byte("foobar"))
req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", postData)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
Eventually(closed).Should(BeClosed())
headerFields := decode(strBuf)
@ -96,7 +104,7 @@ var _ = Describe("Request Writer", func() {
str.EXPECT().Close().Do(func() { close(closed) })
req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", &foobarReader{})
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
Eventually(closed).Should(BeClosed())
headerFields := decode(strBuf)
@ -122,7 +130,7 @@ var _ = Describe("Request Writer", func() {
}
req.AddCookie(cookie1)
req.AddCookie(cookie2)
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`))
})
@ -131,7 +139,7 @@ var _ = Describe("Request Writer", func() {
str.EXPECT().Close()
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, true)).To(Succeed())
Expect(rw.WriteRequest(str, req, false, true)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
})
@ -140,7 +148,7 @@ var _ = Describe("Request Writer", func() {
str.EXPECT().Close()
req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT"))
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
@ -154,7 +162,7 @@ var _ = Describe("Request Writer", func() {
req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil)
Expect(err).ToNot(HaveOccurred())
req.Proto = "webtransport"
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT"))

View file

@ -16,7 +16,7 @@ import (
)
type roundTripCloser interface {
http.RoundTripper
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
io.Closer
}
@ -77,11 +77,16 @@ type RoundTripper struct {
// RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
// If set true and no cached connection is available, RoundTrip will return ErrNoCachedConn.
// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
OnlyCachedConn bool
// DontCloseRequestStream controls whether the request stream is closed after sending the request.
DontCloseRequestStream bool
}
var _ roundTripCloser = &RoundTripper{}
var (
_ http.RoundTripper = &RoundTripper{}
_ io.Closer = &RoundTripper{}
)
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("http3: no cached connection was available")
@ -127,7 +132,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
if err != nil {
return nil, err
}
return cl.RoundTrip(req)
return cl.RoundTripOpt(req, opt)
}
// RoundTrip does a round trip.
@ -135,7 +140,7 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripCloser, error) {
r.mutex.Lock()
defer r.mutex.Unlock()

View file

@ -20,7 +20,7 @@ type mockClient struct {
closed bool
}
func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) {
func (m *mockClient) RoundTripOpt(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
return &http.Response{Request: req}, nil
}

View file

@ -138,7 +138,7 @@ var _ = Describe("Server", func() {
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
rw := newRequestWriter(utils.DefaultLogger)
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
Eventually(closed).Should(BeClosed())
return buf.Bytes()
}