mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
introduce a http3.RoundTripOpt to prevent closing of request stream (#3411)
This commit is contained in:
parent
6511723c28
commit
de5f08171b
8 changed files with 107 additions and 73 deletions
|
@ -236,10 +236,10 @@ func (c *client) maxHeaderBytes() uint64 {
|
||||||
return uint64(c.opts.MaxHeaderBytes)
|
return uint64(c.opts.MaxHeaderBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip executes a request and returns a response
|
// RoundTripOpt executes a request and returns a response
|
||||||
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||||
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
|
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
|
||||||
return nil, fmt.Errorf("http3 client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.dialOnce.Do(func() {
|
c.dialOnce.Do(func() {
|
||||||
|
@ -268,7 +268,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request Cancellation:
|
// Request Cancellation:
|
||||||
// This go routine keeps running even after RoundTrip() returns.
|
// This go routine keeps running even after RoundTripOpt() returns.
|
||||||
// It is shut down when the application is done processing the body.
|
// It is shut down when the application is done processing the body.
|
||||||
reqDone := make(chan struct{})
|
reqDone := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -280,7 +280,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
rsp, rerr := c.doRequest(req, str, reqDone)
|
rsp, rerr := c.doRequest(req, str, opt, reqDone)
|
||||||
if rerr.err != nil { // if any error occurred
|
if rerr.err != nil { // if any error occurred
|
||||||
close(reqDone)
|
close(reqDone)
|
||||||
if rerr.streamErr != 0 { // if it was a stream error
|
if rerr.streamErr != 0 { // if it was a stream error
|
||||||
|
@ -297,16 +297,12 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return rsp, rerr.err
|
return rsp, rerr.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) doRequest(
|
func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan struct{}) (*http.Response, requestError) {
|
||||||
req *http.Request,
|
|
||||||
str quic.Stream,
|
|
||||||
reqDone chan struct{},
|
|
||||||
) (*http.Response, requestError) {
|
|
||||||
var requestGzip bool
|
var requestGzip bool
|
||||||
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
|
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
|
||||||
requestGzip = true
|
requestGzip = true
|
||||||
}
|
}
|
||||||
if err := c.requestWriter.WriteRequest(str, req, requestGzip); err != nil {
|
if err := c.requestWriter.WriteRequest(str, req, opt.DontCloseRequestStream, requestGzip); err != nil {
|
||||||
return nil, newStreamError(errorInternalError, err)
|
return nil, newStreamError(errorInternalError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -72,7 +72,7 @@ var _ = Describe("Client", func() {
|
||||||
dialAddrCalled = true
|
dialAddrCalled = true
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
}
|
}
|
||||||
client.RoundTrip(req)
|
client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(dialAddrCalled).To(BeTrue())
|
Expect(dialAddrCalled).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ var _ = Describe("Client", func() {
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
|
req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
client.RoundTrip(req)
|
client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(dialAddrCalled).To(BeTrue())
|
Expect(dialAddrCalled).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ var _ = Describe("Client", func() {
|
||||||
dialAddrCalled = true
|
dialAddrCalled = true
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
}
|
}
|
||||||
client.RoundTrip(req)
|
client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(dialAddrCalled).To(BeTrue())
|
Expect(dialAddrCalled).To(BeTrue())
|
||||||
// make sure the original tls.Config was not modified
|
// make sure the original tls.Config was not modified
|
||||||
Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
|
Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
|
||||||
|
@ -131,7 +131,7 @@ var _ = Describe("Client", func() {
|
||||||
}
|
}
|
||||||
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
|
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = client.RoundTrip(req.WithContext(ctx))
|
_, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
Expect(dialerCalled).To(BeTrue())
|
Expect(dialerCalled).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
@ -144,7 +144,7 @@ var _ = Describe("Client", func() {
|
||||||
Expect(quicConf.EnableDatagrams).To(BeTrue())
|
Expect(quicConf.EnableDatagrams).To(BeTrue())
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
_, err = client.RoundTrip(req)
|
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -155,7 +155,7 @@ var _ = Describe("Client", func() {
|
||||||
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
_, err = client.RoundTrip(req)
|
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -169,8 +169,8 @@ var _ = Describe("Client", func() {
|
||||||
It("refuses to do requests for the wrong host", func() {
|
It("refuses to do requests for the wrong host", func() {
|
||||||
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
|
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = client.RoundTrip(req)
|
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("http3 client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
|
Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("allows requests using a different scheme", func() {
|
It("allows requests using a different scheme", func() {
|
||||||
|
@ -180,14 +180,14 @@ var _ = Describe("Client", func() {
|
||||||
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
_, err = client.RoundTrip(req)
|
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("hijacking unidirectional streams", func() {
|
Context("hijacking unidirectional streams", func() {
|
||||||
var (
|
var (
|
||||||
request *http.Request
|
req *http.Request
|
||||||
conn *mockquic.MockEarlyConnection
|
conn *mockquic.MockEarlyConnection
|
||||||
settingsFrameWritten chan struct{}
|
settingsFrameWritten chan struct{}
|
||||||
)
|
)
|
||||||
|
@ -209,7 +209,7 @@ var _ = Describe("Client", func() {
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -236,7 +236,7 @@ var _ = Describe("Client", func() {
|
||||||
<-testDone
|
<-testDone
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("done"))
|
Expect(err).To(MatchError("done"))
|
||||||
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
|
@ -261,7 +261,7 @@ var _ = Describe("Client", func() {
|
||||||
<-testDone
|
<-testDone
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("done"))
|
Expect(err).To(MatchError("done"))
|
||||||
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
|
@ -270,7 +270,7 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
Context("control stream handling", func() {
|
Context("control stream handling", func() {
|
||||||
var (
|
var (
|
||||||
request *http.Request
|
req *http.Request
|
||||||
conn *mockquic.MockEarlyConnection
|
conn *mockquic.MockEarlyConnection
|
||||||
settingsFrameWritten chan struct{}
|
settingsFrameWritten chan struct{}
|
||||||
)
|
)
|
||||||
|
@ -291,7 +291,7 @@ var _ = Describe("Client", func() {
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -313,7 +313,7 @@ var _ = Describe("Client", func() {
|
||||||
<-testDone
|
<-testDone
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("done"))
|
Expect(err).To(MatchError("done"))
|
||||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
})
|
})
|
||||||
|
@ -338,7 +338,7 @@ var _ = Describe("Client", func() {
|
||||||
<-testDone
|
<-testDone
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("done"))
|
Expect(err).To(MatchError("done"))
|
||||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
|
||||||
})
|
})
|
||||||
|
@ -361,7 +361,7 @@ var _ = Describe("Client", func() {
|
||||||
<-testDone
|
<-testDone
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("done"))
|
Expect(err).To(MatchError("done"))
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -385,7 +385,7 @@ var _ = Describe("Client", func() {
|
||||||
Expect(code).To(BeEquivalentTo(errorMissingSettings))
|
Expect(code).To(BeEquivalentTo(errorMissingSettings))
|
||||||
close(done)
|
close(done)
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("done"))
|
Expect(err).To(MatchError("done"))
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -411,7 +411,7 @@ var _ = Describe("Client", func() {
|
||||||
Expect(code).To(BeEquivalentTo(errorFrameError))
|
Expect(code).To(BeEquivalentTo(errorFrameError))
|
||||||
close(done)
|
close(done)
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("done"))
|
Expect(err).To(MatchError("done"))
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -434,7 +434,7 @@ var _ = Describe("Client", func() {
|
||||||
Expect(code).To(BeEquivalentTo(errorIDError))
|
Expect(code).To(BeEquivalentTo(errorIDError))
|
||||||
close(done)
|
close(done)
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("done"))
|
Expect(err).To(MatchError("done"))
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -461,7 +461,7 @@ var _ = Describe("Client", func() {
|
||||||
Expect(reason).To(Equal("missing QUIC Datagram support"))
|
Expect(reason).To(Equal("missing QUIC Datagram support"))
|
||||||
close(done)
|
close(done)
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("done"))
|
Expect(err).To(MatchError("done"))
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -469,7 +469,7 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
Context("Doing requests", func() {
|
Context("Doing requests", func() {
|
||||||
var (
|
var (
|
||||||
request *http.Request
|
req *http.Request
|
||||||
str *mockquic.MockStream
|
str *mockquic.MockStream
|
||||||
conn *mockquic.MockEarlyConnection
|
conn *mockquic.MockEarlyConnection
|
||||||
settingsFrameWritten chan struct{}
|
settingsFrameWritten chan struct{}
|
||||||
|
@ -540,7 +540,7 @@ var _ = Describe("Client", func() {
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -554,13 +554,13 @@ var _ = Describe("Client", func() {
|
||||||
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
|
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
|
||||||
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
|
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||||
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("performs a 0-RTT request", func() {
|
It("performs a 0-RTT request", func() {
|
||||||
testErr := errors.New("stream open error")
|
testErr := errors.New("stream open error")
|
||||||
request.Method = MethodGet0RTT
|
req.Method = MethodGet0RTT
|
||||||
// don't EXPECT any calls to HandshakeComplete()
|
// don't EXPECT any calls to HandshakeComplete()
|
||||||
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
|
@ -570,7 +570,7 @@ var _ = Describe("Client", func() {
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
||||||
return 0, testErr
|
return 0, testErr
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
|
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
|
||||||
})
|
})
|
||||||
|
@ -585,7 +585,7 @@ var _ = Describe("Client", func() {
|
||||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||||
str.EXPECT().Close()
|
str.EXPECT().Close()
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||||
rsp, err := client.RoundTrip(request)
|
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(rsp.Proto).To(Equal("HTTP/3"))
|
Expect(rsp.Proto).To(Equal("HTTP/3"))
|
||||||
Expect(rsp.ProtoMajor).To(Equal(3))
|
Expect(rsp.ProtoMajor).To(Equal(3))
|
||||||
|
@ -604,7 +604,7 @@ var _ = Describe("Client", func() {
|
||||||
body := &mockBody{}
|
body := &mockBody{}
|
||||||
body.SetData([]byte("request body"))
|
body.SetData([]byte("request body"))
|
||||||
var err error
|
var err error
|
||||||
request, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
|
req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
|
||||||
})
|
})
|
||||||
|
@ -620,7 +620,7 @@ var _ = Describe("Client", func() {
|
||||||
<-done
|
<-done
|
||||||
return 0, errors.New("test done")
|
return 0, errors.New("test done")
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("test done"))
|
Expect(err).To(MatchError("test done"))
|
||||||
hfs := decodeHeader(strBuf)
|
hfs := decodeHeader(strBuf)
|
||||||
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
|
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
|
||||||
|
@ -628,7 +628,7 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns the error that occurred when reading the body", func() {
|
It("returns the error that occurred when reading the body", func() {
|
||||||
request.Body.(*mockBody).readErr = errors.New("testErr")
|
req.Body.(*mockBody).readErr = errors.New("testErr")
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
gomock.InOrder(
|
gomock.InOrder(
|
||||||
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) {
|
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) {
|
||||||
|
@ -642,7 +642,7 @@ var _ = Describe("Client", func() {
|
||||||
<-done
|
<-done
|
||||||
return 0, errors.New("test done")
|
return 0, errors.New("test done")
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("test done"))
|
Expect(err).To(MatchError("test done"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -660,7 +660,7 @@ var _ = Describe("Client", func() {
|
||||||
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors
|
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors
|
||||||
// the response body is sent asynchronously, while already reading the response
|
// the response body is sent asynchronously, while already reading the response
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||||
req, err := client.RoundTrip(request)
|
req, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(req.ContentLength).To(BeEquivalentTo(1337))
|
Expect(req.ContentLength).To(BeEquivalentTo(1337))
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
|
@ -673,7 +673,7 @@ var _ = Describe("Client", func() {
|
||||||
closed := make(chan struct{})
|
closed := make(chan struct{})
|
||||||
str.EXPECT().Close().Do(func() { close(closed) })
|
str.EXPECT().Close().Do(func() { close(closed) })
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
|
Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
|
||||||
Eventually(closed).Should(BeClosed())
|
Eventually(closed).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -685,7 +685,7 @@ var _ = Describe("Client", func() {
|
||||||
closed := make(chan struct{})
|
closed := make(chan struct{})
|
||||||
str.EXPECT().Close().Do(func() { close(closed) })
|
str.EXPECT().Close().Do(func() { close(closed) })
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
|
Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
|
||||||
Eventually(closed).Should(BeClosed())
|
Eventually(closed).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -694,12 +694,12 @@ var _ = Describe("Client", func() {
|
||||||
Context("request cancellations", func() {
|
Context("request cancellations", func() {
|
||||||
It("cancels a request while waiting for the handshake to complete", func() {
|
It("cancels a request while waiting for the handshake to complete", func() {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
req := request.WithContext(ctx)
|
req := req.WithContext(ctx)
|
||||||
conn.EXPECT().HandshakeComplete().Return(context.Background())
|
conn.EXPECT().HandshakeComplete().Return(context.Background())
|
||||||
|
|
||||||
errChan := make(chan error)
|
errChan := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
_, err := client.RoundTrip(req)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
errChan <- err
|
errChan <- err
|
||||||
}()
|
}()
|
||||||
Consistently(errChan).ShouldNot(Receive())
|
Consistently(errChan).ShouldNot(Receive())
|
||||||
|
@ -709,7 +709,7 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
It("cancels a request while the request is still in flight", func() {
|
It("cancels a request while the request is still in flight", func() {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
req := request.WithContext(ctx)
|
req := req.WithContext(ctx)
|
||||||
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||||
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
|
@ -729,7 +729,7 @@ var _ = Describe("Client", func() {
|
||||||
<-canceled
|
<-canceled
|
||||||
return 0, errors.New("test done")
|
return 0, errors.New("test done")
|
||||||
})
|
})
|
||||||
_, err := client.RoundTrip(req)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("test done"))
|
Expect(err).To(MatchError("test done"))
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -738,7 +738,7 @@ var _ = Describe("Client", func() {
|
||||||
rspBuf := bytes.NewBuffer(getResponse(404))
|
rspBuf := bytes.NewBuffer(getResponse(404))
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
req := request.WithContext(ctx)
|
req := req.WithContext(ctx)
|
||||||
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||||
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
||||||
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
|
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
|
||||||
|
@ -750,7 +750,7 @@ var _ = Describe("Client", func() {
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||||
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
|
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
|
||||||
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
|
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
|
||||||
_, err := client.RoundTrip(req)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
cancel()
|
cancel()
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
|
@ -771,7 +771,7 @@ var _ = Describe("Client", func() {
|
||||||
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
|
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
|
||||||
)
|
)
|
||||||
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
|
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
|
||||||
_, err := client.RoundTrip(request)
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("test done"))
|
Expect(err).To(MatchError("test done"))
|
||||||
hfs := decodeHeader(buf)
|
hfs := decodeHeader(buf)
|
||||||
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
||||||
|
@ -788,7 +788,7 @@ var _ = Describe("Client", func() {
|
||||||
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
|
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
|
||||||
)
|
)
|
||||||
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
|
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
|
||||||
_, err = client.RoundTrip(request)
|
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).To(MatchError("test done"))
|
Expect(err).To(MatchError("test done"))
|
||||||
hfs := decodeHeader(buf)
|
hfs := decodeHeader(buf)
|
||||||
Expect(hfs).ToNot(HaveKey("accept-encoding"))
|
Expect(hfs).ToNot(HaveKey("accept-encoding"))
|
||||||
|
@ -810,7 +810,7 @@ var _ = Describe("Client", func() {
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||||
str.EXPECT().Close()
|
str.EXPECT().Close()
|
||||||
|
|
||||||
rsp, err := client.RoundTrip(request)
|
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
data, err := ioutil.ReadAll(rsp.Body)
|
data, err := ioutil.ReadAll(rsp.Body)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -833,7 +833,7 @@ var _ = Describe("Client", func() {
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||||
str.EXPECT().Close()
|
str.EXPECT().Close()
|
||||||
|
|
||||||
rsp, err := client.RoundTrip(request)
|
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
data, err := ioutil.ReadAll(rsp.Body)
|
data, err := ioutil.ReadAll(rsp.Body)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
21
http3/conn.go
Normal file
21
http3/conn.go
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
package http3
|
||||||
|
|
||||||
|
import "github.com/lucas-clemente/quic-go"
|
||||||
|
|
||||||
|
type ConnState struct {
|
||||||
|
SupportsDatagram bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
conn quic.Connection
|
||||||
|
|
||||||
|
supportsDatagram bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) State() ConnState {
|
||||||
|
return ConnState{SupportsDatagram: c.supportsDatagram}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SendDatagram(b []byte) error {
|
||||||
|
return c.conn.SendMessage(b)
|
||||||
|
}
|
|
@ -38,7 +38,7 @@ func newRequestWriter(logger utils.Logger) *requestWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error {
|
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, dontCloseStr, gzip bool) error {
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
if err := w.writeHeaders(buf, req, gzip); err != nil {
|
if err := w.writeHeaders(buf, req, gzip); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -48,7 +48,9 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo
|
||||||
}
|
}
|
||||||
// TODO: add support for trailers
|
// TODO: add support for trailers
|
||||||
if req.Body == nil {
|
if req.Body == nil {
|
||||||
str.Close()
|
if !dontCloseStr {
|
||||||
|
str.Close()
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,7 +86,9 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
str.Close()
|
if !dontCloseStr {
|
||||||
|
str.Close()
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -60,7 +60,7 @@ var _ = Describe("Request Writer", func() {
|
||||||
str.EXPECT().Close()
|
str.EXPECT().Close()
|
||||||
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil)
|
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||||
headerFields := decode(strBuf)
|
headerFields := decode(strBuf)
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
|
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
|
||||||
|
@ -69,13 +69,21 @@ var _ = Describe("Request Writer", func() {
|
||||||
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
|
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("writes a GET request without closing the stream", func() {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(rw.WriteRequest(str, req, true, false)).To(Succeed())
|
||||||
|
headerFields := decode(strBuf)
|
||||||
|
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||||
|
})
|
||||||
|
|
||||||
It("writes a POST request", func() {
|
It("writes a POST request", func() {
|
||||||
closed := make(chan struct{})
|
closed := make(chan struct{})
|
||||||
str.EXPECT().Close().Do(func() { close(closed) })
|
str.EXPECT().Close().Do(func() { close(closed) })
|
||||||
postData := bytes.NewReader([]byte("foobar"))
|
postData := bytes.NewReader([]byte("foobar"))
|
||||||
req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", postData)
|
req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", postData)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||||
|
|
||||||
Eventually(closed).Should(BeClosed())
|
Eventually(closed).Should(BeClosed())
|
||||||
headerFields := decode(strBuf)
|
headerFields := decode(strBuf)
|
||||||
|
@ -96,7 +104,7 @@ var _ = Describe("Request Writer", func() {
|
||||||
str.EXPECT().Close().Do(func() { close(closed) })
|
str.EXPECT().Close().Do(func() { close(closed) })
|
||||||
req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", &foobarReader{})
|
req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", &foobarReader{})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||||
|
|
||||||
Eventually(closed).Should(BeClosed())
|
Eventually(closed).Should(BeClosed())
|
||||||
headerFields := decode(strBuf)
|
headerFields := decode(strBuf)
|
||||||
|
@ -122,7 +130,7 @@ var _ = Describe("Request Writer", func() {
|
||||||
}
|
}
|
||||||
req.AddCookie(cookie1)
|
req.AddCookie(cookie1)
|
||||||
req.AddCookie(cookie2)
|
req.AddCookie(cookie2)
|
||||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||||
headerFields := decode(strBuf)
|
headerFields := decode(strBuf)
|
||||||
Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`))
|
Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`))
|
||||||
})
|
})
|
||||||
|
@ -131,7 +139,7 @@ var _ = Describe("Request Writer", func() {
|
||||||
str.EXPECT().Close()
|
str.EXPECT().Close()
|
||||||
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil)
|
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(rw.WriteRequest(str, req, true)).To(Succeed())
|
Expect(rw.WriteRequest(str, req, false, true)).To(Succeed())
|
||||||
headerFields := decode(strBuf)
|
headerFields := decode(strBuf)
|
||||||
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
||||||
})
|
})
|
||||||
|
@ -140,7 +148,7 @@ var _ = Describe("Request Writer", func() {
|
||||||
str.EXPECT().Close()
|
str.EXPECT().Close()
|
||||||
req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil)
|
req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||||
headerFields := decode(strBuf)
|
headerFields := decode(strBuf)
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT"))
|
Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT"))
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||||
|
@ -154,7 +162,7 @@ var _ = Describe("Request Writer", func() {
|
||||||
req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil)
|
req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
req.Proto = "webtransport"
|
req.Proto = "webtransport"
|
||||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||||
headerFields := decode(strBuf)
|
headerFields := decode(strBuf)
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT"))
|
Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT"))
|
||||||
|
|
|
@ -16,7 +16,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type roundTripCloser interface {
|
type roundTripCloser interface {
|
||||||
http.RoundTripper
|
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
|
||||||
io.Closer
|
io.Closer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,11 +77,16 @@ type RoundTripper struct {
|
||||||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
||||||
type RoundTripOpt struct {
|
type RoundTripOpt struct {
|
||||||
// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
|
// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
|
||||||
// If set true and no cached connection is available, RoundTrip will return ErrNoCachedConn.
|
// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
|
||||||
OnlyCachedConn bool
|
OnlyCachedConn bool
|
||||||
|
// DontCloseRequestStream controls whether the request stream is closed after sending the request.
|
||||||
|
DontCloseRequestStream bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ roundTripCloser = &RoundTripper{}
|
var (
|
||||||
|
_ http.RoundTripper = &RoundTripper{}
|
||||||
|
_ io.Closer = &RoundTripper{}
|
||||||
|
)
|
||||||
|
|
||||||
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
|
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
|
||||||
var ErrNoCachedConn = errors.New("http3: no cached connection was available")
|
var ErrNoCachedConn = errors.New("http3: no cached connection was available")
|
||||||
|
@ -127,7 +132,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return cl.RoundTrip(req)
|
return cl.RoundTripOpt(req, opt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip does a round trip.
|
// RoundTrip does a round trip.
|
||||||
|
@ -135,7 +140,7 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return r.RoundTripOpt(req, RoundTripOpt{})
|
return r.RoundTripOpt(req, RoundTripOpt{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
|
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripCloser, error) {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ type mockClient struct {
|
||||||
closed bool
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (m *mockClient) RoundTripOpt(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
||||||
return &http.Response{Request: req}, nil
|
return &http.Response{Request: req}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -138,7 +138,7 @@ var _ = Describe("Server", func() {
|
||||||
closed := make(chan struct{})
|
closed := make(chan struct{})
|
||||||
str.EXPECT().Close().Do(func() { close(closed) })
|
str.EXPECT().Close().Do(func() { close(closed) })
|
||||||
rw := newRequestWriter(utils.DefaultLogger)
|
rw := newRequestWriter(utils.DefaultLogger)
|
||||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||||
Eventually(closed).Should(BeClosed())
|
Eventually(closed).Should(BeClosed())
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue