refactor HTTP/3 stream handling to use a dedicated stream

Reading from and writing onto this stream applies HTTP/3 DATA framing.
This commit is contained in:
Marten Seemann 2022-05-29 19:22:05 +02:00
parent ccf897e519
commit 04d46526c7
10 changed files with 344 additions and 360 deletions

View file

@ -2,7 +2,6 @@ package http3
import (
"context"
"fmt"
"io"
"net"
@ -29,42 +28,43 @@ type Hijacker interface {
// The body of a http.Request or http.Response.
type body struct {
str quic.Stream
}
var _ io.ReadCloser = &body{}
func newRequestBody(str Stream) *body {
return &body{str: str}
}
func (r *body) Read(b []byte) (int, error) {
return r.str.Read(b)
}
func (r *body) Close() error {
r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
return nil
}
type hijackableBody struct {
body
conn quic.Connection // only needed to implement Hijacker
// only set for the http.Response
// The channel is closed when the user is done with this response:
// either when Read() errors, or when Close() is called.
reqDone chan<- struct{}
reqDoneClosed bool
onFrameError func()
bytesRemainingInFrame uint64
}
var _ io.ReadCloser = &body{}
type hijackableBody struct {
body
conn quic.Connection // only needed to implement Hijacker
}
var _ Hijacker = &hijackableBody{}
func newRequestBody(str quic.Stream, onFrameError func()) *body {
return &body{
str: str,
onFrameError: onFrameError,
}
}
func newResponseBody(str quic.Stream, conn quic.Connection, done chan<- struct{}, onFrameError func()) *hijackableBody {
func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody {
return &hijackableBody{
body: body{
str: str,
onFrameError: onFrameError,
reqDone: done,
str: str,
},
conn: conn,
reqDone: done,
conn: conn,
}
}
@ -72,50 +72,15 @@ func (r *hijackableBody) StreamCreator() StreamCreator {
return r.conn
}
func (r *body) Read(b []byte) (int, error) {
n, err := r.readImpl(b)
func (r *hijackableBody) Read(b []byte) (int, error) {
n, err := r.str.Read(b)
if err != nil {
r.requestDone()
}
return n, err
}
func (r *body) readImpl(b []byte) (int, error) {
if r.bytesRemainingInFrame == 0 {
parseLoop:
for {
frame, err := parseNextFrame(r.str, nil)
if err != nil {
return 0, err
}
switch f := frame.(type) {
case *headersFrame:
// skip HEADERS frames
continue
case *dataFrame:
r.bytesRemainingInFrame = f.Length
break parseLoop
default:
r.onFrameError()
// parseNextFrame skips over unknown frame types
// Therefore, this condition is only entered when we parsed another known frame type.
return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
}
}
}
var n int
var err error
if r.bytesRemainingInFrame < uint64(len(b)) {
n, err = r.str.Read(b[:r.bytesRemainingInFrame])
} else {
n, err = r.str.Read(b)
}
r.bytesRemainingInFrame -= uint64(n)
return n, err
}
func (r *body) requestDone() {
func (r *hijackableBody) requestDone() {
if r.reqDoneClosed || r.reqDone == nil {
return
}
@ -127,7 +92,7 @@ func (r *body) StreamID() quic.StreamID {
return r.str.StreamID()
}
func (r *body) Close() error {
func (r *hijackableBody) Close() error {
r.requestDone()
// If the EOF was read, CancelRead() is a no-op.
r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))

View file

@ -1,189 +1,54 @@
package http3
import (
"bytes"
"fmt"
"io"
"errors"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type bodyType uint8
var _ = Describe("Response Body", func() {
var reqDone chan struct{}
const (
bodyTypeRequest bodyType = iota
bodyTypeResponse
)
BeforeEach(func() { reqDone = make(chan struct{}) })
func (t bodyType) String() string {
if t == bodyTypeRequest {
return "request"
}
return "response"
}
var _ = Describe("Body", func() {
var (
rb io.ReadCloser
str *mockquic.MockStream
buf *bytes.Buffer
reqDone chan struct{}
errorCbCalled bool
)
errorCb := func() { errorCbCalled = true }
getDataFrame := func(data []byte) []byte {
b := &bytes.Buffer{}
(&dataFrame{Length: uint64(len(data))}).Write(b)
b.Write(data)
return b.Bytes()
}
BeforeEach(func() {
buf = &bytes.Buffer{}
errorCbCalled = false
It("closes the reqDone channel when Read errors", func() {
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error"))
rb := newResponseBody(str, nil, reqDone)
_, err := rb.Read([]byte{0})
Expect(err).To(MatchError("test error"))
Expect(reqDone).To(BeClosed())
})
for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} {
bodyType := bt
It("allows multiple calls to Read, when Read errors", func() {
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error")).Times(2)
rb := newResponseBody(str, nil, reqDone)
_, err := rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(reqDone).To(BeClosed())
_, err = rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
})
Context(fmt.Sprintf("using a %s body", bodyType), func() {
BeforeEach(func() {
str = mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return buf.Write(b)
}).AnyTimes()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return buf.Read(b)
}).AnyTimes()
It("closes responses", func() {
str := mockquic.NewMockStream(mockCtrl)
rb := newResponseBody(str, nil, reqDone)
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled))
Expect(rb.Close()).To(Succeed())
})
switch bodyType {
case bodyTypeRequest:
rb = newRequestBody(str, errorCb)
case bodyTypeResponse:
reqDone = make(chan struct{})
rb = newResponseBody(str, nil, reqDone, errorCb)
}
})
It("reads DATA frames in a single run", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 6)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b).To(Equal([]byte("foobar")))
})
It("reads DATA frames in multiple runs", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 3)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b).To(Equal([]byte("foo")))
n, err = rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b).To(Equal([]byte("bar")))
})
It("reads DATA frames into too large buffers", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 10)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b[:n]).To(Equal([]byte("foobar")))
})
It("reads DATA frames into too large buffers, in multiple runs", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 4)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte("foob")))
n, err = rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(2))
Expect(b[:n]).To(Equal([]byte("ar")))
})
It("reads multiple DATA frames", func() {
buf.Write(getDataFrame([]byte("foo")))
buf.Write(getDataFrame([]byte("bar")))
b := make([]byte, 6)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b[:n]).To(Equal([]byte("foo")))
n, err = rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b[:n]).To(Equal([]byte("bar")))
})
It("skips HEADERS frames", func() {
buf.Write(getDataFrame([]byte("foo")))
(&headersFrame{Length: 10}).Write(buf)
buf.Write(make([]byte, 10))
buf.Write(getDataFrame([]byte("bar")))
b := make([]byte, 6)
n, err := io.ReadFull(rb, b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b).To(Equal([]byte("foobar")))
})
It("errors when it can't parse the frame", func() {
buf.Write([]byte("invalid"))
_, err := rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
})
It("errors on unexpected frames, and calls the error callback", func() {
(&settingsFrame{}).Write(buf)
_, err := rb.Read([]byte{0})
Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame"))
Expect(errorCbCalled).To(BeTrue())
})
if bodyType == bodyTypeResponse {
It("closes the reqDone channel when Read errors", func() {
buf.Write([]byte("invalid"))
_, err := rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(reqDone).To(BeClosed())
})
It("allows multiple calls to Read, when Read errors", func() {
buf.Write([]byte("invalid"))
_, err := rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(reqDone).To(BeClosed())
_, err = rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
})
It("closes responses", func() {
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled))
Expect(rb.Close()).To(Succeed())
})
It("allows multiple calls to Close", func() {
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2)
Expect(rb.Close()).To(Succeed())
Expect(reqDone).To(BeClosed())
Expect(rb.Close()).To(Succeed())
})
}
})
}
It("allows multiple calls to Close", func() {
str := mockquic.NewMockStream(mockCtrl)
rb := newResponseBody(str, nil, reqDone)
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2)
Expect(rb.Close()).To(Succeed())
Expect(reqDone).To(BeClosed())
Expect(rb.Close()).To(Succeed())
})
})

View file

@ -298,15 +298,59 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
return rsp, rerr.err
}
func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error {
defer body.Close()
b := make([]byte, bodyCopyBufferSize)
for {
n, rerr := body.Read(b)
if n == 0 {
if rerr == nil {
continue
}
if rerr == io.EOF {
break
}
}
if _, err := str.Write(b[:n]); err != nil {
return err
}
if rerr != nil {
if rerr == io.EOF {
break
}
str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
return rerr
}
}
return nil
}
func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan struct{}) (*http.Response, requestError) {
var requestGzip bool
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
requestGzip = true
}
if err := c.requestWriter.WriteRequest(str, req, opt.DontCloseRequestStream, requestGzip); err != nil {
if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil {
return nil, newStreamError(errorInternalError, err)
}
if req.Body == nil && !opt.DontCloseRequestStream {
str.Close()
}
hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") })
if req.Body != nil {
// send the request body asynchronously
go func() {
if err := c.sendRequestBody(hstr, req.Body); err != nil {
c.logger.Errorf("Error writing request: %s", err)
}
if !opt.DontCloseRequestStream {
hstr.Close()
}
}()
}
frame, err := parseNextFrame(str, nil)
if err != nil {
return nil, newStreamError(errorFrameError, err)
@ -348,9 +392,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
res.Header.Add(hf.Name, hf.Value)
}
}
respBody := newResponseBody(str, c.conn, reqDone, func() {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "")
})
respBody := newResponseBody(hstr, c.conn, reqDone)
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
_, hasTransferEncoding := res.Header["Transfer-Encoding"]

View file

@ -797,6 +797,7 @@ var _ = Describe("Client", func() {
<-done
return 0, errors.New("test done")
})
str.EXPECT().Close()
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
})

71
http3/http_stream.go Normal file
View file

@ -0,0 +1,71 @@
package http3
import (
"bytes"
"fmt"
"github.com/lucas-clemente/quic-go"
)
// A Stream is a HTTP/3 stream.
// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames.
type Stream quic.Stream
// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly
// from the QUIC stream, it writes to and reads from the HTTP stream.
type stream struct {
quic.Stream
onFrameError func()
bytesRemainingInFrame uint64
}
var _ Stream = &stream{}
func newStream(str quic.Stream, onFrameError func()) *stream {
return &stream{Stream: str, onFrameError: onFrameError}
}
func (s *stream) Read(b []byte) (int, error) {
if s.bytesRemainingInFrame == 0 {
parseLoop:
for {
frame, err := parseNextFrame(s.Stream, nil)
if err != nil {
return 0, err
}
switch f := frame.(type) {
case *headersFrame:
// skip HEADERS frames
continue
case *dataFrame:
s.bytesRemainingInFrame = f.Length
break parseLoop
default:
s.onFrameError()
// parseNextFrame skips over unknown frame types
// Therefore, this condition is only entered when we parsed another known frame type.
return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
}
}
}
var n int
var err error
if s.bytesRemainingInFrame < uint64(len(b)) {
n, err = s.Stream.Read(b[:s.bytesRemainingInFrame])
} else {
n, err = s.Stream.Read(b)
}
s.bytesRemainingInFrame -= uint64(n)
return n, err
}
func (s *stream) Write(b []byte) (int, error) {
buf := &bytes.Buffer{}
(&dataFrame{Length: uint64(len(b))}).Write(buf)
if _, err := s.Stream.Write(buf.Bytes()); err != nil {
return 0, err
}
return s.Stream.Write(b)
}

150
http3/http_stream_test.go Normal file
View file

@ -0,0 +1,150 @@
package http3
import (
"bytes"
"io"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Stream", func() {
Context("reading", func() {
var (
str Stream
qstr *mockquic.MockStream
buf *bytes.Buffer
errorCbCalled bool
)
errorCb := func() { errorCbCalled = true }
getDataFrame := func(data []byte) []byte {
b := &bytes.Buffer{}
(&dataFrame{Length: uint64(len(data))}).Write(b)
b.Write(data)
return b.Bytes()
}
BeforeEach(func() {
buf = &bytes.Buffer{}
errorCbCalled = false
qstr = mockquic.NewMockStream(mockCtrl)
qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str = newStream(qstr, errorCb)
})
It("reads DATA frames in a single run", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 6)
n, err := str.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b).To(Equal([]byte("foobar")))
})
It("reads DATA frames in multiple runs", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 3)
n, err := str.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b).To(Equal([]byte("foo")))
n, err = str.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b).To(Equal([]byte("bar")))
})
It("reads DATA frames into too large buffers", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 10)
n, err := str.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b[:n]).To(Equal([]byte("foobar")))
})
It("reads DATA frames into too large buffers, in multiple runs", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 4)
n, err := str.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte("foob")))
n, err = str.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(2))
Expect(b[:n]).To(Equal([]byte("ar")))
})
It("reads multiple DATA frames", func() {
buf.Write(getDataFrame([]byte("foo")))
buf.Write(getDataFrame([]byte("bar")))
b := make([]byte, 6)
n, err := str.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b[:n]).To(Equal([]byte("foo")))
n, err = str.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b[:n]).To(Equal([]byte("bar")))
})
It("skips HEADERS frames", func() {
buf.Write(getDataFrame([]byte("foo")))
(&headersFrame{Length: 10}).Write(buf)
buf.Write(make([]byte, 10))
buf.Write(getDataFrame([]byte("bar")))
b := make([]byte, 6)
n, err := io.ReadFull(str, b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b).To(Equal([]byte("foobar")))
})
It("errors when it can't parse the frame", func() {
buf.Write([]byte("invalid"))
_, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
})
It("errors on unexpected frames, and calls the error callback", func() {
(&settingsFrame{}).Write(buf)
_, err := str.Read([]byte{0})
Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame"))
Expect(errorCbCalled).To(BeTrue())
})
})
Context("writing", func() {
It("writes data frames", func() {
buf := &bytes.Buffer{}
qstr := mockquic.NewMockStream(mockCtrl)
qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
str := newStream(qstr, nil)
str.Write([]byte("foo"))
str.Write([]byte("foobar"))
f, err := parseNextFrame(buf, nil)
Expect(err).ToNot(HaveOccurred())
Expect(f).To(Equal(&dataFrame{Length: 3}))
b := make([]byte, 3)
_, err = io.ReadFull(buf, b)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal([]byte("foo")))
f, err = parseNextFrame(buf, nil)
Expect(err).ToNot(HaveOccurred())
Expect(f).To(Equal(&dataFrame{Length: 6}))
b = make([]byte, 6)
_, err = io.ReadFull(buf, b)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal([]byte("foobar")))
})
})
})

View file

@ -38,60 +38,14 @@ func newRequestWriter(logger utils.Logger) *requestWriter {
}
}
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, dontCloseStr, gzip bool) error {
func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, gzip bool) error {
// TODO: figure out how to add support for trailers
buf := &bytes.Buffer{}
if err := w.writeHeaders(buf, req, gzip); err != nil {
return err
}
if _, err := str.Write(buf.Bytes()); err != nil {
return err
}
// TODO: add support for trailers
if req.Body == nil {
if !dontCloseStr {
str.Close()
}
return nil
}
// send the request body asynchronously
go func() {
defer req.Body.Close()
b := make([]byte, bodyCopyBufferSize)
for {
n, rerr := req.Body.Read(b)
if n == 0 {
if rerr == nil {
continue
} else if rerr == io.EOF {
break
}
}
buf := &bytes.Buffer{}
(&dataFrame{Length: uint64(n)}).Write(buf)
if _, err := str.Write(buf.Bytes()); err != nil {
w.logger.Errorf("Error writing request: %s", err)
return
}
if _, err := str.Write(b[:n]); err != nil {
w.logger.Errorf("Error writing request: %s", err)
return
}
if rerr != nil {
if rerr == io.EOF {
break
}
str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
w.logger.Errorf("Error writing request: %s", rerr)
return
}
}
if !dontCloseStr {
str.Close()
}
}()
return nil
_, err := str.Write(buf.Bytes())
return err
}
func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {

View file

@ -4,7 +4,6 @@ import (
"bytes"
"io"
"net/http"
"strconv"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/lucas-clemente/quic-go/internal/utils"
@ -16,12 +15,6 @@ import (
. "github.com/onsi/gomega"
)
type foobarReader struct{}
func (r *foobarReader) Read(b []byte) (int, error) {
return copy(b, []byte("foobar")), io.EOF
}
var _ = Describe("Request Writer", func() {
var (
rw *requestWriter
@ -51,16 +44,13 @@ var _ = Describe("Request Writer", func() {
rw = newRequestWriter(utils.DefaultLogger)
strBuf = &bytes.Buffer{}
str = mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return strBuf.Write(p)
}).AnyTimes()
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
})
It("writes a GET request", func() {
str.EXPECT().Close()
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
@ -69,55 +59,7 @@ var _ = Describe("Request Writer", func() {
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
})
It("writes a GET request without closing the stream", func() {
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io", nil)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, true, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
})
It("writes a POST request", func() {
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
postData := bytes.NewReader([]byte("foobar"))
req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", postData)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
Eventually(closed).Should(BeClosed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
Expect(headerFields).To(HaveKey("content-length"))
contentLength, err := strconv.Atoi(headerFields["content-length"])
Expect(err).ToNot(HaveOccurred())
Expect(contentLength).To(BeNumerically(">", 0))
frame, err := parseNextFrame(strBuf, nil)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6))
})
It("writes a POST request, if the Body returns an EOF immediately", func() {
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", &foobarReader{})
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
Eventually(closed).Should(BeClosed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
frame, err := parseNextFrame(strBuf, nil)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6))
})
It("sends cookies", func() {
str.EXPECT().Close()
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
cookie1 := &http.Cookie{
@ -130,25 +72,23 @@ var _ = Describe("Request Writer", func() {
}
req.AddCookie(cookie1)
req.AddCookie(cookie2)
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`))
})
It("adds the header for gzip support", func() {
str.EXPECT().Close()
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, false, true)).To(Succeed())
Expect(rw.WriteRequestHeader(str, req, true)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
})
It("writes a CONNECT request", func() {
str.EXPECT().Close()
req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT"))
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
@ -158,11 +98,10 @@ var _ = Describe("Request Writer", func() {
})
It("writes an Extended CONNECT request", func() {
str.EXPECT().Close()
req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil)
Expect(err).ToNot(HaveOccurred())
req.Proto = "webtransport"
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT"))

View file

@ -549,7 +549,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
}
req.RemoteAddr = conn.RemoteAddr().String()
req.Body = newRequestBody(str, onFrameError)
req.Body = newRequestBody(newStream(str, onFrameError))
if s.logger.Debug() {
s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())

View file

@ -135,11 +135,8 @@ var _ = Describe("Server", func() {
buf := &bytes.Buffer{}
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
rw := newRequestWriter(utils.DefaultLogger)
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
Eventually(closed).Should(BeClosed())
Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
return buf.Bytes()
}