mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
introduce a http3.RoundTripOpt to prevent closing of request stream (#3411)
This commit is contained in:
parent
6511723c28
commit
de5f08171b
8 changed files with 107 additions and 73 deletions
|
@ -236,10 +236,10 @@ func (c *client) maxHeaderBytes() uint64 {
|
|||
return uint64(c.opts.MaxHeaderBytes)
|
||||
}
|
||||
|
||||
// RoundTrip executes a request and returns a response
|
||||
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// RoundTripOpt executes a request and returns a response
|
||||
func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
|
||||
return nil, fmt.Errorf("http3 client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
||||
return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
||||
}
|
||||
|
||||
c.dialOnce.Do(func() {
|
||||
|
@ -268,7 +268,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
}
|
||||
|
||||
// Request Cancellation:
|
||||
// This go routine keeps running even after RoundTrip() returns.
|
||||
// This go routine keeps running even after RoundTripOpt() returns.
|
||||
// It is shut down when the application is done processing the body.
|
||||
reqDone := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -280,7 +280,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
}
|
||||
}()
|
||||
|
||||
rsp, rerr := c.doRequest(req, str, reqDone)
|
||||
rsp, rerr := c.doRequest(req, str, opt, reqDone)
|
||||
if rerr.err != nil { // if any error occurred
|
||||
close(reqDone)
|
||||
if rerr.streamErr != 0 { // if it was a stream error
|
||||
|
@ -297,16 +297,12 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
return rsp, rerr.err
|
||||
}
|
||||
|
||||
func (c *client) doRequest(
|
||||
req *http.Request,
|
||||
str quic.Stream,
|
||||
reqDone chan struct{},
|
||||
) (*http.Response, requestError) {
|
||||
func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan struct{}) (*http.Response, requestError) {
|
||||
var requestGzip bool
|
||||
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
|
||||
requestGzip = true
|
||||
}
|
||||
if err := c.requestWriter.WriteRequest(str, req, requestGzip); err != nil {
|
||||
if err := c.requestWriter.WriteRequest(str, req, opt.DontCloseRequestStream, requestGzip); err != nil {
|
||||
return nil, newStreamError(errorInternalError, err)
|
||||
}
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ var _ = Describe("Client", func() {
|
|||
dialAddrCalled = true
|
||||
return nil, errors.New("test done")
|
||||
}
|
||||
client.RoundTrip(req)
|
||||
client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(dialAddrCalled).To(BeTrue())
|
||||
})
|
||||
|
||||
|
@ -87,7 +87,7 @@ var _ = Describe("Client", func() {
|
|||
}
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
client.RoundTrip(req)
|
||||
client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(dialAddrCalled).To(BeTrue())
|
||||
})
|
||||
|
||||
|
@ -108,7 +108,7 @@ var _ = Describe("Client", func() {
|
|||
dialAddrCalled = true
|
||||
return nil, errors.New("test done")
|
||||
}
|
||||
client.RoundTrip(req)
|
||||
client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(dialAddrCalled).To(BeTrue())
|
||||
// make sure the original tls.Config was not modified
|
||||
Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
|
||||
|
@ -131,7 +131,7 @@ var _ = Describe("Client", func() {
|
|||
}
|
||||
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = client.RoundTrip(req.WithContext(ctx))
|
||||
_, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(dialerCalled).To(BeTrue())
|
||||
})
|
||||
|
@ -144,7 +144,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(quicConf.EnableDatagrams).To(BeTrue())
|
||||
return nil, testErr
|
||||
}
|
||||
_, err = client.RoundTrip(req)
|
||||
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
|
@ -155,7 +155,7 @@ var _ = Describe("Client", func() {
|
|||
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
_, err = client.RoundTrip(req)
|
||||
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
|
@ -169,8 +169,8 @@ var _ = Describe("Client", func() {
|
|||
It("refuses to do requests for the wrong host", func() {
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = client.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3 client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
|
||||
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
|
||||
})
|
||||
|
||||
It("allows requests using a different scheme", func() {
|
||||
|
@ -180,14 +180,14 @@ var _ = Describe("Client", func() {
|
|||
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
_, err = client.RoundTrip(req)
|
||||
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
|
||||
Context("hijacking unidirectional streams", func() {
|
||||
var (
|
||||
request *http.Request
|
||||
req *http.Request
|
||||
conn *mockquic.MockEarlyConnection
|
||||
settingsFrameWritten chan struct{}
|
||||
)
|
||||
|
@ -209,7 +209,7 @@ var _ = Describe("Client", func() {
|
|||
return conn, nil
|
||||
}
|
||||
var err error
|
||||
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
|
@ -236,7 +236,7 @@ var _ = Describe("Client", func() {
|
|||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
|
@ -261,7 +261,7 @@ var _ = Describe("Client", func() {
|
|||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
|
@ -270,7 +270,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
Context("control stream handling", func() {
|
||||
var (
|
||||
request *http.Request
|
||||
req *http.Request
|
||||
conn *mockquic.MockEarlyConnection
|
||||
settingsFrameWritten chan struct{}
|
||||
)
|
||||
|
@ -291,7 +291,7 @@ var _ = Describe("Client", func() {
|
|||
return conn, nil
|
||||
}
|
||||
var err error
|
||||
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
|
@ -313,7 +313,7 @@ var _ = Describe("Client", func() {
|
|||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
})
|
||||
|
@ -338,7 +338,7 @@ var _ = Describe("Client", func() {
|
|||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
|
||||
})
|
||||
|
@ -361,7 +361,7 @@ var _ = Describe("Client", func() {
|
|||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -385,7 +385,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(code).To(BeEquivalentTo(errorMissingSettings))
|
||||
close(done)
|
||||
})
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -411,7 +411,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(code).To(BeEquivalentTo(errorFrameError))
|
||||
close(done)
|
||||
})
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -434,7 +434,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(code).To(BeEquivalentTo(errorIDError))
|
||||
close(done)
|
||||
})
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -461,7 +461,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(reason).To(Equal("missing QUIC Datagram support"))
|
||||
close(done)
|
||||
})
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -469,7 +469,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
Context("Doing requests", func() {
|
||||
var (
|
||||
request *http.Request
|
||||
req *http.Request
|
||||
str *mockquic.MockStream
|
||||
conn *mockquic.MockEarlyConnection
|
||||
settingsFrameWritten chan struct{}
|
||||
|
@ -540,7 +540,7 @@ var _ = Describe("Client", func() {
|
|||
return conn, nil
|
||||
}
|
||||
var err error
|
||||
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
|
@ -554,13 +554,13 @@ var _ = Describe("Client", func() {
|
|||
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
|
||||
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("performs a 0-RTT request", func() {
|
||||
testErr := errors.New("stream open error")
|
||||
request.Method = MethodGet0RTT
|
||||
req.Method = MethodGet0RTT
|
||||
// don't EXPECT any calls to HandshakeComplete()
|
||||
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||
buf := &bytes.Buffer{}
|
||||
|
@ -570,7 +570,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
||||
return 0, testErr
|
||||
})
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
|
||||
})
|
||||
|
@ -585,7 +585,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Close()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
rsp, err := client.RoundTrip(request)
|
||||
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Proto).To(Equal("HTTP/3"))
|
||||
Expect(rsp.ProtoMajor).To(Equal(3))
|
||||
|
@ -604,7 +604,7 @@ var _ = Describe("Client", func() {
|
|||
body := &mockBody{}
|
||||
body.SetData([]byte("request body"))
|
||||
var err error
|
||||
request, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
|
||||
req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
|
||||
})
|
||||
|
@ -620,7 +620,7 @@ var _ = Describe("Client", func() {
|
|||
<-done
|
||||
return 0, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
hfs := decodeHeader(strBuf)
|
||||
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
|
||||
|
@ -628,7 +628,7 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("returns the error that occurred when reading the body", func() {
|
||||
request.Body.(*mockBody).readErr = errors.New("testErr")
|
||||
req.Body.(*mockBody).readErr = errors.New("testErr")
|
||||
done := make(chan struct{})
|
||||
gomock.InOrder(
|
||||
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) {
|
||||
|
@ -642,7 +642,7 @@ var _ = Describe("Client", func() {
|
|||
<-done
|
||||
return 0, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
})
|
||||
|
||||
|
@ -660,7 +660,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors
|
||||
// the response body is sent asynchronously, while already reading the response
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
req, err := client.RoundTrip(request)
|
||||
req, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(req.ContentLength).To(BeEquivalentTo(1337))
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -673,7 +673,7 @@ var _ = Describe("Client", func() {
|
|||
closed := make(chan struct{})
|
||||
str.EXPECT().Close().Do(func() { close(closed) })
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
|
||||
Eventually(closed).Should(BeClosed())
|
||||
})
|
||||
|
@ -685,7 +685,7 @@ var _ = Describe("Client", func() {
|
|||
closed := make(chan struct{})
|
||||
str.EXPECT().Close().Do(func() { close(closed) })
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
|
||||
Eventually(closed).Should(BeClosed())
|
||||
})
|
||||
|
@ -694,12 +694,12 @@ var _ = Describe("Client", func() {
|
|||
Context("request cancellations", func() {
|
||||
It("cancels a request while waiting for the handshake to complete", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := request.WithContext(ctx)
|
||||
req := req.WithContext(ctx)
|
||||
conn.EXPECT().HandshakeComplete().Return(context.Background())
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
_, err := client.RoundTrip(req)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
errChan <- err
|
||||
}()
|
||||
Consistently(errChan).ShouldNot(Receive())
|
||||
|
@ -709,7 +709,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("cancels a request while the request is still in flight", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := request.WithContext(ctx)
|
||||
req := req.WithContext(ctx)
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
||||
buf := &bytes.Buffer{}
|
||||
|
@ -729,7 +729,7 @@ var _ = Describe("Client", func() {
|
|||
<-canceled
|
||||
return 0, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTrip(req)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -738,7 +738,7 @@ var _ = Describe("Client", func() {
|
|||
rspBuf := bytes.NewBuffer(getResponse(404))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := request.WithContext(ctx)
|
||||
req := req.WithContext(ctx)
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
||||
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
|
||||
|
@ -750,7 +750,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
|
||||
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
|
||||
_, err := client.RoundTrip(req)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cancel()
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -771,7 +771,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
|
||||
)
|
||||
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
|
||||
_, err := client.RoundTrip(request)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
hfs := decodeHeader(buf)
|
||||
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
||||
|
@ -788,7 +788,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
|
||||
)
|
||||
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
|
||||
_, err = client.RoundTrip(request)
|
||||
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
hfs := decodeHeader(buf)
|
||||
Expect(hfs).ToNot(HaveKey("accept-encoding"))
|
||||
|
@ -810,7 +810,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
str.EXPECT().Close()
|
||||
|
||||
rsp, err := client.RoundTrip(request)
|
||||
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(rsp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -833,7 +833,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
str.EXPECT().Close()
|
||||
|
||||
rsp, err := client.RoundTrip(request)
|
||||
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(rsp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
21
http3/conn.go
Normal file
21
http3/conn.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package http3
|
||||
|
||||
import "github.com/lucas-clemente/quic-go"
|
||||
|
||||
type ConnState struct {
|
||||
SupportsDatagram bool
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
conn quic.Connection
|
||||
|
||||
supportsDatagram bool
|
||||
}
|
||||
|
||||
func (c *Conn) State() ConnState {
|
||||
return ConnState{SupportsDatagram: c.supportsDatagram}
|
||||
}
|
||||
|
||||
func (c *Conn) SendDatagram(b []byte) error {
|
||||
return c.conn.SendMessage(b)
|
||||
}
|
|
@ -38,7 +38,7 @@ func newRequestWriter(logger utils.Logger) *requestWriter {
|
|||
}
|
||||
}
|
||||
|
||||
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error {
|
||||
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, dontCloseStr, gzip bool) error {
|
||||
buf := &bytes.Buffer{}
|
||||
if err := w.writeHeaders(buf, req, gzip); err != nil {
|
||||
return err
|
||||
|
@ -48,7 +48,9 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo
|
|||
}
|
||||
// TODO: add support for trailers
|
||||
if req.Body == nil {
|
||||
str.Close()
|
||||
if !dontCloseStr {
|
||||
str.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -84,7 +86,9 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo
|
|||
return
|
||||
}
|
||||
}
|
||||
str.Close()
|
||||
if !dontCloseStr {
|
||||
str.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
|
|
|
@ -60,7 +60,7 @@ var _ = Describe("Request Writer", func() {
|
|||
str.EXPECT().Close()
|
||||
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
|
||||
|
@ -69,13 +69,21 @@ var _ = Describe("Request Writer", func() {
|
|||
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
|
||||
})
|
||||
|
||||
It("writes a GET request without closing the stream", func() {
|
||||
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, true, false)).To(Succeed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||
})
|
||||
|
||||
It("writes a POST request", func() {
|
||||
closed := make(chan struct{})
|
||||
str.EXPECT().Close().Do(func() { close(closed) })
|
||||
postData := bytes.NewReader([]byte("foobar"))
|
||||
req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", postData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
|
||||
Eventually(closed).Should(BeClosed())
|
||||
headerFields := decode(strBuf)
|
||||
|
@ -96,7 +104,7 @@ var _ = Describe("Request Writer", func() {
|
|||
str.EXPECT().Close().Do(func() { close(closed) })
|
||||
req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", &foobarReader{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
|
||||
Eventually(closed).Should(BeClosed())
|
||||
headerFields := decode(strBuf)
|
||||
|
@ -122,7 +130,7 @@ var _ = Describe("Request Writer", func() {
|
|||
}
|
||||
req.AddCookie(cookie1)
|
||||
req.AddCookie(cookie2)
|
||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`))
|
||||
})
|
||||
|
@ -131,7 +139,7 @@ var _ = Describe("Request Writer", func() {
|
|||
str.EXPECT().Close()
|
||||
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, true)).To(Succeed())
|
||||
Expect(rw.WriteRequest(str, req, false, true)).To(Succeed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
||||
})
|
||||
|
@ -140,7 +148,7 @@ var _ = Describe("Request Writer", func() {
|
|||
str.EXPECT().Close()
|
||||
req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT"))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||
|
@ -154,7 +162,7 @@ var _ = Describe("Request Writer", func() {
|
|||
req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Proto = "webtransport"
|
||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT"))
|
||||
|
|
|
@ -16,7 +16,7 @@ import (
|
|||
)
|
||||
|
||||
type roundTripCloser interface {
|
||||
http.RoundTripper
|
||||
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
|
||||
io.Closer
|
||||
}
|
||||
|
||||
|
@ -77,11 +77,16 @@ type RoundTripper struct {
|
|||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
||||
type RoundTripOpt struct {
|
||||
// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
|
||||
// If set true and no cached connection is available, RoundTrip will return ErrNoCachedConn.
|
||||
// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
|
||||
OnlyCachedConn bool
|
||||
// DontCloseRequestStream controls whether the request stream is closed after sending the request.
|
||||
DontCloseRequestStream bool
|
||||
}
|
||||
|
||||
var _ roundTripCloser = &RoundTripper{}
|
||||
var (
|
||||
_ http.RoundTripper = &RoundTripper{}
|
||||
_ io.Closer = &RoundTripper{}
|
||||
)
|
||||
|
||||
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
|
||||
var ErrNoCachedConn = errors.New("http3: no cached connection was available")
|
||||
|
@ -127,7 +132,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cl.RoundTrip(req)
|
||||
return cl.RoundTripOpt(req, opt)
|
||||
}
|
||||
|
||||
// RoundTrip does a round trip.
|
||||
|
@ -135,7 +140,7 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
return r.RoundTripOpt(req, RoundTripOpt{})
|
||||
}
|
||||
|
||||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
|
||||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripCloser, error) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ type mockClient struct {
|
|||
closed bool
|
||||
}
|
||||
|
||||
func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (m *mockClient) RoundTripOpt(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
||||
return &http.Response{Request: req}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -138,7 +138,7 @@ var _ = Describe("Server", func() {
|
|||
closed := make(chan struct{})
|
||||
str.EXPECT().Close().Do(func() { close(closed) })
|
||||
rw := newRequestWriter(utils.DefaultLogger)
|
||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
Eventually(closed).Should(BeClosed())
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue