pass RST_STREAM frames directly to the stream

This commit is contained in:
Marten Seemann 2017-12-12 10:12:54 +07:00
parent 2d31440510
commit 03977c1a25
5 changed files with 141 additions and 75 deletions

View file

@ -478,8 +478,6 @@ var _ = Describe("Stream", func() {
})
Context("resetting", func() {
testErr := errors.New("testErr")
Context("reset by the peer", func() {
It("continues reading after receiving a remote error", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
@ -489,22 +487,30 @@ var _ = Describe("Stream", func() {
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
}
str.HandleStreamFrame(&frame)
str.RegisterRemoteError(testErr, 10)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 10,
})
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
})
It("reads a delayed StreamFrame that arrives after receiving a remote error", func() {
It("reads a delayed STREAM frame that arrives after receiving a remote error", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
str.RegisterRemoteError(testErr, 4)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 4,
})
Expect(err).ToNot(HaveOccurred())
frame := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
}
err := str.HandleStreamFrame(&frame)
err = str.HandleStreamFrame(&frame)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
@ -520,15 +526,20 @@ var _ = Describe("Stream", func() {
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
}
str.HandleStreamFrame(&frame)
str.RegisterRemoteError(testErr, 8)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 8,
ErrorCode: 1337,
})
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 10)
n, err := strWithTimeout.Read(b)
Expect(b[0:4]).To(Equal(frame.Data))
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError("RST_STREAM received with code 1337"))
Expect(n).To(Equal(4))
})
It("returns an EOF when reading past the offset, if the stream received a finbit", func() {
It("returns an EOF when reading past the offset, if the stream received a FIN bit", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), true)
frame := wire.StreamFrame{
@ -537,7 +548,11 @@ var _ = Describe("Stream", func() {
FinBit: true,
}
str.HandleStreamFrame(&frame)
str.RegisterRemoteError(testErr, 8)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 8,
})
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 10)
n, err := strWithTimeout.Read(b)
Expect(b[:4]).To(Equal(frame.Data))
@ -554,9 +569,12 @@ var _ = Describe("Stream", func() {
FinBit: true,
}
str.HandleStreamFrame(&frame)
str.RegisterRemoteError(testErr, 4)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 4,
})
b := make([]byte, 3)
_, err := strWithTimeout.Read(b)
_, err = strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal([]byte{0xde, 0xad, 0xbe}))
b = make([]byte, 3)
@ -576,27 +594,36 @@ var _ = Describe("Stream", func() {
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
}
str.HandleStreamFrame(&frame)
str.RegisterRemoteError(testErr, 10)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 10,
})
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 3)
_, err := strWithTimeout.Read(b)
_, err = strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
})
It("stops writing after receiving a remote error", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), true)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError("RST_STREAM received with code 1337"))
close(done)
}()
str.RegisterRemoteError(testErr, 10)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 8,
ErrorCode: 1337,
})
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
})
It("returns how much was written when recieving a remote error", func() {
It("returns how much was written when receiving a remote error", func() {
frameHeaderSize := protocol.ByteCount(4)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
@ -605,7 +632,7 @@ var _ = Describe("Stream", func() {
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError("RST_STREAM received with code 1337"))
Expect(n).To(Equal(4))
close(done)
}()
@ -614,22 +641,31 @@ var _ = Describe("Stream", func() {
Eventually(func() *wire.StreamFrame { frame = str.PopStreamFrame(4 + frameHeaderSize); return frame }).ShouldNot(BeNil())
Expect(frame).ToNot(BeNil())
Expect(frame.DataLen()).To(BeEquivalentTo(4))
str.RegisterRemoteError(testErr, 10)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 10,
ErrorCode: 1337,
})
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
})
It("calls onReset when receiving a remote error", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
It("calls queues a RST_STREAM frame when receiving a remote error", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true)
done := make(chan struct{})
str.writeOffset = 0x1000
go func() {
_, _ = strWithTimeout.Write([]byte("foobar"))
close(done)
}()
str.RegisterRemoteError(testErr, 0)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 10,
})
Expect(err).ToNot(HaveOccurred())
Expect(queuedControlFrames).To(Equal([]wire.Frame{
&wire.RstStreamFrame{
StreamID: 1337,
StreamID: streamID,
ByteOffset: 0x1000,
},
}))
@ -637,32 +673,50 @@ var _ = Describe("Stream", func() {
})
It("doesn't call onReset if it already sent a FIN", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true)
str.Close()
f := str.PopStreamFrame(100)
Expect(f.FinBit).To(BeTrue())
str.RegisterRemoteError(testErr, 0)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 10,
})
Expect(err).ToNot(HaveOccurred())
Expect(queuedControlFrames).To(BeEmpty())
})
It("doesn't call queue a RST_STREAM if the stream was reset locally before", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
str.Reset(testErr)
It("doesn't queue a RST_STREAM if the stream was reset locally before", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true)
str.Reset(errors.New("reset"))
Expect(queuedControlFrames).To(HaveLen(1))
str.RegisterRemoteError(testErr, 0)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 10,
})
Expect(err).ToNot(HaveOccurred())
Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame
})
It("doesn't queue two RST_STREAMs twice, when it gets two remote errors", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
str.RegisterRemoteError(testErr, 0)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), true)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 8,
})
Expect(err).ToNot(HaveOccurred())
Expect(queuedControlFrames).To(HaveLen(1))
err = str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 9,
})
Expect(err).ToNot(HaveOccurred())
Expect(queuedControlFrames).To(HaveLen(1))
str.RegisterRemoteError(testErr, 0)
Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame
})
})
Context("reset locally", func() {
testErr := errors.New("test error")
It("stops writing", func() {
done := make(chan struct{})
go func() {
@ -733,11 +787,14 @@ var _ = Describe("Stream", func() {
})
It("doesn't queue a new RST_STREAM, if the stream was reset remotely before", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
str.RegisterRemoteError(testErr, 0)
Expect(queuedControlFrames).To(HaveLen(1))
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true)
err := str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 10,
})
Expect(err).ToNot(HaveOccurred())
str.Reset(testErr)
Expect(queuedControlFrames).To(HaveLen(1)) // no additional queued frame
Expect(queuedControlFrames).To(HaveLen(1))
})
It("doesn't call onReset twice", func() {
@ -1037,7 +1094,10 @@ var _ = Describe("Stream", func() {
It("is finished after receiving a RST and sending one", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
// this directly sends a rst
str.RegisterRemoteError(testErr, 0)
str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 0,
})
Expect(str.rstSent.Get()).To(BeTrue())
Expect(str.Finished()).To(BeTrue())
})
@ -1045,7 +1105,10 @@ var _ = Describe("Stream", func() {
It("cancels the context after receiving a RST", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
Expect(str.Context().Done()).ToNot(BeClosed())
str.RegisterRemoteError(testErr, 0)
str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 0,
})
Expect(str.Context().Done()).To(BeClosed())
})
@ -1053,7 +1116,10 @@ var _ = Describe("Stream", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(13), true)
str.Reset(testErr)
Expect(str.Finished()).To(BeFalse())
str.RegisterRemoteError(testErr, 13)
str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 13,
})
Expect(str.Finished()).To(BeTrue())
})
@ -1062,7 +1128,10 @@ var _ = Describe("Stream", func() {
str.Close()
f := str.PopStreamFrame(1000)
Expect(f.FinBit).To(BeTrue())
str.RegisterRemoteError(testErr, 13)
str.HandleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 13,
})
Expect(str.Finished()).To(BeTrue())
})