mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
allow access to the underlying quic.Stream from a http.ResponseWriter
This commit is contained in:
parent
d1c5297c0b
commit
35939b25a9
5 changed files with 94 additions and 40 deletions
|
@ -385,6 +385,16 @@ var _ = Describe("Client", func() {
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getResponse := func(status int) []byte {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
rstr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
||||||
|
rw := newResponseWriter(rstr, utils.DefaultLogger)
|
||||||
|
rw.WriteHeader(status)
|
||||||
|
rw.Flush()
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
settingsFrameWritten = make(chan struct{})
|
settingsFrameWritten = make(chan struct{})
|
||||||
controlStr := mockquic.NewMockStream(mockCtrl)
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
@ -441,11 +451,7 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns a response", func() {
|
It("returns a response", func() {
|
||||||
rspBuf := &bytes.Buffer{}
|
rspBuf := bytes.NewBuffer(getResponse(418))
|
||||||
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
|
|
||||||
rw.WriteHeader(418)
|
|
||||||
rw.Flush()
|
|
||||||
|
|
||||||
gomock.InOrder(
|
gomock.InOrder(
|
||||||
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
|
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
|
||||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
|
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
|
||||||
|
@ -453,9 +459,7 @@ var _ = Describe("Client", func() {
|
||||||
)
|
)
|
||||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||||
str.EXPECT().Close()
|
str.EXPECT().Close()
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||||
return rspBuf.Read(p)
|
|
||||||
}).AnyTimes()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
rsp, err := client.RoundTrip(request)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(rsp.Proto).To(Equal("HTTP/3"))
|
Expect(rsp.Proto).To(Equal("HTTP/3"))
|
||||||
|
@ -590,10 +594,7 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("cancels a request after the response arrived", func() {
|
It("cancels a request after the response arrived", func() {
|
||||||
rspBuf := &bytes.Buffer{}
|
rspBuf := bytes.NewBuffer(getResponse(404))
|
||||||
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
|
|
||||||
rw.WriteHeader(418)
|
|
||||||
rw.Flush()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
req := request.WithContext(ctx)
|
req := request.WithContext(ctx)
|
||||||
|
@ -656,7 +657,9 @@ var _ = Describe("Client", func() {
|
||||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||||
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
|
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
rw := newResponseWriter(buf, utils.DefaultLogger)
|
rstr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
||||||
|
rw := newResponseWriter(rstr, utils.DefaultLogger)
|
||||||
rw.Header().Set("Content-Encoding", "gzip")
|
rw.Header().Set("Content-Encoding", "gzip")
|
||||||
gz := gzip.NewWriter(rw)
|
gz := gzip.NewWriter(rw)
|
||||||
gz.Write([]byte("gzipped response"))
|
gz.Write([]byte("gzipped response"))
|
||||||
|
@ -680,7 +683,9 @@ var _ = Describe("Client", func() {
|
||||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||||
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
|
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
rw := newResponseWriter(buf, utils.DefaultLogger)
|
rstr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
||||||
|
rw := newResponseWriter(rstr, utils.DefaultLogger)
|
||||||
rw.Write([]byte("not gzipped"))
|
rw.Write([]byte("not gzipped"))
|
||||||
rw.Flush()
|
rw.Flush()
|
||||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||||
|
|
|
@ -3,21 +3,33 @@ package http3
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/marten-seemann/qpack"
|
"github.com/marten-seemann/qpack"
|
||||||
)
|
)
|
||||||
|
|
||||||
type responseWriter struct {
|
// DataStreamer lets the caller take over the stream. After a call to DataStream
|
||||||
stream *bufio.Writer
|
// the HTTP server library will not do anything else with the connection.
|
||||||
|
//
|
||||||
|
// It becomes the caller's responsibility to manage and close the stream.
|
||||||
|
//
|
||||||
|
// After a call to DataStream, the original Request.Body must not be used.
|
||||||
|
type DataStreamer interface {
|
||||||
|
DataStream() quic.Stream
|
||||||
|
}
|
||||||
|
|
||||||
header http.Header
|
type responseWriter struct {
|
||||||
status int // status code passed to WriteHeader
|
stream quic.Stream // needed for DataStream()
|
||||||
headerWritten bool
|
bufferedStream *bufio.Writer
|
||||||
|
|
||||||
|
header http.Header
|
||||||
|
status int // status code passed to WriteHeader
|
||||||
|
headerWritten bool
|
||||||
|
dataStreamUsed bool // set when DataSteam() is called
|
||||||
|
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
@ -25,13 +37,15 @@ type responseWriter struct {
|
||||||
var (
|
var (
|
||||||
_ http.ResponseWriter = &responseWriter{}
|
_ http.ResponseWriter = &responseWriter{}
|
||||||
_ http.Flusher = &responseWriter{}
|
_ http.Flusher = &responseWriter{}
|
||||||
|
_ DataStreamer = &responseWriter{}
|
||||||
)
|
)
|
||||||
|
|
||||||
func newResponseWriter(stream io.Writer, logger utils.Logger) *responseWriter {
|
func newResponseWriter(stream quic.Stream, logger utils.Logger) *responseWriter {
|
||||||
return &responseWriter{
|
return &responseWriter{
|
||||||
header: http.Header{},
|
header: http.Header{},
|
||||||
stream: bufio.NewWriter(stream),
|
stream: stream,
|
||||||
logger: logger,
|
bufferedStream: bufio.NewWriter(stream),
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,10 +73,10 @@ func (w *responseWriter) WriteHeader(status int) {
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
(&headersFrame{Length: uint64(headers.Len())}).Write(buf)
|
(&headersFrame{Length: uint64(headers.Len())}).Write(buf)
|
||||||
w.logger.Infof("Responding with %d", status)
|
w.logger.Infof("Responding with %d", status)
|
||||||
if _, err := w.stream.Write(buf.Bytes()); err != nil {
|
if _, err := w.bufferedStream.Write(buf.Bytes()); err != nil {
|
||||||
w.logger.Errorf("could not write headers frame: %s", err.Error())
|
w.logger.Errorf("could not write headers frame: %s", err.Error())
|
||||||
}
|
}
|
||||||
if _, err := w.stream.Write(headers.Bytes()); err != nil {
|
if _, err := w.bufferedStream.Write(headers.Bytes()); err != nil {
|
||||||
w.logger.Errorf("could not write header frame payload: %s", err.Error())
|
w.logger.Errorf("could not write header frame payload: %s", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -77,18 +91,28 @@ func (w *responseWriter) Write(p []byte) (int, error) {
|
||||||
df := &dataFrame{Length: uint64(len(p))}
|
df := &dataFrame{Length: uint64(len(p))}
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
df.Write(buf)
|
df.Write(buf)
|
||||||
if _, err := w.stream.Write(buf.Bytes()); err != nil {
|
if _, err := w.bufferedStream.Write(buf.Bytes()); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return w.stream.Write(p)
|
return w.bufferedStream.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *responseWriter) Flush() {
|
func (w *responseWriter) Flush() {
|
||||||
if err := w.stream.Flush(); err != nil {
|
if err := w.bufferedStream.Flush(); err != nil {
|
||||||
w.logger.Errorf("could not flush to stream: %s", err.Error())
|
w.logger.Errorf("could not flush to stream: %s", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *responseWriter) usedDataStream() bool {
|
||||||
|
return w.dataStreamUsed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseWriter) DataStream() quic.Stream {
|
||||||
|
w.dataStreamUsed = true
|
||||||
|
w.Flush()
|
||||||
|
return w.stream
|
||||||
|
}
|
||||||
|
|
||||||
// copied from http2/http2.go
|
// copied from http2/http2.go
|
||||||
// bodyAllowedForStatus reports whether a given response status code
|
// bodyAllowedForStatus reports whether a given response status code
|
||||||
// permits a body. See RFC 2616, section 4.4.
|
// permits a body. See RFC 2616, section 4.4.
|
||||||
|
|
|
@ -5,7 +5,10 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/marten-seemann/qpack"
|
"github.com/marten-seemann/qpack"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
|
@ -20,7 +23,9 @@ var _ = Describe("Response Writer", func() {
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
strBuf = &bytes.Buffer{}
|
strBuf = &bytes.Buffer{}
|
||||||
rw = newResponseWriter(strBuf, utils.DefaultLogger)
|
str := mockquic.NewMockStream(mockCtrl)
|
||||||
|
str.EXPECT().Write(gomock.Any()).Do(strBuf.Write).AnyTimes()
|
||||||
|
rw = newResponseWriter(str, utils.DefaultLogger)
|
||||||
})
|
})
|
||||||
|
|
||||||
decodeHeader := func(str io.Reader) map[string][]string {
|
decodeHeader := func(str io.Reader) map[string][]string {
|
||||||
|
|
|
@ -367,8 +367,12 @@ func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpac
|
||||||
ctx = context.WithValue(ctx, ServerContextKey, s)
|
ctx = context.WithValue(ctx, ServerContextKey, s)
|
||||||
ctx = context.WithValue(ctx, http.LocalAddrContextKey, sess.LocalAddr())
|
ctx = context.WithValue(ctx, http.LocalAddrContextKey, sess.LocalAddr())
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
responseWriter := newResponseWriter(str, s.logger)
|
r := newResponseWriter(str, s.logger)
|
||||||
defer responseWriter.Flush()
|
defer func() {
|
||||||
|
if !r.usedDataStream() {
|
||||||
|
r.Flush()
|
||||||
|
}
|
||||||
|
}()
|
||||||
handler := s.Handler
|
handler := s.Handler
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
handler = http.DefaultServeMux
|
handler = http.DefaultServeMux
|
||||||
|
@ -386,17 +390,18 @@ func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpac
|
||||||
panicked = true
|
panicked = true
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
handler.ServeHTTP(responseWriter, req)
|
handler.ServeHTTP(r, req)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if panicked {
|
if !r.usedDataStream() {
|
||||||
responseWriter.WriteHeader(500)
|
if panicked {
|
||||||
} else {
|
r.WriteHeader(500)
|
||||||
responseWriter.WriteHeader(200)
|
} else {
|
||||||
|
r.WriteHeader(200)
|
||||||
|
}
|
||||||
|
// If the EOF was read by the handler, CancelRead() is a no-op.
|
||||||
|
str.CancelRead(quic.ErrorCode(errorNoError))
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the EOF was read by the handler, CancelRead() is a no-op.
|
|
||||||
str.CancelRead(quic.ErrorCode(errorNoError))
|
|
||||||
return requestError{}
|
return requestError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -177,6 +177,21 @@ var _ = Describe("Server", func() {
|
||||||
Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"}))
|
Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"}))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("doesn't close the stream if the handler called DataStream()", func() {
|
||||||
|
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
str := w.(DataStreamer).DataStream()
|
||||||
|
str.Write([]byte("foobar"))
|
||||||
|
})
|
||||||
|
|
||||||
|
setRequest(encodeRequest(exampleGetRequest))
|
||||||
|
str.EXPECT().Context().Return(reqContext)
|
||||||
|
str.EXPECT().Write([]byte("foobar"))
|
||||||
|
// don't EXPECT CancelRead()
|
||||||
|
|
||||||
|
serr := s.handleRequest(sess, str, qpackDecoder, nil)
|
||||||
|
Expect(serr.err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
Context("control stream handling", func() {
|
Context("control stream handling", func() {
|
||||||
var sess *mockquic.MockEarlySession
|
var sess *mockquic.MockEarlySession
|
||||||
testDone := make(chan struct{})
|
testDone := make(chan struct{})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue