diff --git a/stream.go b/stream.go index 9ee9b3a3..1348eb9c 100644 --- a/stream.go +++ b/stream.go @@ -50,6 +50,7 @@ type stream struct { finSent utils.AtomicBool rstSent utils.AtomicBool writeChan chan struct{} + writeDeadline time.Time flowControlManager flowcontrol.FlowControlManager } @@ -168,38 +169,48 @@ func (s *stream) Read(p []byte) (int, error) { func (s *stream) Write(p []byte) (int, error) { s.mutex.Lock() + defer s.mutex.Unlock() + if s.resetLocally.Get() || s.err != nil { - err := s.err - s.mutex.Unlock() - return 0, err + return 0, s.err } if len(p) == 0 { - s.mutex.Unlock() return 0, nil } s.dataForWriting = make([]byte, len(p)) copy(s.dataForWriting, p) s.onData() - s.mutex.Unlock() + var err error for { - s.mutex.Lock() - if s.dataForWriting == nil || s.err != nil { - s.mutex.Unlock() + deadline := s.writeDeadline + if !deadline.IsZero() && !time.Now().Before(deadline) { + err = errDeadline break } + if s.dataForWriting == nil || s.err != nil { + break + } + s.mutex.Unlock() - <-s.writeChan + if deadline.IsZero() { + <-s.writeChan + } else { + select { + case <-s.writeChan: + case <-time.After(deadline.Sub(time.Now())): + } + } + s.mutex.Lock() } - s.mutex.Lock() - defer s.mutex.Unlock() - + if err != nil { + return 0, err + } if s.err != nil { return len(p) - len(s.dataForWriting), s.err } - return len(p), nil } @@ -308,6 +319,31 @@ func (s *stream) SetReadDeadline(t time.Time) error { return nil } +// SetWriteDeadline sets the deadline for future Write calls +// and any currently-blocked Write call. +// Even if write times out, it may return n > 0, indicating that +// some of the data was successfully written. +// A zero value for t means Write will not time out. +func (s *stream) SetWriteDeadline(t time.Time) error { + s.mutex.Lock() + oldDeadline := s.writeDeadline + s.writeDeadline = t + s.mutex.Unlock() + if t.Before(oldDeadline) { + s.signalWrite() + } + return nil +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. It is equivalent to calling both +// SetReadDeadline and SetWriteDeadline. +func (s *stream) SetDeadline(t time.Time) error { + _ = s.SetReadDeadline(t) // SetReadDeadline never errors + _ = s.SetWriteDeadline(t) // SetWriteDeadline never errors + return nil +} + // CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset func (s *stream) CloseRemote(offset protocol.ByteCount) { s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset}) diff --git a/stream_test.go b/stream_test.go index 6b7cdf72..c2a1bb30 100644 --- a/stream_test.go +++ b/stream_test.go @@ -252,7 +252,7 @@ var _ = Describe("Stream", func() { Expect(n).To(BeZero()) }) - It("unblocks read after the deadline", func() { + It("unblocks after the deadline", func() { deadline := time.Now().Add(200 * time.Millisecond) str.SetReadDeadline(deadline) b := make([]byte, 6) @@ -298,6 +298,18 @@ var _ = Describe("Stream", func() { Expect(err).To(MatchError(errDeadline)) Expect(time.Now()).To(BeTemporally("~", deadline2, 50*time.Millisecond)) }) + + It("sets a read deadline, when SetDeadline is called", func() { + mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(6)).AnyTimes() + f := &frames.StreamFrame{Data: []byte("foobar")} + err := str.AddStreamFrame(f) + Expect(err).ToNot(HaveOccurred()) + str.SetDeadline(time.Now().Add(-time.Second)) + b := make([]byte, 6) + n, err := str.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) }) Context("closing", func() { @@ -763,6 +775,66 @@ var _ = Describe("Stream", func() { Expect(err).ToNot(HaveOccurred()) }) + Context("deadlines", func() { + It("returns an error when Write is called after the deadline", func() { + str.SetWriteDeadline(time.Now().Add(-time.Second)) + n, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) + + It("unblocks after the deadline", func() { + deadline := time.Now().Add(200 * time.Millisecond) + str.SetWriteDeadline(deadline) + n, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline, 50*time.Millisecond)) + }) + + It("doesn't unblock if the deadline is changed before the first one expires", func() { + deadline1 := time.Now().Add(200 * time.Millisecond) + deadline2 := time.Now().Add(400 * time.Millisecond) + str.SetWriteDeadline(deadline1) + go func() { + defer GinkgoRecover() + time.Sleep(50 * time.Millisecond) + str.SetWriteDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline1)) + }() + runtime.Gosched() + n, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline2, 50*time.Millisecond)) + }) + + It("unblocks earlier, when a new deadline is set", func() { + deadline1 := time.Now().Add(1200 * time.Millisecond) + deadline2 := time.Now().Add(300 * time.Millisecond) + go func() { + defer GinkgoRecover() + time.Sleep(50 * time.Millisecond) + str.SetWriteDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline2)) + }() + str.SetWriteDeadline(deadline1) + runtime.Gosched() + _, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(time.Now()).To(BeTemporally("~", deadline2, 50*time.Millisecond)) + }) + + It("sets a read deadline, when SetDeadline is called", func() { + str.SetDeadline(time.Now().Add(-time.Second)) + n, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) + }) + Context("closing", func() { It("sets finishedWriting when calling Close", func() { str.Close()