diff --git a/streams_map.go b/streams_map.go index 69797313..dba65ce4 100644 --- a/streams_map.go +++ b/streams_map.go @@ -311,6 +311,18 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { return nil } +// Range executes a callback for all streams, in pseudo-random order +func (m *streamsMap) Range(cb func(s *stream)) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + for _, s := range m.streams { + if s != nil { + cb(s) + } + } +} + func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) { str, ok := m.streams[streamID] if !ok { diff --git a/streams_map_test.go b/streams_map_test.go index f8d1c384..98692a66 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -2,6 +2,7 @@ package quic import ( "errors" + "sort" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/qerr" @@ -542,6 +543,28 @@ var _ = Describe("Streams Map", func() { }) }) + Context("Ranging", func() { + // create 5 streams, ids 4 to 8 + var callbackCalledForStream []protocol.StreamID + callback := func(str *stream) { + callbackCalledForStream = append(callbackCalledForStream, str.StreamID()) + sort.Slice(callbackCalledForStream, func(i, j int) bool { return callbackCalledForStream[i] < callbackCalledForStream[j] }) + } + + BeforeEach(func() { + callbackCalledForStream = callbackCalledForStream[:0] + for i := 4; i <= 8; i++ { + err := m.putStream(&stream{streamID: protocol.StreamID(i)}) + Expect(err).NotTo(HaveOccurred()) + } + }) + + It("ranges over all open streams", func() { + m.Range(callback) + Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8})) + }) + }) + Context("RoundRobinIterate", func() { // create 5 streams, ids 4 to 8 var lambdaCalledForStream []protocol.StreamID