diff --git a/receive_stream.go b/receive_stream.go index 13674e8a..17512520 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -8,6 +8,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -43,9 +44,8 @@ type receiveStream struct { canceledRead bool // set when CancelRead() is called resetRemotely bool // set when HandleResetStreamFrame() is called - readChan chan struct{} - deadline time.Time - deadlineTimer *time.Timer // initialized by SetReadDeadline() + readChan chan struct{} + deadline time.Time flowController flowcontrol.StreamFlowController version protocol.VersionNumber @@ -109,6 +109,7 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err return false, bytesRead, s.closeForShutdownErr } + var deadlineTimer *utils.Timer for { // Stop waiting on errors if s.closedForShutdown { @@ -121,8 +122,15 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err return false, bytesRead, s.resetRemotelyErr } - if !s.deadline.IsZero() && !time.Now().Before(s.deadline) { - return false, bytesRead, errDeadline + deadline := s.deadline + if !deadline.IsZero() { + if !time.Now().Before(deadline) { + return false, bytesRead, errDeadline + } + if deadlineTimer == nil { + deadlineTimer = utils.NewTimer() + } + deadlineTimer.Reset(deadline) } if s.currentFrame != nil || s.currentFrameIsLast { @@ -130,12 +138,13 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err } s.mutex.Unlock() - if s.deadline.IsZero() { + if deadline.IsZero() { <-s.readChan } else { select { case <-s.readChan: - case <-s.deadlineTimer.C: + case <-deadlineTimer.Chan(): + deadlineTimer.SetRead() } } s.mutex.Lock() @@ -252,22 +261,9 @@ func (s *receiveStream) CloseRemote(offset protocol.ByteCount) { func (s *receiveStream) SetReadDeadline(t time.Time) error { s.mutex.Lock() - defer s.mutex.Unlock() s.deadline = t - if s.deadline.IsZero() { // skip if there's no deadline to set - s.signalRead() - return nil - } - // Lazily initialize the deadline timer. - if s.deadlineTimer == nil { - s.deadlineTimer = time.NewTimer(time.Until(t)) - return nil - } - // reset the timer to the new deadline - if !s.deadlineTimer.Stop() { - <-s.deadlineTimer.C - } - s.deadlineTimer.Reset(time.Until(t)) + s.mutex.Unlock() + s.signalRead() return nil } diff --git a/receive_stream_test.go b/receive_stream_test.go index 27543143..7707aba2 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -303,6 +303,33 @@ var _ = Describe("Receive Stream", func() { Expect(err).To(MatchError(errDeadline)) Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(25*time.Millisecond))) }) + + It("doesn't unblock if the deadline is removed", func() { + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetReadDeadline(deadline) + deadlineUnset := make(chan struct{}) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(20 * time.Millisecond)) + str.SetReadDeadline(time.Time{}) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline)) + close(deadlineUnset) + }() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Read(make([]byte, 1)) + Expect(err).To(MatchError("test done")) + close(done) + }() + runtime.Gosched() + Eventually(deadlineUnset).Should(BeClosed()) + Consistently(done, scaleDuration(100*time.Millisecond)).ShouldNot(BeClosed()) + // make the go routine return + str.closeForShutdown(errors.New("test done")) + Eventually(done).Should(BeClosed()) + }) }) Context("closing", func() { diff --git a/send_stream.go b/send_stream.go index ccb58a93..94883392 100644 --- a/send_stream.go +++ b/send_stream.go @@ -42,9 +42,8 @@ type sendStream struct { dataForWriting []byte - writeChan chan struct{} - deadline time.Time - deadlineTimer *time.Timer // initialized by SetReadDeadline() + writeChan chan struct{} + deadline time.Time flowController flowcontrol.StreamFlowController @@ -97,15 +96,23 @@ func (s *sendStream) Write(p []byte) (int, error) { s.dataForWriting = p - var bytesWritten int - var err error - var notifiedSender bool + var ( + deadlineTimer *utils.Timer + bytesWritten int + notifiedSender bool + ) for { bytesWritten = len(p) - len(s.dataForWriting) - if !s.deadline.IsZero() && !time.Now().Before(s.deadline) { - s.dataForWriting = nil - err = errDeadline - break + deadline := s.deadline + if !deadline.IsZero() { + if !time.Now().Before(deadline) { + s.dataForWriting = nil + return bytesWritten, errDeadline + } + if deadlineTimer == nil { + deadlineTimer = utils.NewTimer() + } + deadlineTimer.Reset(deadline) } if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown { break @@ -116,23 +123,24 @@ func (s *sendStream) Write(p []byte) (int, error) { s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex notifiedSender = true } - if s.deadline.IsZero() { + if deadline.IsZero() { <-s.writeChan } else { select { case <-s.writeChan: - case <-s.deadlineTimer.C: + case <-deadlineTimer.Chan(): + deadlineTimer.SetRead() } } s.mutex.Lock() } if s.closeForShutdownErr != nil { - err = s.closeForShutdownErr + return bytesWritten, s.closeForShutdownErr } else if s.cancelWriteErr != nil { - err = s.cancelWriteErr + return bytesWritten, s.cancelWriteErr } - return bytesWritten, err + return bytesWritten, nil } // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream @@ -301,22 +309,9 @@ func (s *sendStream) Context() context.Context { func (s *sendStream) SetWriteDeadline(t time.Time) error { s.mutex.Lock() - defer s.mutex.Unlock() s.deadline = t - if s.deadline.IsZero() { // skip if there's no deadline to set - s.signalWrite() - return nil - } - // Lazily initialize the deadline timer. - if s.deadlineTimer == nil { - s.deadlineTimer = time.NewTimer(time.Until(t)) - return nil - } - // reset the timer to the new deadline - if !s.deadlineTimer.Stop() { - <-s.deadlineTimer.C - } - s.deadlineTimer.Reset(time.Until(t)) + s.mutex.Unlock() + s.signalWrite() return nil } diff --git a/send_stream_test.go b/send_stream_test.go index b0a578f3..c2955c29 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -360,6 +360,34 @@ var _ = Describe("Send Stream", func() { Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) Eventually(done).Should(BeClosed()) }) + + It("doesn't unblock if the deadline is removed", func() { + mockSender.EXPECT().onHasStreamData(streamID) + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetWriteDeadline(deadline) + deadlineUnset := make(chan struct{}) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(20 * time.Millisecond)) + str.SetWriteDeadline(time.Time{}) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline)) + close(deadlineUnset) + }() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError("test done")) + close(done) + }() + runtime.Gosched() + Eventually(deadlineUnset).Should(BeClosed()) + Consistently(done, scaleDuration(100*time.Millisecond)).ShouldNot(BeClosed()) + // make the go routine return + str.closeForShutdown(errors.New("test done")) + Eventually(done).Should(BeClosed()) + }) }) Context("closing", func() {