mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
use a buffered writer for the http3 request writer
This commit is contained in:
parent
683230372e
commit
bcffb77ad4
3 changed files with 50 additions and 42 deletions
|
@ -222,7 +222,7 @@ var _ = Describe("Client", func() {
|
|||
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
|
||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
|
||||
)
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes()
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Close()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||
return rspBuf.Read(p)
|
||||
|
@ -458,7 +458,7 @@ var _ = Describe("Client", func() {
|
|||
gz.Write([]byte("gzipped response"))
|
||||
gz.Close()
|
||||
rw.Flush()
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes()
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||
return buf.Read(p)
|
||||
}).AnyTimes()
|
||||
|
@ -480,7 +480,7 @@ var _ = Describe("Client", func() {
|
|||
rw := newResponseWriter(buf, utils.DefaultLogger)
|
||||
rw.Write([]byte("not gzipped"))
|
||||
rw.Flush()
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes()
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||
return buf.Read(p)
|
||||
}).AnyTimes()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -18,6 +19,8 @@ import (
|
|||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
const bodyCopyBufferSize = 8 * 1024
|
||||
|
||||
type requestWriter struct {
|
||||
mutex sync.Mutex
|
||||
encoder *qpack.Encoder
|
||||
|
@ -37,22 +40,47 @@ func newRequestWriter(logger utils.Logger) *requestWriter {
|
|||
}
|
||||
|
||||
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error {
|
||||
headers, err := w.getHeaders(req, gzip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := str.Write(headers); err != nil {
|
||||
wr := bufio.NewWriter(str)
|
||||
|
||||
if err := w.writeHeaders(wr, req, gzip); err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: add support for trailers
|
||||
if req.Body == nil {
|
||||
if err := wr.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
str.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// send the request body asynchronously
|
||||
go func() {
|
||||
if err := w.sendRequestBody(req.Body, str); err != nil {
|
||||
defer req.Body.Close()
|
||||
b := make([]byte, bodyCopyBufferSize)
|
||||
for {
|
||||
n, err := req.Body.Read(b)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
|
||||
w.logger.Errorf("Error writing request: %s", err)
|
||||
return
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
(&dataFrame{Length: uint64(n)}).Write(buf)
|
||||
if _, err := wr.Write(buf.Bytes()); err != nil {
|
||||
w.logger.Errorf("Error writing request: %s", err)
|
||||
return
|
||||
}
|
||||
if _, err := wr.Write(b[:n]); err != nil {
|
||||
w.logger.Errorf("Error writing request: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := wr.Flush(); err != nil {
|
||||
fmt.Println(err)
|
||||
w.logger.Errorf("Error writing request: %s", err)
|
||||
return
|
||||
}
|
||||
|
@ -62,46 +90,25 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo
|
|||
return nil
|
||||
}
|
||||
|
||||
func (w *requestWriter) getHeaders(req *http.Request, gzip bool) ([]byte, error) {
|
||||
func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
defer w.encoder.Close()
|
||||
|
||||
if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
hf := headersFrame{Length: uint64(w.headerBuf.Len())}
|
||||
hf.Write(buf)
|
||||
if _, err := io.Copy(buf, w.headerBuf); err != nil {
|
||||
return nil, err
|
||||
if _, err := wr.Write(buf.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := wr.Write(w.headerBuf.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
w.headerBuf.Reset()
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (w *requestWriter) sendRequestBody(req io.ReadCloser, str quic.Stream) error {
|
||||
b := make([]byte, 8*1024)
|
||||
for {
|
||||
n, err := req.Read(b)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
|
||||
return err
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
(&dataFrame{Length: uint64(n)}).Write(buf)
|
||||
if _, err := str.Write(buf.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := str.Write(b[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
req.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -25,15 +25,15 @@ var _ = Describe("Request Writer", func() {
|
|||
|
||||
decode := func(str io.Reader) map[string]string {
|
||||
frame, err := parseNextFrame(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
|
||||
headersFrame := frame.(*headersFrame)
|
||||
data := make([]byte, headersFrame.Length)
|
||||
_, err = io.ReadFull(str, data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||
decoder := qpack.NewDecoder(nil)
|
||||
hfs, err := decoder.DecodeFull(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||
values := make(map[string]string)
|
||||
for _, hf := range hfs {
|
||||
values[hf.Name] = hf.Value
|
||||
|
@ -70,6 +70,8 @@ var _ = Describe("Request Writer", func() {
|
|||
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", postData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
||||
|
||||
Eventually(closed).Should(BeClosed())
|
||||
headerFields := decode(strBuf)
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
|
||||
Expect(headerFields).To(HaveKey("content-length"))
|
||||
|
@ -77,7 +79,6 @@ var _ = Describe("Request Writer", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(contentLength).To(BeNumerically(">", 0))
|
||||
|
||||
Eventually(closed).Should(BeClosed())
|
||||
frame, err := parseNextFrame(strBuf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue