make the responseWriter hijackable

This commit is contained in:
Marten Seemann 2022-03-27 17:52:39 +01:00
parent a983db0301
commit ff6313fdb3
4 changed files with 13 additions and 6 deletions

View file

@ -429,7 +429,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw.WriteHeader(status)
rw.Flush()
return buf.Bytes()
@ -717,7 +717,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(rw)
gz.Write([]byte("gzipped response"))
@ -743,7 +743,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })

View file

@ -23,6 +23,7 @@ type DataStreamer interface {
}
type responseWriter struct {
conn quic.Connection
stream quic.Stream // needed for DataStream()
bufferedStream *bufio.Writer
@ -38,12 +39,14 @@ var (
_ http.ResponseWriter = &responseWriter{}
_ http.Flusher = &responseWriter{}
_ DataStreamer = &responseWriter{}
_ Hijacker = &responseWriter{}
)
func newResponseWriter(stream quic.Stream, logger utils.Logger) *responseWriter {
func newResponseWriter(stream quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
return &responseWriter{
header: http.Header{},
stream: stream,
conn: conn,
bufferedStream: bufio.NewWriter(stream),
logger: logger,
}
@ -123,6 +126,10 @@ func (w *responseWriter) StreamID() quic.StreamID {
return w.stream.StreamID()
}
func (w *responseWriter) StreamCreator() StreamCreator {
return w.conn
}
// copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4.

View file

@ -25,7 +25,7 @@ var _ = Describe("Response Writer", func() {
strBuf = &bytes.Buffer{}
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
rw = newResponseWriter(str, utils.DefaultLogger)
rw = newResponseWriter(str, nil, utils.DefaultLogger)
})
decodeHeader := func(str io.Reader) map[string][]string {

View file

@ -503,7 +503,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
req = req.WithContext(ctx)
r := newResponseWriter(str, s.logger)
r := newResponseWriter(str, conn, s.logger)
defer func() {
if !r.usedDataStream() {
r.Flush()