diff --git a/example/main.go b/example/main.go index 96b41196..acf81d79 100644 --- a/example/main.go +++ b/example/main.go @@ -45,6 +45,9 @@ type responseWriter struct { header http.Header headerWritten bool + + bytesWritten int + contentLength int } func (w *responseWriter) Header() http.Header { @@ -64,13 +67,15 @@ func (w *responseWriter) WriteHeader(status int) { enc.WriteField(hpack.HeaderField{Name: k, Value: v[0]}) } - fmt.Printf("responding with %d %#v\n", status, w.header) + fmt.Printf("Responding with %d %#v\n", status, w.header) h2framer := http2.NewFramer(w.headerStream, nil) h2framer.WriteHeaders(http2.HeadersFrameParam{ StreamID: uint32(w.dataStreamID), EndHeaders: true, BlockFragment: headers.Bytes(), }) + + w.contentLength, _ = strconv.Atoi(w.header.Get("content-length")) } func (w *responseWriter) Write(p []byte) (int, error) { @@ -79,12 +84,22 @@ func (w *responseWriter) Write(p []byte) (int, error) { } if len(p) != 0 { - dataStream, err := w.session.NewStream(w.dataStreamID) - if err != nil { - return 0, fmt.Errorf("error creating data stream: %s\n", err.Error()) + if w.dataStream == nil { + var err error + w.dataStream, err = w.session.NewStream(w.dataStreamID) + if err != nil { + return 0, fmt.Errorf("error creating data stream: %s\n", err.Error()) + } } - defer dataStream.Close() - return dataStream.Write(p) + + n, err := w.dataStream.Write(p) + w.bytesWritten += n + + if w.bytesWritten >= w.contentLength { + defer w.dataStream.Close() + } + + return n, err } return 0, nil