mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
make the responseWriter hijackable
This commit is contained in:
parent
a983db0301
commit
ff6313fdb3
4 changed files with 13 additions and 6 deletions
|
@ -429,7 +429,7 @@ var _ = Describe("Client", func() {
|
|||
buf := &bytes.Buffer{}
|
||||
rstr := mockquic.NewMockStream(mockCtrl)
|
||||
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
||||
rw := newResponseWriter(rstr, utils.DefaultLogger)
|
||||
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
|
||||
rw.WriteHeader(status)
|
||||
rw.Flush()
|
||||
return buf.Bytes()
|
||||
|
@ -717,7 +717,7 @@ var _ = Describe("Client", func() {
|
|||
buf := &bytes.Buffer{}
|
||||
rstr := mockquic.NewMockStream(mockCtrl)
|
||||
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
||||
rw := newResponseWriter(rstr, utils.DefaultLogger)
|
||||
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
|
||||
rw.Header().Set("Content-Encoding", "gzip")
|
||||
gz := gzip.NewWriter(rw)
|
||||
gz.Write([]byte("gzipped response"))
|
||||
|
@ -743,7 +743,7 @@ var _ = Describe("Client", func() {
|
|||
buf := &bytes.Buffer{}
|
||||
rstr := mockquic.NewMockStream(mockCtrl)
|
||||
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
||||
rw := newResponseWriter(rstr, utils.DefaultLogger)
|
||||
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
|
||||
rw.Write([]byte("not gzipped"))
|
||||
rw.Flush()
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
|
|
|
@ -23,6 +23,7 @@ type DataStreamer interface {
|
|||
}
|
||||
|
||||
type responseWriter struct {
|
||||
conn quic.Connection
|
||||
stream quic.Stream // needed for DataStream()
|
||||
bufferedStream *bufio.Writer
|
||||
|
||||
|
@ -38,12 +39,14 @@ var (
|
|||
_ http.ResponseWriter = &responseWriter{}
|
||||
_ http.Flusher = &responseWriter{}
|
||||
_ DataStreamer = &responseWriter{}
|
||||
_ Hijacker = &responseWriter{}
|
||||
)
|
||||
|
||||
func newResponseWriter(stream quic.Stream, logger utils.Logger) *responseWriter {
|
||||
func newResponseWriter(stream quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
|
||||
return &responseWriter{
|
||||
header: http.Header{},
|
||||
stream: stream,
|
||||
conn: conn,
|
||||
bufferedStream: bufio.NewWriter(stream),
|
||||
logger: logger,
|
||||
}
|
||||
|
@ -123,6 +126,10 @@ func (w *responseWriter) StreamID() quic.StreamID {
|
|||
return w.stream.StreamID()
|
||||
}
|
||||
|
||||
func (w *responseWriter) StreamCreator() StreamCreator {
|
||||
return w.conn
|
||||
}
|
||||
|
||||
// copied from http2/http2.go
|
||||
// bodyAllowedForStatus reports whether a given response status code
|
||||
// permits a body. See RFC 2616, section 4.4.
|
||||
|
|
|
@ -25,7 +25,7 @@ var _ = Describe("Response Writer", func() {
|
|||
strBuf = &bytes.Buffer{}
|
||||
str := mockquic.NewMockStream(mockCtrl)
|
||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
|
||||
rw = newResponseWriter(str, utils.DefaultLogger)
|
||||
rw = newResponseWriter(str, nil, utils.DefaultLogger)
|
||||
})
|
||||
|
||||
decodeHeader := func(str io.Reader) map[string][]string {
|
||||
|
|
|
@ -503,7 +503,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
|
|||
ctx = context.WithValue(ctx, ServerContextKey, s)
|
||||
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
|
||||
req = req.WithContext(ctx)
|
||||
r := newResponseWriter(str, s.logger)
|
||||
r := newResponseWriter(str, conn, s.logger)
|
||||
defer func() {
|
||||
if !r.usedDataStream() {
|
||||
r.Flush()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue