diff --git a/streams_map.go b/streams_map.go new file mode 100644 index 00000000..da36dbc1 --- /dev/null +++ b/streams_map.go @@ -0,0 +1,59 @@ +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/protocol" +) + +type streamsMap struct { + streams map[protocol.StreamID]*stream + nStreams int + mutex sync.RWMutex +} + +func newStreamsMap() *streamsMap { + return &streamsMap{ + streams: map[protocol.StreamID]*stream{}, + } +} + +func (m *streamsMap) GetStream(id protocol.StreamID) (*stream, error) { + m.mutex.RLock() + s, ok := m.streams[id] + m.mutex.RUnlock() + if !ok { + return nil, fmt.Errorf("unknown stream: %d", id) + } + return s, nil +} + +func (m *streamsMap) PutStream(s *stream) error { + m.mutex.Lock() + defer m.mutex.Unlock() + if _, ok := m.streams[s.StreamID()]; ok { + return fmt.Errorf("a stream with ID %d already exists", s.StreamID()) + } + m.streams[s.StreamID()] = s + m.nStreams++ + return nil +} + +func (m *streamsMap) RemoveStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + s, ok := m.streams[id] + if !ok || s == nil { + return fmt.Errorf("attempted to remove non-existing stream: %d", id) + } + m.streams[id] = nil + m.nStreams-- + return nil +} + +func (m *streamsMap) NumberOfStreams() int { + m.mutex.RLock() + defer m.mutex.RUnlock() + return m.nStreams +} diff --git a/streams_map_test.go b/streams_map_test.go new file mode 100644 index 00000000..725a092c --- /dev/null +++ b/streams_map_test.go @@ -0,0 +1,60 @@ +package quic + +import ( + "github.com/lucas-clemente/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Streams Map", func() { + var ( + m *streamsMap + ) + + BeforeEach(func() { + m = newStreamsMap() + }) + + It("returns an error for non-existant streams", func() { + _, err := m.GetStream(1) + Expect(err).To(MatchError("unknown stream: 1")) + }) + + It("returns nil for previously existing streams", func() { + err := m.PutStream(&stream{streamID: 1}) + Expect(err).NotTo(HaveOccurred()) + err = m.RemoveStream(1) + Expect(err).NotTo(HaveOccurred()) + s, err := m.GetStream(1) + Expect(err).NotTo(HaveOccurred()) + Expect(s).To(BeNil()) + }) + + It("errors when removing non-existing stream", func() { + err := m.RemoveStream(1) + Expect(err).To(MatchError("attempted to remove non-existing stream: 1")) + }) + + It("stores streams", func() { + err := m.PutStream(&stream{streamID: 5}) + Expect(err).NotTo(HaveOccurred()) + s, err := m.GetStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(s.streamID).To(Equal(protocol.StreamID(5))) + }) + + It("does not store multiple streams with the same ID", func() { + err := m.PutStream(&stream{streamID: 5}) + Expect(err).NotTo(HaveOccurred()) + err = m.PutStream(&stream{streamID: 5}) + Expect(err).To(MatchError("a stream with ID 5 already exists")) + }) + + It("gets the number of streams", func() { + Expect(m.NumberOfStreams()).To(Equal(0)) + m.PutStream(&stream{streamID: 5}) + Expect(m.NumberOfStreams()).To(Equal(1)) + m.RemoveStream(5) + Expect(m.NumberOfStreams()).To(Equal(0)) + }) +})