directly queue RST_STREAM frames on stream.Reset

This commit is contained in:
Marten Seemann 2017-12-08 16:09:18 +07:00
parent 3679c56f7e
commit bd77f3081c
4 changed files with 57 additions and 66 deletions

View file

@ -849,14 +849,6 @@ func (s *session) WaitUntilHandshakeComplete() error {
return <-s.handshakeCompleteChan
}
func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) {
s.packer.QueueControlFrame(&wire.RstStreamFrame{
StreamID: id,
ByteOffset: offset,
})
s.scheduleSending()
}
func (s *session) newStream(id protocol.StreamID) streamI {
var initialSendWindow protocol.ByteCount
if s.peerParams != nil {
@ -871,7 +863,7 @@ func (s *session) newStream(id protocol.StreamID) streamI {
initialSendWindow,
s.rttStats,
)
return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController, s.version)
return newStream(id, s.scheduleSending, s.packer.QueueControlFrame, flowController, s.version)
}
func (s *session) newCryptoStream() cryptoStreamI {

View file

@ -357,15 +357,6 @@ var _ = Describe("Session", func() {
Expect(err).ToNot(HaveOccurred())
})
It("queues a RST_STERAM frame", func() {
sess.queueResetStreamFrame(5, 0x1337)
Expect(sess.packer.controlFrames).To(HaveLen(1))
Expect(sess.packer.controlFrames[0].(*wire.RstStreamFrame)).To(Equal(&wire.RstStreamFrame{
StreamID: 5,
ByteOffset: 0x1337,
}))
})
It("returns errors", func() {
testErr := errors.New("flow control violation")
str, err := sess.GetOrOpenStream(5)

View file

@ -38,9 +38,11 @@ type stream struct {
ctxCancel context.CancelFunc
streamID protocol.StreamID
onData func()
// onReset is a callback that should send a RST_STREAM
onReset func(protocol.StreamID, protocol.ByteCount)
// onData tells the session that there's stuff to pack into a new packet
onData func()
// queueControlFrame queues a new control frame for sending
// it does not call onData
queueControlFrame func(wire.Frame)
readPosInFrame int
writeOffset protocol.ByteCount
@ -88,19 +90,19 @@ var errDeadline net.Error = &deadlineError{}
// newStream creates a new Stream
func newStream(StreamID protocol.StreamID,
onData func(),
onReset func(protocol.StreamID, protocol.ByteCount),
queueControlFrame func(wire.Frame),
flowController flowcontrol.StreamFlowController,
version protocol.VersionNumber,
) *stream {
s := &stream{
onData: onData,
onReset: onReset,
streamID: StreamID,
flowController: flowController,
frameQueue: newStreamFrameSorter(),
readChan: make(chan struct{}, 1),
writeChan: make(chan struct{}, 1),
version: version,
onData: onData,
queueControlFrame: queueControlFrame,
streamID: StreamID,
flowController: flowController,
frameQueue: newStreamFrameSorter(),
readChan: make(chan struct{}, 1),
writeChan: make(chan struct{}, 1),
version: version,
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
return s
@ -421,7 +423,11 @@ func (s *stream) Reset(err error) {
s.signalWrite()
}
if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset)
s.queueControlFrame(&wire.RstStreamFrame{
StreamID: s.streamID,
ByteOffset: s.writeOffset,
})
s.onData()
s.rstSent.Set(true)
}
s.mutex.Unlock()
@ -444,7 +450,11 @@ func (s *stream) RegisterRemoteError(err error, offset protocol.ByteCount) error
return err
}
if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset)
s.queueControlFrame(&wire.RstStreamFrame{
StreamID: s.streamID,
ByteOffset: s.writeOffset,
})
s.onData()
s.rstSent.Set(true)
}
s.mutex.Unlock()

View file

@ -28,9 +28,7 @@ var _ = Describe("Stream", func() {
strWithTimeout io.ReadWriter // str wrapped with gbytes.Timeout{Reader,Writer}
onDataCalled bool
resetCalled bool
resetCalledForStream protocol.StreamID
resetCalledAtOffset protocol.ByteCount
queuedControlFrames []wire.Frame
mockFC *mocks.MockStreamFlowController
)
@ -51,17 +49,15 @@ var _ = Describe("Stream", func() {
onDataCalled = true
}
onReset := func(id protocol.StreamID, offset protocol.ByteCount) {
resetCalled = true
resetCalledForStream = id
resetCalledAtOffset = offset
queueControlFrame := func(f wire.Frame) {
queuedControlFrames = append(queuedControlFrames, f)
}
BeforeEach(func() {
queuedControlFrames = queuedControlFrames[:0]
onDataCalled = false
resetCalled = false
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
str = newStream(streamID, onData, onReset, mockFC, protocol.VersionWhatever)
str = newStream(streamID, onData, queueControlFrame, mockFC, protocol.VersionWhatever)
timeout := scaleDuration(250 * time.Millisecond)
strWithTimeout = struct {
@ -631,9 +627,12 @@ var _ = Describe("Stream", func() {
close(done)
}()
str.RegisterRemoteError(testErr, 0)
Expect(resetCalled).To(BeTrue())
Expect(resetCalledForStream).To(Equal(protocol.StreamID(1337)))
Expect(resetCalledAtOffset).To(Equal(protocol.ByteCount(0x1000)))
Expect(queuedControlFrames).To(Equal([]wire.Frame{
&wire.RstStreamFrame{
StreamID: 1337,
ByteOffset: 0x1000,
},
}))
Eventually(done).Should(BeClosed())
})
@ -643,25 +642,23 @@ var _ = Describe("Stream", func() {
f := str.PopStreamFrame(100)
Expect(f.FinBit).To(BeTrue())
str.RegisterRemoteError(testErr, 0)
Expect(resetCalled).To(BeFalse())
Expect(queuedControlFrames).To(BeEmpty())
})
It("doesn't call onReset if the stream was reset locally before", func() {
It("doesn't call queue a RST_STREAM if the stream was reset locally before", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
str.Reset(testErr)
Expect(resetCalled).To(BeTrue())
resetCalled = false
Expect(queuedControlFrames).To(HaveLen(1))
str.RegisterRemoteError(testErr, 0)
Expect(resetCalled).To(BeFalse())
Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame
})
It("doesn't call onReset twice, when it gets two remote errors", func() {
It("doesn't queue two RST_STREAMs twice, when it gets two remote errors", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
str.RegisterRemoteError(testErr, 0)
Expect(resetCalled).To(BeTrue())
resetCalled = false
Expect(queuedControlFrames).To(HaveLen(1))
str.RegisterRemoteError(testErr, 0)
Expect(resetCalled).To(BeFalse())
Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame
})
})
@ -716,37 +713,38 @@ var _ = Describe("Stream", func() {
Expect(err).To(MatchError(testErr))
})
It("calls onReset", func() {
It("queues a RST_STREAM frame", func() {
str.writeOffset = 0x1000
str.Reset(testErr)
Expect(resetCalled).To(BeTrue())
Expect(resetCalledForStream).To(Equal(protocol.StreamID(1337)))
Expect(resetCalledAtOffset).To(Equal(protocol.ByteCount(0x1000)))
Expect(queuedControlFrames).To(Equal([]wire.Frame{
&wire.RstStreamFrame{
StreamID: 1337,
ByteOffset: 0x1000,
},
}))
})
It("doesn't call onReset if it already sent a FIN", func() {
It("doesn't queue a RST_STREAM if it already sent a FIN", func() {
str.Close()
f := str.PopStreamFrame(1000)
Expect(f.FinBit).To(BeTrue())
str.Reset(testErr)
Expect(resetCalled).To(BeFalse())
Expect(queuedControlFrames).To(BeEmpty())
})
It("doesn't call onReset if the stream was reset remotely before", func() {
It("doesn't queue a new RST_STREAM, if the stream was reset remotely before", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
str.RegisterRemoteError(testErr, 0)
Expect(resetCalled).To(BeTrue())
resetCalled = false
Expect(queuedControlFrames).To(HaveLen(1))
str.Reset(testErr)
Expect(resetCalled).To(BeFalse())
Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame
})
It("doesn't call onReset twice", func() {
str.Reset(testErr)
Expect(resetCalled).To(BeTrue())
resetCalled = false
Expect(queuedControlFrames).To(HaveLen(1))
str.Reset(testErr)
Expect(resetCalled).To(BeFalse())
Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame
})
It("cancels the context", func() {