diff --git a/codecov.yml b/codecov.yml index 91e1dbe2..560f7ab4 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,6 +1,8 @@ coverage: round: nearest ignore: + - streams_map_incoming_bidi.go + - streams_map_incoming_uni.go - streams_map_outgoing_bidi.go - streams_map_outgoing_uni.go - h2quic/gzipreader.go diff --git a/streams_map_incoming_bidi.go b/streams_map_incoming_bidi.go new file mode 100644 index 00000000..774bf1a6 --- /dev/null +++ b/streams_map_incoming_bidi.go @@ -0,0 +1,101 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type incomingBidiStreamsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]streamI + + nextStream protocol.StreamID + highestStream protocol.StreamID + newStream func(protocol.StreamID) streamI + + closeErr error +} + +func newIncomingBidiStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) streamI) *incomingBidiStreamsMap { + m := &incomingBidiStreamsMap{ + streams: make(map[protocol.StreamID]streamI), + nextStream: nextStream, + newStream: newStream, + } + m.cond.L = &m.mutex + return m +} + +func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + var str streamI + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStream] + if ok { + break + } + m.cond.Wait() + } + m.nextStream += 4 + return str, nil +} + +func (m *incomingBidiStreamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { + // if the id is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if id <= m.highestStream { + m.mutex.RLock() + s := m.streams[id] + m.mutex.RUnlock() + return s, nil + } + + m.mutex.Lock() + var start protocol.StreamID + if m.highestStream == 0 { + start = m.nextStream + } else { + start = m.highestStream + 4 + } + for newID := start; newID <= id; newID += 4 { + m.streams[newID] = m.newStream(newID) + m.cond.Signal() + } + m.highestStream = id + s := m.streams[id] + m.mutex.Unlock() + return s, nil +} + +func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + return nil +} + +func (m *incomingBidiStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + m.mutex.Unlock() + m.cond.Broadcast() +} diff --git a/streams_map_incoming_generic.go b/streams_map_incoming_generic.go new file mode 100644 index 00000000..e03311c6 --- /dev/null +++ b/streams_map_incoming_generic.go @@ -0,0 +1,99 @@ +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream" +//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream" +type incomingItemsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]item + + nextStream protocol.StreamID + highestStream protocol.StreamID + newStream func(protocol.StreamID) item + + closeErr error +} + +func newIncomingItemsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) item) *incomingItemsMap { + m := &incomingItemsMap{ + streams: make(map[protocol.StreamID]item), + nextStream: nextStream, + newStream: newStream, + } + m.cond.L = &m.mutex + return m +} + +func (m *incomingItemsMap) AcceptStream() (item, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + var str item + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStream] + if ok { + break + } + m.cond.Wait() + } + m.nextStream += 4 + return str, nil +} + +func (m *incomingItemsMap) GetOrOpenStream(id protocol.StreamID) (item, error) { + // if the id is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if id <= m.highestStream { + m.mutex.RLock() + s := m.streams[id] + m.mutex.RUnlock() + return s, nil + } + + m.mutex.Lock() + var start protocol.StreamID + if m.highestStream == 0 { + start = m.nextStream + } else { + start = m.highestStream + 4 + } + for newID := start; newID <= id; newID += 4 { + m.streams[newID] = m.newStream(newID) + m.cond.Signal() + } + m.highestStream = id + s := m.streams[id] + m.mutex.Unlock() + return s, nil +} + +func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + return nil +} + +func (m *incomingItemsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + m.mutex.Unlock() + m.cond.Broadcast() +} diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go new file mode 100644 index 00000000..6b2b9b43 --- /dev/null +++ b/streams_map_incoming_generic_test.go @@ -0,0 +1,106 @@ +package quic + +import ( + "errors" + + "github.com/lucas-clemente/quic-go/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Streams Map (outgoing)", func() { + const firstNewStream protocol.StreamID = 20 + var ( + m *incomingItemsMap + newItem func(id protocol.StreamID) item + newItemCounter int + ) + + BeforeEach(func() { + newItemCounter = 0 + newItem = func(id protocol.StreamID) item { + newItemCounter++ + return id + } + m = newIncomingItemsMap(firstNewStream, newItem) + }) + + It("opens all streams up to the id on GetOrOpenStream", func() { + _, err := m.GetOrOpenStream(firstNewStream + 4*5) + Expect(err).ToNot(HaveOccurred()) + Expect(newItemCounter).To(Equal(6)) + }) + + It("starts opening streams at the right position", func() { + // like the test above, but with 2 calls to GetOrOpenStream + _, err := m.GetOrOpenStream(firstNewStream + 4) + Expect(err).ToNot(HaveOccurred()) + Expect(newItemCounter).To(Equal(2)) + _, err = m.GetOrOpenStream(firstNewStream + 4*5) + Expect(err).ToNot(HaveOccurred()) + Expect(newItemCounter).To(Equal(6)) + }) + + It("accepts streams in the right order", func() { + _, err := m.GetOrOpenStream(firstNewStream + 4) // open stream 20 and 24 + Expect(err).ToNot(HaveOccurred()) + str, err := m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(firstNewStream)) + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(firstNewStream + 4)) + }) + + It("blocks AcceptStream until a new stream is available", func() { + strChan := make(chan item) + go func() { + defer GinkgoRecover() + str, err := m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + strChan <- str + }() + Consistently(strChan).ShouldNot(Receive()) + str, err := m.GetOrOpenStream(firstNewStream) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(firstNewStream)) + Eventually(strChan).Should(Receive(Equal(firstNewStream))) + }) + + It("unblocks AcceptStream when it is closed", func() { + testErr := errors.New("test error") + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.AcceptStream() + Expect(err).To(MatchError(testErr)) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + m.CloseWithError(testErr) + Eventually(done).Should(BeClosed()) + }) + + It("errors AcceptStream immediately if it is closed", func() { + testErr := errors.New("test error") + m.CloseWithError(testErr) + _, err := m.AcceptStream() + Expect(err).To(MatchError(testErr)) + }) + + It("deletes streams", func() { + _, err := m.GetOrOpenStream(20) + Expect(err).ToNot(HaveOccurred()) + err = m.DeleteStream(20) + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetOrOpenStream(20) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + + It("errors when deleting a non-existing stream", func() { + err := m.DeleteStream(1337) + Expect(err).To(MatchError("Tried to delete unknown stream 1337")) + }) +}) diff --git a/streams_map_incoming_uni.go b/streams_map_incoming_uni.go new file mode 100644 index 00000000..7cf57afb --- /dev/null +++ b/streams_map_incoming_uni.go @@ -0,0 +1,101 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type incomingUniStreamsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]receiveStreamI + + nextStream protocol.StreamID + highestStream protocol.StreamID + newStream func(protocol.StreamID) receiveStreamI + + closeErr error +} + +func newIncomingUniStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) receiveStreamI) *incomingUniStreamsMap { + m := &incomingUniStreamsMap{ + streams: make(map[protocol.StreamID]receiveStreamI), + nextStream: nextStream, + newStream: newStream, + } + m.cond.L = &m.mutex + return m +} + +func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + var str receiveStreamI + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStream] + if ok { + break + } + m.cond.Wait() + } + m.nextStream += 4 + return str, nil +} + +func (m *incomingUniStreamsMap) GetOrOpenStream(id protocol.StreamID) (receiveStreamI, error) { + // if the id is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if id <= m.highestStream { + m.mutex.RLock() + s := m.streams[id] + m.mutex.RUnlock() + return s, nil + } + + m.mutex.Lock() + var start protocol.StreamID + if m.highestStream == 0 { + start = m.nextStream + } else { + start = m.highestStream + 4 + } + for newID := start; newID <= id; newID += 4 { + m.streams[newID] = m.newStream(newID) + m.cond.Signal() + } + m.highestStream = id + s := m.streams[id] + m.mutex.Unlock() + return s, nil +} + +func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + return nil +} + +func (m *incomingUniStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + m.mutex.Unlock() + m.cond.Broadcast() +}