mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
use a buffered writer for the http3 request writer
This commit is contained in:
parent
683230372e
commit
bcffb77ad4
3 changed files with 50 additions and 42 deletions
|
@ -222,7 +222,7 @@ var _ = Describe("Client", func() {
|
||||||
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
|
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
|
||||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
|
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
|
||||||
)
|
)
|
||||||
str.EXPECT().Write(gomock.Any()).AnyTimes()
|
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||||
str.EXPECT().Close()
|
str.EXPECT().Close()
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
return rspBuf.Read(p)
|
return rspBuf.Read(p)
|
||||||
|
@ -458,7 +458,7 @@ var _ = Describe("Client", func() {
|
||||||
gz.Write([]byte("gzipped response"))
|
gz.Write([]byte("gzipped response"))
|
||||||
gz.Close()
|
gz.Close()
|
||||||
rw.Flush()
|
rw.Flush()
|
||||||
str.EXPECT().Write(gomock.Any()).AnyTimes()
|
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
return buf.Read(p)
|
return buf.Read(p)
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
@ -480,7 +480,7 @@ var _ = Describe("Client", func() {
|
||||||
rw := newResponseWriter(buf, utils.DefaultLogger)
|
rw := newResponseWriter(buf, utils.DefaultLogger)
|
||||||
rw.Write([]byte("not gzipped"))
|
rw.Write([]byte("not gzipped"))
|
||||||
rw.Flush()
|
rw.Flush()
|
||||||
str.EXPECT().Write(gomock.Any()).AnyTimes()
|
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
return buf.Read(p)
|
return buf.Read(p)
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package http3
|
package http3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -18,6 +19,8 @@ import (
|
||||||
"golang.org/x/net/idna"
|
"golang.org/x/net/idna"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const bodyCopyBufferSize = 8 * 1024
|
||||||
|
|
||||||
type requestWriter struct {
|
type requestWriter struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
encoder *qpack.Encoder
|
encoder *qpack.Encoder
|
||||||
|
@ -37,22 +40,47 @@ func newRequestWriter(logger utils.Logger) *requestWriter {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error {
|
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error {
|
||||||
headers, err := w.getHeaders(req, gzip)
|
wr := bufio.NewWriter(str)
|
||||||
if err != nil {
|
|
||||||
return err
|
if err := w.writeHeaders(wr, req, gzip); err != nil {
|
||||||
}
|
|
||||||
if _, err := str.Write(headers); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// TODO: add support for trailers
|
// TODO: add support for trailers
|
||||||
if req.Body == nil {
|
if req.Body == nil {
|
||||||
|
if err := wr.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
str.Close()
|
str.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// send the request body asynchronously
|
// send the request body asynchronously
|
||||||
go func() {
|
go func() {
|
||||||
if err := w.sendRequestBody(req.Body, str); err != nil {
|
defer req.Body.Close()
|
||||||
|
b := make([]byte, bodyCopyBufferSize)
|
||||||
|
for {
|
||||||
|
n, err := req.Body.Read(b)
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
|
||||||
|
w.logger.Errorf("Error writing request: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
(&dataFrame{Length: uint64(n)}).Write(buf)
|
||||||
|
if _, err := wr.Write(buf.Bytes()); err != nil {
|
||||||
|
w.logger.Errorf("Error writing request: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := wr.Write(b[:n]); err != nil {
|
||||||
|
w.logger.Errorf("Error writing request: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := wr.Flush(); err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
w.logger.Errorf("Error writing request: %s", err)
|
w.logger.Errorf("Error writing request: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -62,46 +90,25 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *requestWriter) getHeaders(req *http.Request, gzip bool) ([]byte, error) {
|
func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {
|
||||||
w.mutex.Lock()
|
w.mutex.Lock()
|
||||||
defer w.mutex.Unlock()
|
defer w.mutex.Unlock()
|
||||||
defer w.encoder.Close()
|
defer w.encoder.Close()
|
||||||
|
|
||||||
if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil {
|
if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
hf := headersFrame{Length: uint64(w.headerBuf.Len())}
|
hf := headersFrame{Length: uint64(w.headerBuf.Len())}
|
||||||
hf.Write(buf)
|
hf.Write(buf)
|
||||||
if _, err := io.Copy(buf, w.headerBuf); err != nil {
|
if _, err := wr.Write(buf.Bytes()); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
|
}
|
||||||
|
if _, err := wr.Write(w.headerBuf.Bytes()); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
w.headerBuf.Reset()
|
w.headerBuf.Reset()
|
||||||
return buf.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *requestWriter) sendRequestBody(req io.ReadCloser, str quic.Stream) error {
|
|
||||||
b := make([]byte, 8*1024)
|
|
||||||
for {
|
|
||||||
n, err := req.Read(b)
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
(&dataFrame{Length: uint64(n)}).Write(buf)
|
|
||||||
if _, err := str.Write(buf.Bytes()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := str.Write(b[:n]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
req.Close()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,15 +25,15 @@ var _ = Describe("Request Writer", func() {
|
||||||
|
|
||||||
decode := func(str io.Reader) map[string]string {
|
decode := func(str io.Reader) map[string]string {
|
||||||
frame, err := parseNextFrame(str)
|
frame, err := parseNextFrame(str)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||||
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
|
ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
|
||||||
headersFrame := frame.(*headersFrame)
|
headersFrame := frame.(*headersFrame)
|
||||||
data := make([]byte, headersFrame.Length)
|
data := make([]byte, headersFrame.Length)
|
||||||
_, err = io.ReadFull(str, data)
|
_, err = io.ReadFull(str, data)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||||
decoder := qpack.NewDecoder(nil)
|
decoder := qpack.NewDecoder(nil)
|
||||||
hfs, err := decoder.DecodeFull(data)
|
hfs, err := decoder.DecodeFull(data)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||||
values := make(map[string]string)
|
values := make(map[string]string)
|
||||||
for _, hf := range hfs {
|
for _, hf := range hfs {
|
||||||
values[hf.Name] = hf.Value
|
values[hf.Name] = hf.Value
|
||||||
|
@ -70,6 +70,8 @@ var _ = Describe("Request Writer", func() {
|
||||||
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", postData)
|
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", postData)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
|
||||||
|
|
||||||
|
Eventually(closed).Should(BeClosed())
|
||||||
headerFields := decode(strBuf)
|
headerFields := decode(strBuf)
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
|
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
|
||||||
Expect(headerFields).To(HaveKey("content-length"))
|
Expect(headerFields).To(HaveKey("content-length"))
|
||||||
|
@ -77,7 +79,6 @@ var _ = Describe("Request Writer", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(contentLength).To(BeNumerically(">", 0))
|
Expect(contentLength).To(BeNumerically(">", 0))
|
||||||
|
|
||||||
Eventually(closed).Should(BeClosed())
|
|
||||||
frame, err := parseNextFrame(strBuf)
|
frame, err := parseNextFrame(strBuf)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
|
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue