mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
refactor HTTP/3 stream handling to use a dedicated stream
Reading from and writing onto this stream applies HTTP/3 DATA framing.
This commit is contained in:
parent
ccf897e519
commit
04d46526c7
10 changed files with 344 additions and 360 deletions
|
@ -2,7 +2,6 @@ package http3
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
|
@ -29,42 +28,43 @@ type Hijacker interface {
|
|||
// The body of a http.Request or http.Response.
|
||||
type body struct {
|
||||
str quic.Stream
|
||||
}
|
||||
|
||||
var _ io.ReadCloser = &body{}
|
||||
|
||||
func newRequestBody(str Stream) *body {
|
||||
return &body{str: str}
|
||||
}
|
||||
|
||||
func (r *body) Read(b []byte) (int, error) {
|
||||
return r.str.Read(b)
|
||||
}
|
||||
|
||||
func (r *body) Close() error {
|
||||
r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
|
||||
return nil
|
||||
}
|
||||
|
||||
type hijackableBody struct {
|
||||
body
|
||||
conn quic.Connection // only needed to implement Hijacker
|
||||
|
||||
// only set for the http.Response
|
||||
// The channel is closed when the user is done with this response:
|
||||
// either when Read() errors, or when Close() is called.
|
||||
reqDone chan<- struct{}
|
||||
reqDoneClosed bool
|
||||
|
||||
onFrameError func()
|
||||
|
||||
bytesRemainingInFrame uint64
|
||||
}
|
||||
|
||||
var _ io.ReadCloser = &body{}
|
||||
|
||||
type hijackableBody struct {
|
||||
body
|
||||
conn quic.Connection // only needed to implement Hijacker
|
||||
}
|
||||
|
||||
var _ Hijacker = &hijackableBody{}
|
||||
|
||||
func newRequestBody(str quic.Stream, onFrameError func()) *body {
|
||||
return &body{
|
||||
str: str,
|
||||
onFrameError: onFrameError,
|
||||
}
|
||||
}
|
||||
|
||||
func newResponseBody(str quic.Stream, conn quic.Connection, done chan<- struct{}, onFrameError func()) *hijackableBody {
|
||||
func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody {
|
||||
return &hijackableBody{
|
||||
body: body{
|
||||
str: str,
|
||||
onFrameError: onFrameError,
|
||||
reqDone: done,
|
||||
str: str,
|
||||
},
|
||||
conn: conn,
|
||||
reqDone: done,
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,50 +72,15 @@ func (r *hijackableBody) StreamCreator() StreamCreator {
|
|||
return r.conn
|
||||
}
|
||||
|
||||
func (r *body) Read(b []byte) (int, error) {
|
||||
n, err := r.readImpl(b)
|
||||
func (r *hijackableBody) Read(b []byte) (int, error) {
|
||||
n, err := r.str.Read(b)
|
||||
if err != nil {
|
||||
r.requestDone()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *body) readImpl(b []byte) (int, error) {
|
||||
if r.bytesRemainingInFrame == 0 {
|
||||
parseLoop:
|
||||
for {
|
||||
frame, err := parseNextFrame(r.str, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch f := frame.(type) {
|
||||
case *headersFrame:
|
||||
// skip HEADERS frames
|
||||
continue
|
||||
case *dataFrame:
|
||||
r.bytesRemainingInFrame = f.Length
|
||||
break parseLoop
|
||||
default:
|
||||
r.onFrameError()
|
||||
// parseNextFrame skips over unknown frame types
|
||||
// Therefore, this condition is only entered when we parsed another known frame type.
|
||||
return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var n int
|
||||
var err error
|
||||
if r.bytesRemainingInFrame < uint64(len(b)) {
|
||||
n, err = r.str.Read(b[:r.bytesRemainingInFrame])
|
||||
} else {
|
||||
n, err = r.str.Read(b)
|
||||
}
|
||||
r.bytesRemainingInFrame -= uint64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *body) requestDone() {
|
||||
func (r *hijackableBody) requestDone() {
|
||||
if r.reqDoneClosed || r.reqDone == nil {
|
||||
return
|
||||
}
|
||||
|
@ -127,7 +92,7 @@ func (r *body) StreamID() quic.StreamID {
|
|||
return r.str.StreamID()
|
||||
}
|
||||
|
||||
func (r *body) Close() error {
|
||||
func (r *hijackableBody) Close() error {
|
||||
r.requestDone()
|
||||
// If the EOF was read, CancelRead() is a no-op.
|
||||
r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
|
||||
|
|
|
@ -1,189 +1,54 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"errors"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type bodyType uint8
|
||||
var _ = Describe("Response Body", func() {
|
||||
var reqDone chan struct{}
|
||||
|
||||
const (
|
||||
bodyTypeRequest bodyType = iota
|
||||
bodyTypeResponse
|
||||
)
|
||||
BeforeEach(func() { reqDone = make(chan struct{}) })
|
||||
|
||||
func (t bodyType) String() string {
|
||||
if t == bodyTypeRequest {
|
||||
return "request"
|
||||
}
|
||||
return "response"
|
||||
}
|
||||
|
||||
var _ = Describe("Body", func() {
|
||||
var (
|
||||
rb io.ReadCloser
|
||||
str *mockquic.MockStream
|
||||
buf *bytes.Buffer
|
||||
reqDone chan struct{}
|
||||
errorCbCalled bool
|
||||
)
|
||||
|
||||
errorCb := func() { errorCbCalled = true }
|
||||
|
||||
getDataFrame := func(data []byte) []byte {
|
||||
b := &bytes.Buffer{}
|
||||
(&dataFrame{Length: uint64(len(data))}).Write(b)
|
||||
b.Write(data)
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
buf = &bytes.Buffer{}
|
||||
errorCbCalled = false
|
||||
It("closes the reqDone channel when Read errors", func() {
|
||||
str := mockquic.NewMockStream(mockCtrl)
|
||||
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error"))
|
||||
rb := newResponseBody(str, nil, reqDone)
|
||||
_, err := rb.Read([]byte{0})
|
||||
Expect(err).To(MatchError("test error"))
|
||||
Expect(reqDone).To(BeClosed())
|
||||
})
|
||||
|
||||
for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} {
|
||||
bodyType := bt
|
||||
It("allows multiple calls to Read, when Read errors", func() {
|
||||
str := mockquic.NewMockStream(mockCtrl)
|
||||
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error")).Times(2)
|
||||
rb := newResponseBody(str, nil, reqDone)
|
||||
_, err := rb.Read([]byte{0})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(reqDone).To(BeClosed())
|
||||
_, err = rb.Read([]byte{0})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
Context(fmt.Sprintf("using a %s body", bodyType), func() {
|
||||
BeforeEach(func() {
|
||||
str = mockquic.NewMockStream(mockCtrl)
|
||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
|
||||
return buf.Write(b)
|
||||
}).AnyTimes()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
|
||||
return buf.Read(b)
|
||||
}).AnyTimes()
|
||||
It("closes responses", func() {
|
||||
str := mockquic.NewMockStream(mockCtrl)
|
||||
rb := newResponseBody(str, nil, reqDone)
|
||||
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled))
|
||||
Expect(rb.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
switch bodyType {
|
||||
case bodyTypeRequest:
|
||||
rb = newRequestBody(str, errorCb)
|
||||
case bodyTypeResponse:
|
||||
reqDone = make(chan struct{})
|
||||
rb = newResponseBody(str, nil, reqDone, errorCb)
|
||||
}
|
||||
})
|
||||
|
||||
It("reads DATA frames in a single run", func() {
|
||||
buf.Write(getDataFrame([]byte("foobar")))
|
||||
b := make([]byte, 6)
|
||||
n, err := rb.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(b).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("reads DATA frames in multiple runs", func() {
|
||||
buf.Write(getDataFrame([]byte("foobar")))
|
||||
b := make([]byte, 3)
|
||||
n, err := rb.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
Expect(b).To(Equal([]byte("foo")))
|
||||
n, err = rb.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
Expect(b).To(Equal([]byte("bar")))
|
||||
})
|
||||
|
||||
It("reads DATA frames into too large buffers", func() {
|
||||
buf.Write(getDataFrame([]byte("foobar")))
|
||||
b := make([]byte, 10)
|
||||
n, err := rb.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(b[:n]).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("reads DATA frames into too large buffers, in multiple runs", func() {
|
||||
buf.Write(getDataFrame([]byte("foobar")))
|
||||
b := make([]byte, 4)
|
||||
n, err := rb.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(4))
|
||||
Expect(b).To(Equal([]byte("foob")))
|
||||
n, err = rb.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(2))
|
||||
Expect(b[:n]).To(Equal([]byte("ar")))
|
||||
})
|
||||
|
||||
It("reads multiple DATA frames", func() {
|
||||
buf.Write(getDataFrame([]byte("foo")))
|
||||
buf.Write(getDataFrame([]byte("bar")))
|
||||
b := make([]byte, 6)
|
||||
n, err := rb.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
Expect(b[:n]).To(Equal([]byte("foo")))
|
||||
n, err = rb.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
Expect(b[:n]).To(Equal([]byte("bar")))
|
||||
})
|
||||
|
||||
It("skips HEADERS frames", func() {
|
||||
buf.Write(getDataFrame([]byte("foo")))
|
||||
(&headersFrame{Length: 10}).Write(buf)
|
||||
buf.Write(make([]byte, 10))
|
||||
buf.Write(getDataFrame([]byte("bar")))
|
||||
b := make([]byte, 6)
|
||||
n, err := io.ReadFull(rb, b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(b).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("errors when it can't parse the frame", func() {
|
||||
buf.Write([]byte("invalid"))
|
||||
_, err := rb.Read([]byte{0})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors on unexpected frames, and calls the error callback", func() {
|
||||
(&settingsFrame{}).Write(buf)
|
||||
_, err := rb.Read([]byte{0})
|
||||
Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame"))
|
||||
Expect(errorCbCalled).To(BeTrue())
|
||||
})
|
||||
|
||||
if bodyType == bodyTypeResponse {
|
||||
It("closes the reqDone channel when Read errors", func() {
|
||||
buf.Write([]byte("invalid"))
|
||||
_, err := rb.Read([]byte{0})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(reqDone).To(BeClosed())
|
||||
})
|
||||
|
||||
It("allows multiple calls to Read, when Read errors", func() {
|
||||
buf.Write([]byte("invalid"))
|
||||
_, err := rb.Read([]byte{0})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(reqDone).To(BeClosed())
|
||||
_, err = rb.Read([]byte{0})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("closes responses", func() {
|
||||
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled))
|
||||
Expect(rb.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
It("allows multiple calls to Close", func() {
|
||||
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2)
|
||||
Expect(rb.Close()).To(Succeed())
|
||||
Expect(reqDone).To(BeClosed())
|
||||
Expect(rb.Close()).To(Succeed())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
It("allows multiple calls to Close", func() {
|
||||
str := mockquic.NewMockStream(mockCtrl)
|
||||
rb := newResponseBody(str, nil, reqDone)
|
||||
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2)
|
||||
Expect(rb.Close()).To(Succeed())
|
||||
Expect(reqDone).To(BeClosed())
|
||||
Expect(rb.Close()).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -298,15 +298,59 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
|
|||
return rsp, rerr.err
|
||||
}
|
||||
|
||||
func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error {
|
||||
defer body.Close()
|
||||
b := make([]byte, bodyCopyBufferSize)
|
||||
for {
|
||||
n, rerr := body.Read(b)
|
||||
if n == 0 {
|
||||
if rerr == nil {
|
||||
continue
|
||||
}
|
||||
if rerr == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
if _, err := str.Write(b[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
if rerr != nil {
|
||||
if rerr == io.EOF {
|
||||
break
|
||||
}
|
||||
str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
|
||||
return rerr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan struct{}) (*http.Response, requestError) {
|
||||
var requestGzip bool
|
||||
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
|
||||
requestGzip = true
|
||||
}
|
||||
if err := c.requestWriter.WriteRequest(str, req, opt.DontCloseRequestStream, requestGzip); err != nil {
|
||||
if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil {
|
||||
return nil, newStreamError(errorInternalError, err)
|
||||
}
|
||||
|
||||
if req.Body == nil && !opt.DontCloseRequestStream {
|
||||
str.Close()
|
||||
}
|
||||
|
||||
hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") })
|
||||
if req.Body != nil {
|
||||
// send the request body asynchronously
|
||||
go func() {
|
||||
if err := c.sendRequestBody(hstr, req.Body); err != nil {
|
||||
c.logger.Errorf("Error writing request: %s", err)
|
||||
}
|
||||
if !opt.DontCloseRequestStream {
|
||||
hstr.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
frame, err := parseNextFrame(str, nil)
|
||||
if err != nil {
|
||||
return nil, newStreamError(errorFrameError, err)
|
||||
|
@ -348,9 +392,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
|
|||
res.Header.Add(hf.Name, hf.Value)
|
||||
}
|
||||
}
|
||||
respBody := newResponseBody(str, c.conn, reqDone, func() {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "")
|
||||
})
|
||||
respBody := newResponseBody(hstr, c.conn, reqDone)
|
||||
|
||||
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
|
||||
_, hasTransferEncoding := res.Header["Transfer-Encoding"]
|
||||
|
|
|
@ -797,6 +797,7 @@ var _ = Describe("Client", func() {
|
|||
<-done
|
||||
return 0, errors.New("test done")
|
||||
})
|
||||
str.EXPECT().Close()
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
})
|
||||
|
|
71
http3/http_stream.go
Normal file
71
http3/http_stream.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
// A Stream is a HTTP/3 stream.
|
||||
// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames.
|
||||
type Stream quic.Stream
|
||||
|
||||
// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly
|
||||
// from the QUIC stream, it writes to and reads from the HTTP stream.
|
||||
type stream struct {
|
||||
quic.Stream
|
||||
|
||||
onFrameError func()
|
||||
bytesRemainingInFrame uint64
|
||||
}
|
||||
|
||||
var _ Stream = &stream{}
|
||||
|
||||
func newStream(str quic.Stream, onFrameError func()) *stream {
|
||||
return &stream{Stream: str, onFrameError: onFrameError}
|
||||
}
|
||||
|
||||
func (s *stream) Read(b []byte) (int, error) {
|
||||
if s.bytesRemainingInFrame == 0 {
|
||||
parseLoop:
|
||||
for {
|
||||
frame, err := parseNextFrame(s.Stream, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch f := frame.(type) {
|
||||
case *headersFrame:
|
||||
// skip HEADERS frames
|
||||
continue
|
||||
case *dataFrame:
|
||||
s.bytesRemainingInFrame = f.Length
|
||||
break parseLoop
|
||||
default:
|
||||
s.onFrameError()
|
||||
// parseNextFrame skips over unknown frame types
|
||||
// Therefore, this condition is only entered when we parsed another known frame type.
|
||||
return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var n int
|
||||
var err error
|
||||
if s.bytesRemainingInFrame < uint64(len(b)) {
|
||||
n, err = s.Stream.Read(b[:s.bytesRemainingInFrame])
|
||||
} else {
|
||||
n, err = s.Stream.Read(b)
|
||||
}
|
||||
s.bytesRemainingInFrame -= uint64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (s *stream) Write(b []byte) (int, error) {
|
||||
buf := &bytes.Buffer{}
|
||||
(&dataFrame{Length: uint64(len(b))}).Write(buf)
|
||||
if _, err := s.Stream.Write(buf.Bytes()); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return s.Stream.Write(b)
|
||||
}
|
150
http3/http_stream_test.go
Normal file
150
http3/http_stream_test.go
Normal file
|
@ -0,0 +1,150 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Stream", func() {
|
||||
Context("reading", func() {
|
||||
var (
|
||||
str Stream
|
||||
qstr *mockquic.MockStream
|
||||
buf *bytes.Buffer
|
||||
errorCbCalled bool
|
||||
)
|
||||
|
||||
errorCb := func() { errorCbCalled = true }
|
||||
getDataFrame := func(data []byte) []byte {
|
||||
b := &bytes.Buffer{}
|
||||
(&dataFrame{Length: uint64(len(data))}).Write(b)
|
||||
b.Write(data)
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
buf = &bytes.Buffer{}
|
||||
errorCbCalled = false
|
||||
qstr = mockquic.NewMockStream(mockCtrl)
|
||||
qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
|
||||
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
str = newStream(qstr, errorCb)
|
||||
})
|
||||
|
||||
It("reads DATA frames in a single run", func() {
|
||||
buf.Write(getDataFrame([]byte("foobar")))
|
||||
b := make([]byte, 6)
|
||||
n, err := str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(b).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("reads DATA frames in multiple runs", func() {
|
||||
buf.Write(getDataFrame([]byte("foobar")))
|
||||
b := make([]byte, 3)
|
||||
n, err := str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
Expect(b).To(Equal([]byte("foo")))
|
||||
n, err = str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
Expect(b).To(Equal([]byte("bar")))
|
||||
})
|
||||
|
||||
It("reads DATA frames into too large buffers", func() {
|
||||
buf.Write(getDataFrame([]byte("foobar")))
|
||||
b := make([]byte, 10)
|
||||
n, err := str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(b[:n]).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("reads DATA frames into too large buffers, in multiple runs", func() {
|
||||
buf.Write(getDataFrame([]byte("foobar")))
|
||||
b := make([]byte, 4)
|
||||
n, err := str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(4))
|
||||
Expect(b).To(Equal([]byte("foob")))
|
||||
n, err = str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(2))
|
||||
Expect(b[:n]).To(Equal([]byte("ar")))
|
||||
})
|
||||
|
||||
It("reads multiple DATA frames", func() {
|
||||
buf.Write(getDataFrame([]byte("foo")))
|
||||
buf.Write(getDataFrame([]byte("bar")))
|
||||
b := make([]byte, 6)
|
||||
n, err := str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
Expect(b[:n]).To(Equal([]byte("foo")))
|
||||
n, err = str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
Expect(b[:n]).To(Equal([]byte("bar")))
|
||||
})
|
||||
|
||||
It("skips HEADERS frames", func() {
|
||||
buf.Write(getDataFrame([]byte("foo")))
|
||||
(&headersFrame{Length: 10}).Write(buf)
|
||||
buf.Write(make([]byte, 10))
|
||||
buf.Write(getDataFrame([]byte("bar")))
|
||||
b := make([]byte, 6)
|
||||
n, err := io.ReadFull(str, b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(b).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("errors when it can't parse the frame", func() {
|
||||
buf.Write([]byte("invalid"))
|
||||
_, err := str.Read([]byte{0})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors on unexpected frames, and calls the error callback", func() {
|
||||
(&settingsFrame{}).Write(buf)
|
||||
_, err := str.Read([]byte{0})
|
||||
Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame"))
|
||||
Expect(errorCbCalled).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("writing", func() {
|
||||
It("writes data frames", func() {
|
||||
buf := &bytes.Buffer{}
|
||||
qstr := mockquic.NewMockStream(mockCtrl)
|
||||
qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
|
||||
str := newStream(qstr, nil)
|
||||
str.Write([]byte("foo"))
|
||||
str.Write([]byte("foobar"))
|
||||
|
||||
f, err := parseNextFrame(buf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(f).To(Equal(&dataFrame{Length: 3}))
|
||||
b := make([]byte, 3)
|
||||
_, err = io.ReadFull(buf, b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(Equal([]byte("foo")))
|
||||
|
||||
f, err = parseNextFrame(buf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(f).To(Equal(&dataFrame{Length: 6}))
|
||||
b = make([]byte, 6)
|
||||
_, err = io.ReadFull(buf, b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(Equal([]byte("foobar")))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -38,60 +38,14 @@ func newRequestWriter(logger utils.Logger) *requestWriter {
|
|||
}
|
||||
}
|
||||
|
||||
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, dontCloseStr, gzip bool) error {
|
||||
func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, gzip bool) error {
|
||||
// TODO: figure out how to add support for trailers
|
||||
buf := &bytes.Buffer{}
|
||||
if err := w.writeHeaders(buf, req, gzip); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := str.Write(buf.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: add support for trailers
|
||||
if req.Body == nil {
|
||||
if !dontCloseStr {
|
||||
str.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// send the request body asynchronously
|
||||
go func() {
|
||||
defer req.Body.Close()
|
||||
b := make([]byte, bodyCopyBufferSize)
|
||||
for {
|
||||
n, rerr := req.Body.Read(b)
|
||||
if n == 0 {
|
||||
if rerr == nil {
|
||||
continue
|
||||
} else if rerr == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
(&dataFrame{Length: uint64(n)}).Write(buf)
|
||||
if _, err := str.Write(buf.Bytes()); err != nil {
|
||||
w.logger.Errorf("Error writing request: %s", err)
|
||||
return
|
||||
}
|
||||
if _, err := str.Write(b[:n]); err != nil {
|
||||
w.logger.Errorf("Error writing request: %s", err)
|
||||
return
|
||||
}
|
||||
if rerr != nil {
|
||||
if rerr == io.EOF {
|
||||
break
|
||||
}
|
||||
str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
|
||||
w.logger.Errorf("Error writing request: %s", rerr)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !dontCloseStr {
|
||||
str.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
_, err := str.Write(buf.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
@ -16,12 +15,6 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type foobarReader struct{}
|
||||
|
||||
func (r *foobarReader) Read(b []byte) (int, error) {
|
||||
return copy(b, []byte("foobar")), io.EOF
|
||||
}
|
||||
|
||||
var _ = Describe("Request Writer", func() {
|
||||
var (
|
||||
rw *requestWriter
|
||||
|
@ -51,16 +44,13 @@ var _ = Describe("Request Writer", func() {
|
|||
rw = newRequestWriter(utils.DefaultLogger)
|
||||
strBuf = &bytes.Buffer{}
|
||||
str = mockquic.NewMockStream(mockCtrl)
|
||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||
return strBuf.Write(p)
|
||||
}).AnyTimes()
|
||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
|
||||
})
|
||||
|
||||
It("writes a GET request", func() {
|
||||
str.EXPECT().Close()
|
||||
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
|
||||
|
@ -69,55 +59,7 @@ var _ = Describe("Request Writer", func() {
|
|||
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
|
||||
})
|
||||
|
||||
It("writes a GET request without closing the stream", func() {
|
||||
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, true, false)).To(Succeed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||
})
|
||||
|
||||
It("writes a POST request", func() {
|
||||
closed := make(chan struct{})
|
||||
str.EXPECT().Close().Do(func() { close(closed) })
|
||||
postData := bytes.NewReader([]byte("foobar"))
|
||||
req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", postData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
|
||||
Eventually(closed).Should(BeClosed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
|
||||
Expect(headerFields).To(HaveKey("content-length"))
|
||||
contentLength, err := strconv.Atoi(headerFields["content-length"])
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(contentLength).To(BeNumerically(">", 0))
|
||||
|
||||
frame, err := parseNextFrame(strBuf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
|
||||
Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6))
|
||||
})
|
||||
|
||||
It("writes a POST request, if the Body returns an EOF immediately", func() {
|
||||
closed := make(chan struct{})
|
||||
str.EXPECT().Close().Do(func() { close(closed) })
|
||||
req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", &foobarReader{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
|
||||
Eventually(closed).Should(BeClosed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
|
||||
|
||||
frame, err := parseNextFrame(strBuf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
|
||||
Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6))
|
||||
})
|
||||
|
||||
It("sends cookies", func() {
|
||||
str.EXPECT().Close()
|
||||
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cookie1 := &http.Cookie{
|
||||
|
@ -130,25 +72,23 @@ var _ = Describe("Request Writer", func() {
|
|||
}
|
||||
req.AddCookie(cookie1)
|
||||
req.AddCookie(cookie2)
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`))
|
||||
})
|
||||
|
||||
It("adds the header for gzip support", func() {
|
||||
str.EXPECT().Close()
|
||||
req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, false, true)).To(Succeed())
|
||||
Expect(rw.WriteRequestHeader(str, req, true)).To(Succeed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
||||
})
|
||||
|
||||
It("writes a CONNECT request", func() {
|
||||
str.EXPECT().Close()
|
||||
req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT"))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||
|
@ -158,11 +98,10 @@ var _ = Describe("Request Writer", func() {
|
|||
})
|
||||
|
||||
It("writes an Extended CONNECT request", func() {
|
||||
str.EXPECT().Close()
|
||||
req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Proto = "webtransport"
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT"))
|
||||
|
|
|
@ -549,7 +549,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
|
|||
}
|
||||
|
||||
req.RemoteAddr = conn.RemoteAddr().String()
|
||||
req.Body = newRequestBody(str, onFrameError)
|
||||
req.Body = newRequestBody(newStream(str, onFrameError))
|
||||
|
||||
if s.logger.Debug() {
|
||||
s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
|
||||
|
|
|
@ -135,11 +135,8 @@ var _ = Describe("Server", func() {
|
|||
buf := &bytes.Buffer{}
|
||||
str := mockquic.NewMockStream(mockCtrl)
|
||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
|
||||
closed := make(chan struct{})
|
||||
str.EXPECT().Close().Do(func() { close(closed) })
|
||||
rw := newRequestWriter(utils.DefaultLogger)
|
||||
Expect(rw.WriteRequest(str, req, false, false)).To(Succeed())
|
||||
Eventually(closed).Should(BeClosed())
|
||||
Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue