diff --git a/stream.go b/stream.go index b5afc051..7b22cb73 100644 --- a/stream.go +++ b/stream.go @@ -158,13 +158,20 @@ func (s *stream) Write(p []byte) (int, error) { func (s *stream) lenOfDataForWriting() protocol.ByteCount { s.mutex.Lock() - l := protocol.ByteCount(len(s.dataForWriting)) + var l protocol.ByteCount + if s.err == nil { + l = protocol.ByteCount(len(s.dataForWriting)) + } s.mutex.Unlock() return l } func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { s.mutex.Lock() + if s.err != nil { + s.mutex.Unlock() + return nil + } if s.dataForWriting == nil { s.mutex.Unlock() return nil @@ -207,7 +214,6 @@ func (s *stream) sentFin() { func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { maxOffset := frame.Offset + frame.DataLen() err := s.flowControlManager.UpdateHighestReceived(s.streamID, maxOffset) - if err != nil { return err } diff --git a/stream_test.go b/stream_test.go index 386cf095..c5ca23a6 100644 --- a/stream_test.go +++ b/stream_test.go @@ -548,6 +548,19 @@ var _ = Describe("Stream", func() { Expect(n).To(BeZero()) Expect(err).To(MatchError(testErr)) }) + + It("doesn't get data for writing if an error occurred", func() { + go func() { + _, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(testErr)) + }() + Eventually(func() []byte { return str.dataForWriting }).ShouldNot(BeNil()) + Expect(str.lenOfDataForWriting()).ToNot(BeZero()) + str.RegisterError(testErr) + data := str.getDataForWriting(6) + Expect(data).To(BeNil()) + Expect(str.lenOfDataForWriting()).To(BeZero()) + }) }) Context("when CloseRemote is called", func() {