use a buffered writer for the http3 response writer

This commit is contained in:
Marten Seemann 2020-03-27 08:42:35 +07:00
parent c10af76a4a
commit 683230372e
5 changed files with 19 additions and 10 deletions

View file

@ -216,6 +216,7 @@ var _ = Describe("Client", func() {
rspBuf := &bytes.Buffer{} rspBuf := &bytes.Buffer{}
rw := newResponseWriter(rspBuf, utils.DefaultLogger) rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418) rw.WriteHeader(418)
rw.Flush()
gomock.InOrder( gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx), sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
@ -383,6 +384,7 @@ var _ = Describe("Client", func() {
rspBuf := &bytes.Buffer{} rspBuf := &bytes.Buffer{}
rw := newResponseWriter(rspBuf, utils.DefaultLogger) rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418) rw.WriteHeader(418)
rw.Flush()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx) req := request.WithContext(ctx)
@ -455,6 +457,7 @@ var _ = Describe("Client", func() {
gz := gzip.NewWriter(rw) gz := gzip.NewWriter(rw)
gz.Write([]byte("gzipped response")) gz.Write([]byte("gzipped response"))
gz.Close() gz.Close()
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes() str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p) return buf.Read(p)
@ -476,6 +479,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger) rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Write([]byte("not gzipped")) rw.Write([]byte("not gzipped"))
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes() str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p) return buf.Read(p)

View file

@ -1,6 +1,7 @@
package http3 package http3
import ( import (
"bufio"
"bytes" "bytes"
"io" "io"
"net/http" "net/http"
@ -12,7 +13,7 @@ import (
) )
type responseWriter struct { type responseWriter struct {
stream io.Writer stream *bufio.Writer
header http.Header header http.Header
status int // status code passed to WriteHeader status int // status code passed to WriteHeader
@ -22,11 +23,12 @@ type responseWriter struct {
} }
var _ http.ResponseWriter = &responseWriter{} var _ http.ResponseWriter = &responseWriter{}
var _ http.Flusher = &responseWriter{}
func newResponseWriter(stream io.Writer, logger utils.Logger) *responseWriter { func newResponseWriter(stream io.Writer, logger utils.Logger) *responseWriter {
return &responseWriter{ return &responseWriter{
header: http.Header{}, header: http.Header{},
stream: stream, stream: bufio.NewWriter(stream),
logger: logger, logger: logger,
} }
} }
@ -79,10 +81,11 @@ func (w *responseWriter) Write(p []byte) (int, error) {
return w.stream.Write(p) return w.stream.Write(p)
} }
func (w *responseWriter) Flush() {} func (w *responseWriter) Flush() {
if err := w.stream.Flush(); err != nil {
// test that we implement http.Flusher w.logger.Errorf("could not flush to stream: %s", err.Error())
var _ http.Flusher = &responseWriter{} }
}
// copied from http2/http2.go // copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code // bodyAllowedForStatus reports whether a given response status code

View file

@ -24,6 +24,7 @@ var _ = Describe("Response Writer", func() {
}) })
decodeHeader := func(str io.Reader) map[string][]string { decodeHeader := func(str io.Reader) map[string][]string {
rw.Flush()
fields := make(map[string][]string) fields := make(map[string][]string)
decoder := qpack.NewDecoder(nil) decoder := qpack.NewDecoder(nil)

View file

@ -270,6 +270,7 @@ func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpac
ctx = context.WithValue(ctx, http.LocalAddrContextKey, sess.LocalAddr()) ctx = context.WithValue(ctx, http.LocalAddrContextKey, sess.LocalAddr())
req = req.WithContext(ctx) req = req.WithContext(ctx)
responseWriter := newResponseWriter(str, s.logger) responseWriter := newResponseWriter(str, s.logger)
defer responseWriter.Flush()
handler := s.Handler handler := s.Handler
if handler == nil { if handler == nil {
handler = http.DefaultServeMux handler = http.DefaultServeMux

View file

@ -57,14 +57,14 @@ var _ = Describe("Server", func() {
decoder := qpack.NewDecoder(nil) decoder := qpack.NewDecoder(nil)
frame, err := parseNextFrame(str) frame, err := parseNextFrame(str)
Expect(err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame) headersFrame := frame.(*headersFrame)
data := make([]byte, headersFrame.Length) data := make([]byte, headersFrame.Length)
_, err = io.ReadFull(str, data) _, err = io.ReadFull(str, data)
Expect(err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
hfs, err := decoder.DecodeFull(data) hfs, err := decoder.DecodeFull(data)
Expect(err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
for _, p := range hfs { for _, p := range hfs {
fields[p.Name] = append(fields[p.Name], p.Value) fields[p.Name] = append(fields[p.Name], p.Value)
} }