mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
fix closing of http.Response and http.Request bodies
This commit is contained in:
parent
2133d01956
commit
39e29d8364
5 changed files with 138 additions and 163 deletions
|
@ -3,11 +3,13 @@ package http3
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
// The body of a http.Request or http.Response.
|
// The body of a http.Request or http.Response.
|
||||||
type body struct {
|
type body struct {
|
||||||
str io.ReadCloser
|
str quic.Stream
|
||||||
|
|
||||||
isRequest bool
|
isRequest bool
|
||||||
|
|
||||||
|
@ -16,14 +18,14 @@ type body struct {
|
||||||
|
|
||||||
var _ io.ReadCloser = &body{}
|
var _ io.ReadCloser = &body{}
|
||||||
|
|
||||||
func newRequestBody(str io.ReadCloser) *body {
|
func newRequestBody(str quic.Stream) *body {
|
||||||
return &body{
|
return &body{
|
||||||
str: str,
|
str: str,
|
||||||
isRequest: true,
|
isRequest: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newResponseBody(str io.ReadCloser) *body {
|
func newResponseBody(str quic.Stream) *body {
|
||||||
return &body{str: str}
|
return &body{str: str}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,7 +64,8 @@ func (r *body) Read(b []byte) (int, error) {
|
||||||
func (r *body) Close() error {
|
func (r *body) Close() error {
|
||||||
// quic.Stream.Close() closes the write side, not the read side
|
// quic.Stream.Close() closes the write side, not the read side
|
||||||
if r.isRequest {
|
if r.isRequest {
|
||||||
return nil
|
return r.str.Close()
|
||||||
}
|
}
|
||||||
return r.str.Close()
|
r.str.CancelRead(quic.ErrorCode(errorRequestCanceled))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,20 +2,17 @@ package http3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
type closingBuffer struct {
|
|
||||||
*bytes.Buffer
|
|
||||||
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *closingBuffer) Close() error { b.closed = true; return nil }
|
|
||||||
|
|
||||||
type bodyType uint8
|
type bodyType uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -23,9 +20,19 @@ const (
|
||||||
bodyTypeResponse
|
bodyTypeResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (t bodyType) String() string {
|
||||||
|
if t == bodyTypeRequest {
|
||||||
|
return "request"
|
||||||
|
}
|
||||||
|
return "response"
|
||||||
|
}
|
||||||
|
|
||||||
var _ = Describe("Body", func() {
|
var _ = Describe("Body", func() {
|
||||||
var rb *body
|
var (
|
||||||
var buf *bytes.Buffer
|
rb *body
|
||||||
|
str *mockquic.MockStream
|
||||||
|
buf *bytes.Buffer
|
||||||
|
)
|
||||||
|
|
||||||
getDataFrame := func(data []byte) []byte {
|
getDataFrame := func(data []byte) []byte {
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
|
@ -41,110 +48,119 @@ var _ = Describe("Body", func() {
|
||||||
for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} {
|
for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} {
|
||||||
bodyType := bt
|
bodyType := bt
|
||||||
|
|
||||||
BeforeEach(func() {
|
Context(fmt.Sprintf("using a %s body", bodyType), func() {
|
||||||
cb := &closingBuffer{Buffer: buf}
|
BeforeEach(func() {
|
||||||
switch bodyType {
|
str = mockquic.NewMockStream(mockCtrl)
|
||||||
case bodyTypeRequest:
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
|
||||||
rb = newRequestBody(cb)
|
return buf.Write(b)
|
||||||
case bodyTypeResponse:
|
}).AnyTimes()
|
||||||
rb = newResponseBody(cb)
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
|
||||||
|
return buf.Read(b)
|
||||||
|
}).AnyTimes()
|
||||||
|
|
||||||
|
switch bodyType {
|
||||||
|
case bodyTypeRequest:
|
||||||
|
rb = newRequestBody(str)
|
||||||
|
case bodyTypeResponse:
|
||||||
|
rb = newResponseBody(str)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads DATA frames in a single run", func() {
|
||||||
|
buf.Write(getDataFrame([]byte("foobar")))
|
||||||
|
b := make([]byte, 6)
|
||||||
|
n, err := rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(6))
|
||||||
|
Expect(b).To(Equal([]byte("foobar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads DATA frames in multiple runs", func() {
|
||||||
|
buf.Write(getDataFrame([]byte("foobar")))
|
||||||
|
b := make([]byte, 3)
|
||||||
|
n, err := rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(3))
|
||||||
|
Expect(b).To(Equal([]byte("foo")))
|
||||||
|
n, err = rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(3))
|
||||||
|
Expect(b).To(Equal([]byte("bar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads DATA frames into too large buffers", func() {
|
||||||
|
buf.Write(getDataFrame([]byte("foobar")))
|
||||||
|
b := make([]byte, 10)
|
||||||
|
n, err := rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(6))
|
||||||
|
Expect(b[:n]).To(Equal([]byte("foobar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads DATA frames into too large buffers, in multiple runs", func() {
|
||||||
|
buf.Write(getDataFrame([]byte("foobar")))
|
||||||
|
b := make([]byte, 4)
|
||||||
|
n, err := rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(4))
|
||||||
|
Expect(b).To(Equal([]byte("foob")))
|
||||||
|
n, err = rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(2))
|
||||||
|
Expect(b[:n]).To(Equal([]byte("ar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads multiple DATA frames", func() {
|
||||||
|
buf.Write(getDataFrame([]byte("foo")))
|
||||||
|
buf.Write(getDataFrame([]byte("bar")))
|
||||||
|
b := make([]byte, 6)
|
||||||
|
n, err := rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(3))
|
||||||
|
Expect(b[:n]).To(Equal([]byte("foo")))
|
||||||
|
n, err = rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(3))
|
||||||
|
Expect(b[:n]).To(Equal([]byte("bar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("skips HEADERS frames", func() {
|
||||||
|
buf.Write(getDataFrame([]byte("foo")))
|
||||||
|
(&headersFrame{Length: 10}).Write(buf)
|
||||||
|
buf.Write(make([]byte, 10))
|
||||||
|
buf.Write(getDataFrame([]byte("bar")))
|
||||||
|
b := make([]byte, 6)
|
||||||
|
n, err := io.ReadFull(rb, b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(6))
|
||||||
|
Expect(b).To(Equal([]byte("foobar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when it can't parse the frame", func() {
|
||||||
|
buf.Write([]byte("invalid"))
|
||||||
|
_, err := rb.Read([]byte{0})
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors on unexpected frames", func() {
|
||||||
|
(&settingsFrame{}).Write(buf)
|
||||||
|
_, err := rb.Read([]byte{0})
|
||||||
|
Expect(err).To(MatchError("unexpected frame"))
|
||||||
|
})
|
||||||
|
|
||||||
|
if bodyType == bodyTypeRequest {
|
||||||
|
It("closes requests", func() {
|
||||||
|
str.EXPECT().Close()
|
||||||
|
Expect(rb.Close()).To(Succeed())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if bodyType == bodyTypeResponse {
|
||||||
|
It("closes responses", func() {
|
||||||
|
str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled))
|
||||||
|
Expect(rb.Close()).To(Succeed())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
It("reads DATA frames in a single run", func() {
|
|
||||||
buf.Write(getDataFrame([]byte("foobar")))
|
|
||||||
b := make([]byte, 6)
|
|
||||||
n, err := rb.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(6))
|
|
||||||
Expect(b).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads DATA frames in multiple runs", func() {
|
|
||||||
buf.Write(getDataFrame([]byte("foobar")))
|
|
||||||
b := make([]byte, 3)
|
|
||||||
n, err := rb.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(3))
|
|
||||||
Expect(b).To(Equal([]byte("foo")))
|
|
||||||
n, err = rb.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(3))
|
|
||||||
Expect(b).To(Equal([]byte("bar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads DATA frames into too large buffers", func() {
|
|
||||||
buf.Write(getDataFrame([]byte("foobar")))
|
|
||||||
b := make([]byte, 10)
|
|
||||||
n, err := rb.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(6))
|
|
||||||
Expect(b[:n]).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads DATA frames into too large buffers, in multiple runs", func() {
|
|
||||||
buf.Write(getDataFrame([]byte("foobar")))
|
|
||||||
b := make([]byte, 4)
|
|
||||||
n, err := rb.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(4))
|
|
||||||
Expect(b).To(Equal([]byte("foob")))
|
|
||||||
n, err = rb.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(2))
|
|
||||||
Expect(b[:n]).To(Equal([]byte("ar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads multiple DATA frames", func() {
|
|
||||||
buf.Write(getDataFrame([]byte("foo")))
|
|
||||||
buf.Write(getDataFrame([]byte("bar")))
|
|
||||||
b := make([]byte, 6)
|
|
||||||
n, err := rb.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(3))
|
|
||||||
Expect(b[:n]).To(Equal([]byte("foo")))
|
|
||||||
n, err = rb.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(3))
|
|
||||||
Expect(b[:n]).To(Equal([]byte("bar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("skips HEADERS frames", func() {
|
|
||||||
buf.Write(getDataFrame([]byte("foo")))
|
|
||||||
(&headersFrame{Length: 10}).Write(buf)
|
|
||||||
buf.Write(make([]byte, 10))
|
|
||||||
buf.Write(getDataFrame([]byte("bar")))
|
|
||||||
b := make([]byte, 6)
|
|
||||||
n, err := io.ReadFull(rb, b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(6))
|
|
||||||
Expect(b).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when it can't parse the frame", func() {
|
|
||||||
buf.Write([]byte("invalid"))
|
|
||||||
_, err := rb.Read([]byte{0})
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors on unexpected frames", func() {
|
|
||||||
(&settingsFrame{}).Write(buf)
|
|
||||||
_, err := rb.Read([]byte{0})
|
|
||||||
Expect(err).To(MatchError("unexpected frame"))
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
It("closes requests", func() {
|
|
||||||
cb := &closingBuffer{Buffer: buf}
|
|
||||||
rb := newRequestBody(cb)
|
|
||||||
Expect(rb.Close()).To(Succeed())
|
|
||||||
Expect(cb.closed).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes responses", func() {
|
|
||||||
cb := &closingBuffer{Buffer: buf}
|
|
||||||
rb := newResponseBody(cb)
|
|
||||||
Expect(rb.Close()).To(Succeed())
|
|
||||||
Expect(cb.closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
|
@ -187,7 +187,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
res.Header.Add(hf.Name, hf.Value)
|
res.Header.Add(hf.Name, hf.Value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
respBody := newResponseBody(&responseBody{str})
|
respBody := newResponseBody(str)
|
||||||
if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
||||||
res.Header.Del("Content-Encoding")
|
res.Header.Del("Content-Encoding")
|
||||||
res.Header.Del("Content-Length")
|
res.Header.Del("Content-Length")
|
||||||
|
|
|
@ -1,18 +0,0 @@
|
||||||
package http3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
)
|
|
||||||
|
|
||||||
type responseBody struct {
|
|
||||||
quic.Stream
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ io.ReadCloser = &responseBody{}
|
|
||||||
|
|
||||||
func (rb *responseBody) Close() error {
|
|
||||||
rb.Stream.CancelRead(0)
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,26 +0,0 @@
|
||||||
package http3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Response Body", func() {
|
|
||||||
var (
|
|
||||||
stream *mockquic.MockStream
|
|
||||||
body *responseBody
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
stream = mockquic.NewMockStream(mockCtrl)
|
|
||||||
body = &responseBody{stream}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("calls CancelRead when closing", func() {
|
|
||||||
stream.EXPECT().CancelRead(gomock.Any())
|
|
||||||
Expect(body.Close()).To(Succeed())
|
|
||||||
})
|
|
||||||
})
|
|
Loading…
Add table
Add a link
Reference in a new issue