allow access to the underlying quic.Stream from a http.ResponseWriter

This commit is contained in:
Marten Seemann 2021-01-01 12:41:26 +08:00
parent d1c5297c0b
commit 35939b25a9
5 changed files with 94 additions and 40 deletions

View file

@ -385,6 +385,16 @@ var _ = Describe("Client", func() {
return fields return fields
} }
getResponse := func(status int) []byte {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, utils.DefaultLogger)
rw.WriteHeader(status)
rw.Flush()
return buf.Bytes()
}
BeforeEach(func() { BeforeEach(func() {
settingsFrameWritten = make(chan struct{}) settingsFrameWritten = make(chan struct{})
controlStr := mockquic.NewMockStream(mockCtrl) controlStr := mockquic.NewMockStream(mockCtrl)
@ -441,11 +451,7 @@ var _ = Describe("Client", func() {
}) })
It("returns a response", func() { It("returns a response", func() {
rspBuf := &bytes.Buffer{} rspBuf := bytes.NewBuffer(getResponse(418))
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418)
rw.Flush()
gomock.InOrder( gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx), sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
@ -453,9 +459,7 @@ var _ = Describe("Client", func() {
) )
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close() str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
return rspBuf.Read(p)
}).AnyTimes()
rsp, err := client.RoundTrip(request) rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3")) Expect(rsp.Proto).To(Equal("HTTP/3"))
@ -590,10 +594,7 @@ var _ = Describe("Client", func() {
}) })
It("cancels a request after the response arrived", func() { It("cancels a request after the response arrived", func() {
rspBuf := &bytes.Buffer{} rspBuf := bytes.NewBuffer(getResponse(404))
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418)
rw.Flush()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx) req := request.WithContext(ctx)
@ -656,7 +657,9 @@ var _ = Describe("Client", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger) rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip") rw.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(rw) gz := gzip.NewWriter(rw)
gz.Write([]byte("gzipped response")) gz.Write([]byte("gzipped response"))
@ -680,7 +683,9 @@ var _ = Describe("Client", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger) rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, utils.DefaultLogger)
rw.Write([]byte("not gzipped")) rw.Write([]byte("not gzipped"))
rw.Flush() rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })

View file

@ -3,21 +3,33 @@ package http3
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"io"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack" "github.com/marten-seemann/qpack"
) )
type responseWriter struct { // DataStreamer lets the caller take over the stream. After a call to DataStream
stream *bufio.Writer // the HTTP server library will not do anything else with the connection.
//
// It becomes the caller's responsibility to manage and close the stream.
//
// After a call to DataStream, the original Request.Body must not be used.
type DataStreamer interface {
DataStream() quic.Stream
}
header http.Header type responseWriter struct {
status int // status code passed to WriteHeader stream quic.Stream // needed for DataStream()
headerWritten bool bufferedStream *bufio.Writer
header http.Header
status int // status code passed to WriteHeader
headerWritten bool
dataStreamUsed bool // set when DataSteam() is called
logger utils.Logger logger utils.Logger
} }
@ -25,13 +37,15 @@ type responseWriter struct {
var ( var (
_ http.ResponseWriter = &responseWriter{} _ http.ResponseWriter = &responseWriter{}
_ http.Flusher = &responseWriter{} _ http.Flusher = &responseWriter{}
_ DataStreamer = &responseWriter{}
) )
func newResponseWriter(stream io.Writer, logger utils.Logger) *responseWriter { func newResponseWriter(stream quic.Stream, logger utils.Logger) *responseWriter {
return &responseWriter{ return &responseWriter{
header: http.Header{}, header: http.Header{},
stream: bufio.NewWriter(stream), stream: stream,
logger: logger, bufferedStream: bufio.NewWriter(stream),
logger: logger,
} }
} }
@ -59,10 +73,10 @@ func (w *responseWriter) WriteHeader(status int) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
(&headersFrame{Length: uint64(headers.Len())}).Write(buf) (&headersFrame{Length: uint64(headers.Len())}).Write(buf)
w.logger.Infof("Responding with %d", status) w.logger.Infof("Responding with %d", status)
if _, err := w.stream.Write(buf.Bytes()); err != nil { if _, err := w.bufferedStream.Write(buf.Bytes()); err != nil {
w.logger.Errorf("could not write headers frame: %s", err.Error()) w.logger.Errorf("could not write headers frame: %s", err.Error())
} }
if _, err := w.stream.Write(headers.Bytes()); err != nil { if _, err := w.bufferedStream.Write(headers.Bytes()); err != nil {
w.logger.Errorf("could not write header frame payload: %s", err.Error()) w.logger.Errorf("could not write header frame payload: %s", err.Error())
} }
} }
@ -77,18 +91,28 @@ func (w *responseWriter) Write(p []byte) (int, error) {
df := &dataFrame{Length: uint64(len(p))} df := &dataFrame{Length: uint64(len(p))}
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
df.Write(buf) df.Write(buf)
if _, err := w.stream.Write(buf.Bytes()); err != nil { if _, err := w.bufferedStream.Write(buf.Bytes()); err != nil {
return 0, err return 0, err
} }
return w.stream.Write(p) return w.bufferedStream.Write(p)
} }
func (w *responseWriter) Flush() { func (w *responseWriter) Flush() {
if err := w.stream.Flush(); err != nil { if err := w.bufferedStream.Flush(); err != nil {
w.logger.Errorf("could not flush to stream: %s", err.Error()) w.logger.Errorf("could not flush to stream: %s", err.Error())
} }
} }
func (w *responseWriter) usedDataStream() bool {
return w.dataStreamUsed
}
func (w *responseWriter) DataStream() quic.Stream {
w.dataStreamUsed = true
w.Flush()
return w.stream
}
// copied from http2/http2.go // copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code // bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4. // permits a body. See RFC 2616, section 4.4.

View file

@ -5,7 +5,10 @@ import (
"io" "io"
"net/http" "net/http"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/golang/mock/gomock"
"github.com/marten-seemann/qpack" "github.com/marten-seemann/qpack"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -20,7 +23,9 @@ var _ = Describe("Response Writer", func() {
BeforeEach(func() { BeforeEach(func() {
strBuf = &bytes.Buffer{} strBuf = &bytes.Buffer{}
rw = newResponseWriter(strBuf, utils.DefaultLogger) str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).Do(strBuf.Write).AnyTimes()
rw = newResponseWriter(str, utils.DefaultLogger)
}) })
decodeHeader := func(str io.Reader) map[string][]string { decodeHeader := func(str io.Reader) map[string][]string {

View file

@ -367,8 +367,12 @@ func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpac
ctx = context.WithValue(ctx, ServerContextKey, s) ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, sess.LocalAddr()) ctx = context.WithValue(ctx, http.LocalAddrContextKey, sess.LocalAddr())
req = req.WithContext(ctx) req = req.WithContext(ctx)
responseWriter := newResponseWriter(str, s.logger) r := newResponseWriter(str, s.logger)
defer responseWriter.Flush() defer func() {
if !r.usedDataStream() {
r.Flush()
}
}()
handler := s.Handler handler := s.Handler
if handler == nil { if handler == nil {
handler = http.DefaultServeMux handler = http.DefaultServeMux
@ -386,17 +390,18 @@ func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpac
panicked = true panicked = true
} }
}() }()
handler.ServeHTTP(responseWriter, req) handler.ServeHTTP(r, req)
}() }()
if panicked { if !r.usedDataStream() {
responseWriter.WriteHeader(500) if panicked {
} else { r.WriteHeader(500)
responseWriter.WriteHeader(200) } else {
r.WriteHeader(200)
}
// If the EOF was read by the handler, CancelRead() is a no-op.
str.CancelRead(quic.ErrorCode(errorNoError))
} }
// If the EOF was read by the handler, CancelRead() is a no-op.
str.CancelRead(quic.ErrorCode(errorNoError))
return requestError{} return requestError{}
} }

View file

@ -177,6 +177,21 @@ var _ = Describe("Server", func() {
Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"})) Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"}))
}) })
It("doesn't close the stream if the handler called DataStream()", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
str := w.(DataStreamer).DataStream()
str.Write([]byte("foobar"))
})
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write([]byte("foobar"))
// don't EXPECT CancelRead()
serr := s.handleRequest(sess, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
})
Context("control stream handling", func() { Context("control stream handling", func() {
var sess *mockquic.MockEarlySession var sess *mockquic.MockEarlySession
testDone := make(chan struct{}) testDone := make(chan struct{})