add a streamsMap class

This commit is contained in:
Lucas Clemente 2016-07-11 13:09:04 +02:00 committed by Marten Seemann
parent 716937d1c2
commit b3e76770de
2 changed files with 119 additions and 0 deletions

59
streams_map.go Normal file
View file

@ -0,0 +1,59 @@
package quic
import (
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/protocol"
)
type streamsMap struct {
streams map[protocol.StreamID]*stream
nStreams int
mutex sync.RWMutex
}
func newStreamsMap() *streamsMap {
return &streamsMap{
streams: map[protocol.StreamID]*stream{},
}
}
func (m *streamsMap) GetStream(id protocol.StreamID) (*stream, error) {
m.mutex.RLock()
s, ok := m.streams[id]
m.mutex.RUnlock()
if !ok {
return nil, fmt.Errorf("unknown stream: %d", id)
}
return s, nil
}
func (m *streamsMap) PutStream(s *stream) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, ok := m.streams[s.StreamID()]; ok {
return fmt.Errorf("a stream with ID %d already exists", s.StreamID())
}
m.streams[s.StreamID()] = s
m.nStreams++
return nil
}
func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
m.mutex.Lock()
defer m.mutex.Unlock()
s, ok := m.streams[id]
if !ok || s == nil {
return fmt.Errorf("attempted to remove non-existing stream: %d", id)
}
m.streams[id] = nil
m.nStreams--
return nil
}
func (m *streamsMap) NumberOfStreams() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.nStreams
}

60
streams_map_test.go Normal file
View file

@ -0,0 +1,60 @@
package quic
import (
"github.com/lucas-clemente/quic-go/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Streams Map", func() {
var (
m *streamsMap
)
BeforeEach(func() {
m = newStreamsMap()
})
It("returns an error for non-existant streams", func() {
_, err := m.GetStream(1)
Expect(err).To(MatchError("unknown stream: 1"))
})
It("returns nil for previously existing streams", func() {
err := m.PutStream(&stream{streamID: 1})
Expect(err).NotTo(HaveOccurred())
err = m.RemoveStream(1)
Expect(err).NotTo(HaveOccurred())
s, err := m.GetStream(1)
Expect(err).NotTo(HaveOccurred())
Expect(s).To(BeNil())
})
It("errors when removing non-existing stream", func() {
err := m.RemoveStream(1)
Expect(err).To(MatchError("attempted to remove non-existing stream: 1"))
})
It("stores streams", func() {
err := m.PutStream(&stream{streamID: 5})
Expect(err).NotTo(HaveOccurred())
s, err := m.GetStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s.streamID).To(Equal(protocol.StreamID(5)))
})
It("does not store multiple streams with the same ID", func() {
err := m.PutStream(&stream{streamID: 5})
Expect(err).NotTo(HaveOccurred())
err = m.PutStream(&stream{streamID: 5})
Expect(err).To(MatchError("a stream with ID 5 already exists"))
})
It("gets the number of streams", func() {
Expect(m.NumberOfStreams()).To(Equal(0))
m.PutStream(&stream{streamID: 5})
Expect(m.NumberOfStreams()).To(Equal(1))
m.RemoveStream(5)
Expect(m.NumberOfStreams()).To(Equal(0))
})
})