mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
make the responseWriter hijackable
This commit is contained in:
parent
a983db0301
commit
ff6313fdb3
4 changed files with 13 additions and 6 deletions
|
@ -429,7 +429,7 @@ var _ = Describe("Client", func() {
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
rstr := mockquic.NewMockStream(mockCtrl)
|
rstr := mockquic.NewMockStream(mockCtrl)
|
||||||
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
||||||
rw := newResponseWriter(rstr, utils.DefaultLogger)
|
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
|
||||||
rw.WriteHeader(status)
|
rw.WriteHeader(status)
|
||||||
rw.Flush()
|
rw.Flush()
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
|
@ -717,7 +717,7 @@ var _ = Describe("Client", func() {
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
rstr := mockquic.NewMockStream(mockCtrl)
|
rstr := mockquic.NewMockStream(mockCtrl)
|
||||||
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
||||||
rw := newResponseWriter(rstr, utils.DefaultLogger)
|
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
|
||||||
rw.Header().Set("Content-Encoding", "gzip")
|
rw.Header().Set("Content-Encoding", "gzip")
|
||||||
gz := gzip.NewWriter(rw)
|
gz := gzip.NewWriter(rw)
|
||||||
gz.Write([]byte("gzipped response"))
|
gz.Write([]byte("gzipped response"))
|
||||||
|
@ -743,7 +743,7 @@ var _ = Describe("Client", func() {
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
rstr := mockquic.NewMockStream(mockCtrl)
|
rstr := mockquic.NewMockStream(mockCtrl)
|
||||||
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
||||||
rw := newResponseWriter(rstr, utils.DefaultLogger)
|
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
|
||||||
rw.Write([]byte("not gzipped"))
|
rw.Write([]byte("not gzipped"))
|
||||||
rw.Flush()
|
rw.Flush()
|
||||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||||
|
|
|
@ -23,6 +23,7 @@ type DataStreamer interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type responseWriter struct {
|
type responseWriter struct {
|
||||||
|
conn quic.Connection
|
||||||
stream quic.Stream // needed for DataStream()
|
stream quic.Stream // needed for DataStream()
|
||||||
bufferedStream *bufio.Writer
|
bufferedStream *bufio.Writer
|
||||||
|
|
||||||
|
@ -38,12 +39,14 @@ var (
|
||||||
_ http.ResponseWriter = &responseWriter{}
|
_ http.ResponseWriter = &responseWriter{}
|
||||||
_ http.Flusher = &responseWriter{}
|
_ http.Flusher = &responseWriter{}
|
||||||
_ DataStreamer = &responseWriter{}
|
_ DataStreamer = &responseWriter{}
|
||||||
|
_ Hijacker = &responseWriter{}
|
||||||
)
|
)
|
||||||
|
|
||||||
func newResponseWriter(stream quic.Stream, logger utils.Logger) *responseWriter {
|
func newResponseWriter(stream quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
|
||||||
return &responseWriter{
|
return &responseWriter{
|
||||||
header: http.Header{},
|
header: http.Header{},
|
||||||
stream: stream,
|
stream: stream,
|
||||||
|
conn: conn,
|
||||||
bufferedStream: bufio.NewWriter(stream),
|
bufferedStream: bufio.NewWriter(stream),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
@ -123,6 +126,10 @@ func (w *responseWriter) StreamID() quic.StreamID {
|
||||||
return w.stream.StreamID()
|
return w.stream.StreamID()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *responseWriter) StreamCreator() StreamCreator {
|
||||||
|
return w.conn
|
||||||
|
}
|
||||||
|
|
||||||
// copied from http2/http2.go
|
// copied from http2/http2.go
|
||||||
// bodyAllowedForStatus reports whether a given response status code
|
// bodyAllowedForStatus reports whether a given response status code
|
||||||
// permits a body. See RFC 2616, section 4.4.
|
// permits a body. See RFC 2616, section 4.4.
|
||||||
|
|
|
@ -25,7 +25,7 @@ var _ = Describe("Response Writer", func() {
|
||||||
strBuf = &bytes.Buffer{}
|
strBuf = &bytes.Buffer{}
|
||||||
str := mockquic.NewMockStream(mockCtrl)
|
str := mockquic.NewMockStream(mockCtrl)
|
||||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
|
||||||
rw = newResponseWriter(str, utils.DefaultLogger)
|
rw = newResponseWriter(str, nil, utils.DefaultLogger)
|
||||||
})
|
})
|
||||||
|
|
||||||
decodeHeader := func(str io.Reader) map[string][]string {
|
decodeHeader := func(str io.Reader) map[string][]string {
|
||||||
|
|
|
@ -503,7 +503,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
|
||||||
ctx = context.WithValue(ctx, ServerContextKey, s)
|
ctx = context.WithValue(ctx, ServerContextKey, s)
|
||||||
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
|
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
r := newResponseWriter(str, s.logger)
|
r := newResponseWriter(str, conn, s.logger)
|
||||||
defer func() {
|
defer func() {
|
||||||
if !r.usedDataStream() {
|
if !r.usedDataStream() {
|
||||||
r.Flush()
|
r.Flush()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue