diff --git a/streams_map_generic_helper.go b/streams_map_generic_helper.go new file mode 100644 index 00000000..f48db212 --- /dev/null +++ b/streams_map_generic_helper.go @@ -0,0 +1,11 @@ +package quic + +import "github.com/cheekybits/genny/generic" + +// In the auto-generated streams maps, we need to be able to close the streams. +// Therefore, extend the generic.Type with the stream close method. +// This definition must be in a file that Genny doesn't process. +type item interface { + generic.Type + closeForShutdown(error) +} diff --git a/streams_map_incoming_bidi.go b/streams_map_incoming_bidi.go index f7f80a28..317f5e23 100644 --- a/streams_map_incoming_bidi.go +++ b/streams_map_incoming_bidi.go @@ -123,6 +123,9 @@ func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error { func (m *incomingBidiStreamsMap) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } m.mutex.Unlock() m.cond.Broadcast() } diff --git a/streams_map_incoming_generic.go b/streams_map_incoming_generic.go index 40d7b750..58f1ccbe 100644 --- a/streams_map_incoming_generic.go +++ b/streams_map_incoming_generic.go @@ -121,6 +121,9 @@ func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error { func (m *incomingItemsMap) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } m.mutex.Unlock() m.cond.Broadcast() } diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index 5d66fcd1..e9d8dbb4 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -12,7 +12,19 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Streams Map (outgoing)", func() { +type mockGenericStream struct { + id protocol.StreamID + + closed bool + closeErr error +} + +func (s *mockGenericStream) closeForShutdown(err error) { + s.closed = true + s.closeErr = err +} + +var _ = Describe("Streams Map (incoming)", func() { const ( firstNewStream protocol.StreamID = 20 maxNumStreams int = 10 @@ -30,7 +42,7 @@ var _ = Describe("Streams Map (outgoing)", func() { newItemCounter = 0 newItem = func(id protocol.StreamID) item { newItemCounter++ - return id + return &mockGenericStream{id: id} } mockSender = NewMockStreamSender(mockCtrl) m = newIncomingItemsMap(firstNewStream, initialMaxStream, maxNumStreams, mockSender.queueControlFrame, newItem) @@ -57,16 +69,16 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(err).ToNot(HaveOccurred()) str, err := m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(firstNewStream)) + Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(firstNewStream + 4)) + Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream + 4)) }) It("allows opening the maximum stream ID", func() { str, err := m.GetOrOpenStream(initialMaxStream) Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(initialMaxStream)) + Expect(str.(*mockGenericStream).id).To(Equal(initialMaxStream)) }) It("errors when trying to get a stream ID higher than the maximum", func() { @@ -85,8 +97,10 @@ var _ = Describe("Streams Map (outgoing)", func() { Consistently(strChan).ShouldNot(Receive()) str, err := m.GetOrOpenStream(firstNewStream) Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(firstNewStream)) - Eventually(strChan).Should(Receive(Equal(firstNewStream))) + Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) + var acceptedStr item + Eventually(strChan).Should(Receive(&acceptedStr)) + Expect(acceptedStr.(*mockGenericStream).id).To(Equal(firstNewStream)) }) It("unblocks AcceptStream when it is closed", func() { @@ -110,6 +124,19 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(err).To(MatchError(testErr)) }) + It("closes all streams when CloseWithError is called", func() { + str1, err := m.GetOrOpenStream(20) + Expect(err).ToNot(HaveOccurred()) + str2, err := m.GetOrOpenStream(20 + 8) + Expect(err).ToNot(HaveOccurred()) + testErr := errors.New("test err") + m.CloseWithError(testErr) + Expect(str1.(*mockGenericStream).closed).To(BeTrue()) + Expect(str1.(*mockGenericStream).closeErr).To(MatchError(testErr)) + Expect(str2.(*mockGenericStream).closed).To(BeTrue()) + Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr)) + }) + It("deletes streams", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) _, err := m.GetOrOpenStream(20) diff --git a/streams_map_incoming_uni.go b/streams_map_incoming_uni.go index f9fe97b7..8e775aac 100644 --- a/streams_map_incoming_uni.go +++ b/streams_map_incoming_uni.go @@ -123,6 +123,9 @@ func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error { func (m *incomingUniStreamsMap) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } m.mutex.Unlock() m.cond.Broadcast() } diff --git a/streams_map_outgoing_bidi.go b/streams_map_outgoing_bidi.go index e7d9150b..ea9f47e6 100644 --- a/streams_map_outgoing_bidi.go +++ b/streams_map_outgoing_bidi.go @@ -118,6 +118,9 @@ func (m *outgoingBidiStreamsMap) SetMaxStream(id protocol.StreamID) { func (m *outgoingBidiStreamsMap) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } m.cond.Broadcast() m.mutex.Unlock() } diff --git a/streams_map_outgoing_generic.go b/streams_map_outgoing_generic.go index 80236c15..f4b3eb61 100644 --- a/streams_map_outgoing_generic.go +++ b/streams_map_outgoing_generic.go @@ -4,14 +4,11 @@ import ( "fmt" "sync" - "github.com/cheekybits/genny/generic" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) -type item generic.Type - //go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream" //go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream" type outgoingItemsMap struct { @@ -119,6 +116,9 @@ func (m *outgoingItemsMap) SetMaxStream(id protocol.StreamID) { func (m *outgoingItemsMap) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } m.cond.Broadcast() m.mutex.Unlock() } diff --git a/streams_map_outgoing_generic_test.go b/streams_map_outgoing_generic_test.go index eb28784a..c8426899 100644 --- a/streams_map_outgoing_generic_test.go +++ b/streams_map_outgoing_generic_test.go @@ -21,7 +21,7 @@ var _ = Describe("Streams Map (outgoing)", func() { BeforeEach(func() { newItem = func(id protocol.StreamID) item { - return id + return &mockGenericStream{id: id} } mockSender = NewMockStreamSender(mockCtrl) m = newOutgoingItemsMap(firstNewStream, newItem, mockSender.queueControlFrame) @@ -35,10 +35,10 @@ var _ = Describe("Streams Map (outgoing)", func() { It("opens streams", func() { str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(firstNewStream)) + Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) str, err = m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(firstNewStream + 4)) + Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream + 4)) }) It("doesn't open streams after it has been closed", func() { @@ -53,7 +53,7 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(err).ToNot(HaveOccurred()) str, err := m.GetStream(firstNewStream) Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(firstNewStream)) + Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) }) It("errors when trying to get a stream that has not yet been opened", func() { @@ -84,6 +84,19 @@ var _ = Describe("Streams Map (outgoing)", func() { err = m.DeleteStream(10) Expect(err).To(MatchError("Tried to delete unknown stream 10")) }) + + It("closes all streams when CloseWithError is called", func() { + str1, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + str2, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + testErr := errors.New("test err") + m.CloseWithError(testErr) + Expect(str1.(*mockGenericStream).closed).To(BeTrue()) + Expect(str1.(*mockGenericStream).closeErr).To(MatchError(testErr)) + Expect(str2.(*mockGenericStream).closed).To(BeTrue()) + Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr)) + }) }) Context("with stream ID limits", func() { @@ -100,7 +113,7 @@ var _ = Describe("Streams Map (outgoing)", func() { defer GinkgoRecover() str, err := m.OpenStreamSync() Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(firstNewStream)) + Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) close(done) }() @@ -130,7 +143,7 @@ var _ = Describe("Streams Map (outgoing)", func() { m.SetMaxStream(firstNewStream - 4) str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(firstNewStream)) + Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream)) }) It("queues a STREAM_ID_BLOCKED frame if no stream can be opened", func() { diff --git a/streams_map_outgoing_uni.go b/streams_map_outgoing_uni.go index fd2701b7..6ad0348c 100644 --- a/streams_map_outgoing_uni.go +++ b/streams_map_outgoing_uni.go @@ -118,6 +118,9 @@ func (m *outgoingUniStreamsMap) SetMaxStream(id protocol.StreamID) { func (m *outgoingUniStreamsMap) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } m.cond.Broadcast() m.mutex.Unlock() }