implement a context for the stream

The context is cancelled when the write-side of the stream is closed.
This commit is contained in:
Marten Seemann 2017-07-29 09:18:33 +07:00
parent e02f5d5fbe
commit 8ef69143ba
5 changed files with 55 additions and 2 deletions

View file

@ -2,6 +2,7 @@ package h2quic
import (
"bytes"
"context"
"io"
"net/http"
"sync"
@ -37,6 +38,7 @@ func (s *mockStream) Close() error { s.closed = true; r
func (s *mockStream) Reset(error) { s.reset = true }
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true }
func (s mockStream) StreamID() protocol.StreamID { return s.id }
func (s *mockStream) Context() context.Context { panic("not implemented") }
func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") }

View file

@ -23,6 +23,10 @@ type Stream interface {
StreamID() protocol.StreamID
// Reset closes the stream with an error.
Reset(error)
// The context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
// SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call.
// A zero value for t means Read will not time out.

View file

@ -1,6 +1,7 @@
package quic
import (
"context"
"fmt"
"io"
"net"
@ -19,6 +20,9 @@ import (
type stream struct {
mutex sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
streamID protocol.StreamID
onData func()
// onReset is a callback that should send a RST_STREAM
@ -55,6 +59,8 @@ type stream struct {
flowControlManager flowcontrol.FlowControlManager
}
var _ Stream = &stream{}
type deadlineError struct{}
func (deadlineError) Error() string { return "deadline exceeded" }
@ -68,7 +74,7 @@ func newStream(StreamID protocol.StreamID,
onData func(),
onReset func(protocol.StreamID, protocol.ByteCount),
flowControlManager flowcontrol.FlowControlManager) *stream {
return &stream{
s := &stream{
onData: onData,
onReset: onReset,
streamID: StreamID,
@ -77,6 +83,8 @@ func newStream(StreamID protocol.StreamID,
readChan: make(chan struct{}, 1),
writeChan: make(chan struct{}, 1),
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
return s
}
// Read implements io.Reader. It is not thread safe!
@ -257,6 +265,7 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
// Close implements io.Closer
func (s *stream) Close() error {
s.finishedWriting.Set(true)
s.ctxCancel()
s.onData()
return nil
}
@ -352,6 +361,7 @@ func (s *stream) CloseRemote(offset protocol.ByteCount) {
func (s *stream) Cancel(err error) {
s.mutex.Lock()
s.cancelled.Set(true)
s.ctxCancel()
// errors must not be changed!
if s.err == nil {
s.err = err
@ -368,6 +378,7 @@ func (s *stream) Reset(err error) {
}
s.mutex.Lock()
s.resetLocally.Set(true)
s.ctxCancel()
// errors must not be changed!
if s.err == nil {
s.err = err
@ -388,6 +399,7 @@ func (s *stream) RegisterRemoteError(err error) {
}
s.mutex.Lock()
s.resetRemotely.Set(true)
s.ctxCancel()
// errors must not be changed!
if s.err == nil {
s.err = err
@ -412,6 +424,10 @@ func (s *stream) finished() bool {
(s.finishedWriteAndSentFin() && s.resetRemotely.Get())
}
func (s *stream) Context() context.Context {
return s.ctx
}
func (s *stream) StreamID() protocol.StreamID {
return s.streamID
}

View file

@ -435,6 +435,12 @@ var _ = Describe("Stream", func() {
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})
It("doesn't cancel the context", func() {
mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(0))
str.CloseRemote(0)
Expect(str.Context().Done()).ToNot(BeClosed())
})
})
})
@ -463,6 +469,12 @@ var _ = Describe("Stream", func() {
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
It("cancels the context", func() {
Expect(str.Context().Done()).ToNot(BeClosed())
str.Cancel(testErr)
Expect(str.Context().Done()).To(BeClosed())
})
})
})
@ -713,6 +725,12 @@ var _ = Describe("Stream", func() {
str.Reset(testErr)
Expect(resetCalled).To(BeFalse())
})
It("cancels the context", func() {
Expect(str.Context().Done()).ToNot(BeClosed())
str.Reset(testErr)
Expect(str.Context().Done()).To(BeClosed())
})
})
})
@ -958,6 +976,13 @@ var _ = Describe("Stream", func() {
Expect(str.finished()).To(BeFalse())
})
It("cancels the context after it is closed", func() {
Expect(str.Context().Done()).ToNot(BeClosed())
str.Close()
str.sentFin()
Expect(str.Context().Done()).To(BeClosed())
})
It("is not finished if it is only closed for reading", func() {
mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(0))
mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(0))
@ -972,6 +997,12 @@ var _ = Describe("Stream", func() {
Expect(str.finished()).To(BeTrue())
})
It("cancels the context after receiving a RST", func() {
Expect(str.Context().Done()).ToNot(BeClosed())
str.RegisterRemoteError(testErr)
Expect(str.Context().Done()).To(BeClosed())
})
It("is finished after being locally reset and receiving a RST in response", func() {
str.Reset(testErr)
Expect(str.finished()).To(BeFalse())

View file

@ -29,7 +29,7 @@ var _ = Describe("Streams Map", func() {
m = newStreamsMap(nil, p, mockCpm)
m.newStream = func(id protocol.StreamID) *stream {
return &stream{streamID: id}
return newStream(id, nil, nil, nil)
}
}