mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
remove stream ID from OpenStream() method
This commit is contained in:
parent
8cd1e4484c
commit
f47142eaac
9 changed files with 424 additions and 423 deletions
|
@ -109,8 +109,8 @@ func (c *Client) Listen() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs)
|
// OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs)
|
||||||
func (c *Client) OpenStream(id protocol.StreamID) (utils.Stream, error) {
|
func (c *Client) OpenStream() (utils.Stream, error) {
|
||||||
return c.session.OpenStream(id)
|
return c.session.OpenStream()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the connection
|
// Close closes the connection
|
||||||
|
|
|
@ -109,9 +109,9 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("opens a stream", func() {
|
It("opens a stream", func() {
|
||||||
stream, err := client.OpenStream(1337)
|
stream, err := client.OpenStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(stream.StreamID()).To(Equal(protocol.StreamID(1337)))
|
Expect(stream).ToNot(BeNil())
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("handling packets", func() {
|
Context("handling packets", func() {
|
||||||
|
|
|
@ -21,7 +21,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type quicClient interface {
|
type quicClient interface {
|
||||||
OpenStream(protocol.StreamID) (utils.Stream, error)
|
OpenStream() (utils.Stream, error)
|
||||||
Close(error) error
|
Close(error) error
|
||||||
Listen() error
|
Listen() error
|
||||||
}
|
}
|
||||||
|
@ -36,11 +36,10 @@ type Client struct {
|
||||||
hostname string
|
hostname string
|
||||||
encryptionLevel protocol.EncryptionLevel
|
encryptionLevel protocol.EncryptionLevel
|
||||||
|
|
||||||
client quicClient
|
client quicClient
|
||||||
headerStream utils.Stream
|
headerStream utils.Stream
|
||||||
headerErr *qerr.QuicError
|
headerErr *qerr.QuicError
|
||||||
highestOpenedStream protocol.StreamID
|
requestWriter *requestWriter
|
||||||
requestWriter *requestWriter
|
|
||||||
|
|
||||||
responses map[protocol.StreamID]chan *http.Response
|
responses map[protocol.StreamID]chan *http.Response
|
||||||
}
|
}
|
||||||
|
@ -50,10 +49,9 @@ var _ h2quicClient = &Client{}
|
||||||
// NewClient creates a new client
|
// NewClient creates a new client
|
||||||
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) {
|
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) {
|
||||||
c := &Client{
|
c := &Client{
|
||||||
t: t,
|
t: t,
|
||||||
hostname: authorityAddr("https", hostname),
|
hostname: authorityAddr("https", hostname),
|
||||||
highestOpenedStream: 3,
|
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
|
||||||
}
|
}
|
||||||
c.cryptoChangedCond = sync.Cond{L: &c.mutex}
|
c.cryptoChangedCond = sync.Cond{L: &c.mutex}
|
||||||
|
|
||||||
|
@ -88,10 +86,13 @@ func (c *Client) cryptoChangeCallback(isForwardSecure bool) {
|
||||||
func (c *Client) versionNegotiateCallback() error {
|
func (c *Client) versionNegotiateCallback() error {
|
||||||
var err error
|
var err error
|
||||||
// once the version has been negotiated, open the header stream
|
// once the version has been negotiated, open the header stream
|
||||||
c.headerStream, err = c.client.OpenStream(3)
|
c.headerStream, err = c.client.OpenStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if c.headerStream.StreamID() != 3 {
|
||||||
|
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
|
||||||
|
}
|
||||||
c.requestWriter = newRequestWriter(c.headerStream)
|
c.requestWriter = newRequestWriter(c.headerStream)
|
||||||
go c.handleHeaderStream()
|
go c.handleHeaderStream()
|
||||||
return nil
|
return nil
|
||||||
|
@ -160,21 +161,18 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||||
hasBody := (req.Body != nil)
|
hasBody := (req.Body != nil)
|
||||||
|
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
c.highestOpenedStream += 2
|
|
||||||
dataStreamID := c.highestOpenedStream
|
|
||||||
for c.encryptionLevel != protocol.EncryptionForwardSecure {
|
for c.encryptionLevel != protocol.EncryptionForwardSecure {
|
||||||
c.cryptoChangedCond.Wait()
|
c.cryptoChangedCond.Wait()
|
||||||
}
|
}
|
||||||
hdrChan := make(chan *http.Response)
|
hdrChan := make(chan *http.Response)
|
||||||
c.responses[dataStreamID] = hdrChan
|
|
||||||
c.mutex.Unlock()
|
|
||||||
|
|
||||||
// TODO: think about what to do with a TooManyOpenStreams error. Wait and retry?
|
// TODO: think about what to do with a TooManyOpenStreams error. Wait and retry?
|
||||||
dataStream, err := c.client.OpenStream(dataStreamID)
|
dataStream, err := c.client.OpenStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Close(err)
|
c.Close(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
c.responses[dataStream.StreamID()] = hdrChan
|
||||||
|
c.mutex.Unlock()
|
||||||
|
|
||||||
var requestedGzip bool
|
var requestedGzip bool
|
||||||
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
|
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
|
||||||
|
@ -182,7 +180,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||||
}
|
}
|
||||||
// TODO: add support for trailers
|
// TODO: add support for trailers
|
||||||
endStream := !hasBody
|
endStream := !hasBody
|
||||||
err = c.requestWriter.WriteRequest(req, dataStreamID, endStream, requestedGzip)
|
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Close(err)
|
c.Close(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -209,7 +207,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||||
case res = <-hdrChan:
|
case res = <-hdrChan:
|
||||||
receivedResponse = true
|
receivedResponse = true
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
delete(c.responses, dataStreamID)
|
delete(c.responses, dataStream.StreamID())
|
||||||
c.mutex.Unlock()
|
c.mutex.Unlock()
|
||||||
if res == nil { // an error occured on the header stream
|
if res == nil { // an error occured on the header stream
|
||||||
c.Close(c.headerErr)
|
c.Close(c.headerErr)
|
||||||
|
|
|
@ -18,25 +18,25 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockQuicClient struct {
|
type mockQuicClient struct {
|
||||||
streams map[protocol.StreamID]*mockStream
|
nextStream protocol.StreamID
|
||||||
closeErr error
|
streams map[protocol.StreamID]*mockStream
|
||||||
|
closeErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockQuicClient) Close(e error) error { m.closeErr = e; return nil }
|
func (m *mockQuicClient) Close(e error) error { m.closeErr = e; return nil }
|
||||||
func (m *mockQuicClient) Listen() error { panic("not implemented") }
|
func (m *mockQuicClient) Listen() error { panic("not implemented") }
|
||||||
func (m *mockQuicClient) OpenStream(id protocol.StreamID) (utils.Stream, error) {
|
func (m *mockQuicClient) OpenStream() (utils.Stream, error) {
|
||||||
_, ok := m.streams[id]
|
id := m.nextStream
|
||||||
if ok {
|
|
||||||
panic("Stream already exists")
|
|
||||||
}
|
|
||||||
ms := &mockStream{id: id}
|
ms := &mockStream{id: id}
|
||||||
m.streams[id] = ms
|
m.streams[id] = ms
|
||||||
|
m.nextStream += 2
|
||||||
return ms, nil
|
return ms, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockQuicClient() *mockQuicClient {
|
func newMockQuicClient() *mockQuicClient {
|
||||||
return &mockQuicClient{
|
return &mockQuicClient{
|
||||||
streams: make(map[protocol.StreamID]*mockStream),
|
streams: make(map[protocol.StreamID]*mockStream),
|
||||||
|
nextStream: 5,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,6 +77,7 @@ var _ = Describe("Client", func() {
|
||||||
// delete the headerStream openend in the BeforeEach
|
// delete the headerStream openend in the BeforeEach
|
||||||
client.headerStream = nil
|
client.headerStream = nil
|
||||||
delete(qClient.streams, 3)
|
delete(qClient.streams, 3)
|
||||||
|
qClient.nextStream = 3
|
||||||
Expect(client.headerStream).To(BeNil()) // header stream not yet opened
|
Expect(client.headerStream).To(BeNil()) // header stream not yet opened
|
||||||
// now start the actual test
|
// now start the actual test
|
||||||
err := client.versionNegotiateCallback()
|
err := client.versionNegotiateCallback()
|
||||||
|
@ -133,7 +134,6 @@ var _ = Describe("Client", func() {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
||||||
Expect(client.highestOpenedStream).To(Equal(protocol.StreamID(5)))
|
|
||||||
Expect(qClient.streams).Should(HaveKey(protocol.StreamID(5)))
|
Expect(qClient.streams).Should(HaveKey(protocol.StreamID(5)))
|
||||||
Expect(client.responses).To(HaveKey(protocol.StreamID(5)))
|
Expect(client.responses).To(HaveKey(protocol.StreamID(5)))
|
||||||
rsp := &http.Response{
|
rsp := &http.Response{
|
||||||
|
|
|
@ -19,7 +19,7 @@ import (
|
||||||
// packetHandler handles packets
|
// packetHandler handles packets
|
||||||
type packetHandler interface {
|
type packetHandler interface {
|
||||||
handlePacket(*receivedPacket)
|
handlePacket(*receivedPacket)
|
||||||
OpenStream(protocol.StreamID) (utils.Stream, error)
|
OpenStream() (utils.Stream, error)
|
||||||
run()
|
run()
|
||||||
Close(error) error
|
Close(error) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,8 +34,8 @@ func (s *mockSession) Close(e error) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockSession) OpenStream(id protocol.StreamID) (utils.Stream, error) {
|
func (s *mockSession) OpenStream() (utils.Stream, error) {
|
||||||
return &stream{streamID: id}, nil
|
return &stream{streamID: 1337}, nil
|
||||||
}
|
}
|
||||||
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback closeCallback) (packetHandler, error) {
|
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback closeCallback) (packetHandler, error) {
|
||||||
return &mockSession{
|
return &mockSession{
|
||||||
|
|
|
@ -144,7 +144,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p
|
||||||
session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged)
|
session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged)
|
||||||
session.setup()
|
session.setup()
|
||||||
|
|
||||||
cryptoStream, _ := session.OpenStream(1)
|
cryptoStream, _ := session.OpenStream()
|
||||||
var err error
|
var err error
|
||||||
session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, session.connectionParameters, session.aeadChanged, negotiatedVersions)
|
session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, session.connectionParameters, session.aeadChanged, negotiatedVersions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -666,9 +666,9 @@ func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenStream opens a stream from the server's side
|
// OpenStream opens a stream
|
||||||
func (s *Session) OpenStream(id protocol.StreamID) (utils.Stream, error) {
|
func (s *Session) OpenStream() (utils.Stream, error) {
|
||||||
return s.streamsMap.OpenStream(id)
|
return s.streamsMap.OpenStream()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) {
|
func (s *Session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) {
|
||||||
|
|
|
@ -19,7 +19,9 @@ type streamsMap struct {
|
||||||
streams map[protocol.StreamID]*stream
|
streams map[protocol.StreamID]*stream
|
||||||
openStreams []protocol.StreamID
|
openStreams []protocol.StreamID
|
||||||
|
|
||||||
highestStreamOpenedByClient protocol.StreamID
|
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
|
||||||
|
highestStreamOpenedByPeer protocol.StreamID
|
||||||
|
|
||||||
streamsOpenedAfterLastGarbageCollect int
|
streamsOpenedAfterLastGarbageCollect int
|
||||||
|
|
||||||
newStream newStreamLambda
|
newStream newStreamLambda
|
||||||
|
@ -40,13 +42,21 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap {
|
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap {
|
||||||
return &streamsMap{
|
sm := streamsMap{
|
||||||
perspective: pers,
|
perspective: pers,
|
||||||
streams: map[protocol.StreamID]*stream{},
|
streams: map[protocol.StreamID]*stream{},
|
||||||
openStreams: make([]protocol.StreamID, 0),
|
openStreams: make([]protocol.StreamID, 0),
|
||||||
newStream: newStream,
|
newStream: newStream,
|
||||||
connectionParameters: connectionParameters,
|
connectionParameters: connectionParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if pers == protocol.PerspectiveClient {
|
||||||
|
sm.nextStream = 1
|
||||||
|
} else {
|
||||||
|
sm.nextStream = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
return &sm
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
|
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
|
||||||
|
@ -76,8 +86,8 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
||||||
if m.perspective == protocol.PerspectiveClient && id%2 == 1 {
|
if m.perspective == protocol.PerspectiveClient && id%2 == 1 {
|
||||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
|
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
|
||||||
}
|
}
|
||||||
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByClient {
|
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer {
|
||||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByClient))
|
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := m.newStream(id)
|
s, err := m.newStream(id)
|
||||||
|
@ -91,8 +101,8 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
||||||
m.numOutgoingStreams++
|
m.numOutgoingStreams++
|
||||||
}
|
}
|
||||||
|
|
||||||
if id > m.highestStreamOpenedByClient {
|
if id > m.highestStreamOpenedByPeer {
|
||||||
m.highestStreamOpenedByClient = id
|
m.highestStreamOpenedByPeer = id
|
||||||
}
|
}
|
||||||
|
|
||||||
// maybe trigger garbage collection of streams map
|
// maybe trigger garbage collection of streams map
|
||||||
|
@ -105,21 +115,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenStream opens a stream from the server's side
|
// OpenStream opens the next available stream
|
||||||
func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) {
|
func (m *streamsMap) OpenStream() (*stream, error) {
|
||||||
if m.perspective == protocol.PerspectiveServer && id%2 == 1 {
|
|
||||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
|
|
||||||
}
|
|
||||||
if m.perspective == protocol.PerspectiveClient && id%2 == 0 {
|
|
||||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
_, ok := m.streams[id]
|
|
||||||
if ok {
|
id := m.nextStream
|
||||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is already open", id))
|
|
||||||
}
|
|
||||||
if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() {
|
if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() {
|
||||||
return nil, qerr.TooManyOpenStreams
|
return nil, qerr.TooManyOpenStreams
|
||||||
}
|
}
|
||||||
|
@ -135,6 +136,7 @@ func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) {
|
||||||
m.numIncomingStreams++
|
m.numIncomingStreams++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.nextStream += 2
|
||||||
m.putStream(s)
|
m.putStream(s)
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
@ -256,7 +258,7 @@ func (m *streamsMap) garbageCollectClosedStreams() {
|
||||||
|
|
||||||
// server-side streams can be gargage collected immediately
|
// server-side streams can be gargage collected immediately
|
||||||
// client-side streams need to be kept as nils in the streams map for a bit longer, in order to prevent a client from reopening closed streams
|
// client-side streams need to be kept as nils in the streams map for a bit longer, in order to prevent a client from reopening closed streams
|
||||||
if id%2 == 0 || id+protocol.MaxNewStreamIDDelta <= m.highestStreamOpenedByClient {
|
if id%2 == 0 || id+protocol.MaxNewStreamIDDelta <= m.highestStreamOpenedByPeer {
|
||||||
delete(m.streams, id)
|
delete(m.streams, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,177 +57,178 @@ var _ = Describe("Streams Map", func() {
|
||||||
m *streamsMap
|
m *streamsMap
|
||||||
)
|
)
|
||||||
|
|
||||||
|
setNewStreamsMap := func(p protocol.Perspective) {
|
||||||
|
m = newStreamsMap(nil, p, cpm)
|
||||||
|
m.newStream = func(id protocol.StreamID) (*stream, error) {
|
||||||
|
return &stream{streamID: id}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
cpm = &mockConnectionParametersManager{
|
cpm = &mockConnectionParametersManager{
|
||||||
maxIncomingStreams: 75,
|
maxIncomingStreams: 75,
|
||||||
maxOutgoingStreams: 60,
|
maxOutgoingStreams: 60,
|
||||||
}
|
}
|
||||||
m = newStreamsMap(nil, protocol.PerspectiveServer, cpm)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("getting and creating streams", func() {
|
Context("getting and creating streams", func() {
|
||||||
|
Context("as a server", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
setNewStreamsMap(protocol.PerspectiveServer)
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("client-side streams", func() {
|
||||||
|
It("gets new streams", func() {
|
||||||
|
s, err := m.GetOrOpenStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
||||||
|
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
|
||||||
|
Expect(m.numOutgoingStreams).To(BeZero())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects streams with even IDs", func() {
|
||||||
|
_, err := m.GetOrOpenStream(6)
|
||||||
|
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("gets existing streams", func() {
|
||||||
|
s, err := m.GetOrOpenStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
s, err = m.GetOrOpenStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
||||||
|
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns nil for closed streams", func() {
|
||||||
|
s, err := m.GetOrOpenStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
err = m.RemoveStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
s, err = m.GetOrOpenStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(s).To(BeNil())
|
||||||
|
Expect(m.numIncomingStreams).To(BeZero())
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("counting streams", func() {
|
||||||
|
var maxNumStreams int
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
maxNumStreams = int(cpm.GetMaxIncomingStreams())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when too many streams are opened", func() {
|
||||||
|
for i := 0; i < maxNumStreams; i++ {
|
||||||
|
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
}
|
||||||
|
_, err := m.GetOrOpenStream(protocol.StreamID(2*maxNumStreams + 2))
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("does not error when many streams are opened and closed", func() {
|
||||||
|
for i := 2; i < 10*maxNumStreams; i++ {
|
||||||
|
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
m.RemoveStream(protocol.StreamID(i*2 + 1))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("server-side streams", func() {
|
||||||
|
It("opens a stream 2 first", func() {
|
||||||
|
s, err := m.OpenStream()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(s).ToNot(BeNil())
|
||||||
|
Expect(s.StreamID()).To(Equal(protocol.StreamID(2)))
|
||||||
|
Expect(m.numIncomingStreams).To(BeZero())
|
||||||
|
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("counting streams", func() {
|
||||||
|
var maxNumStreams int
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
maxNumStreams = int(cpm.GetMaxOutgoingStreams())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when too many streams are opened", func() {
|
||||||
|
for i := 1; i <= maxNumStreams; i++ {
|
||||||
|
_, err := m.OpenStream()
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
}
|
||||||
|
_, err := m.OpenStream()
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("does not error when many streams are opened and closed", func() {
|
||||||
|
for i := 2; i < 10*maxNumStreams; i++ {
|
||||||
|
str, err := m.OpenStream()
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
m.RemoveStream(str.StreamID())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("allows many server- and client-side streams at the same time", func() {
|
||||||
|
for i := 1; i < int(cpm.GetMaxOutgoingStreams()); i++ {
|
||||||
|
_, err := m.OpenStream()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
for i := 0; i < int(cpm.GetMaxIncomingStreams()); i++ {
|
||||||
|
_, err := m.GetOrOpenStream(protocol.StreamID(2*i + 1))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("as a client", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
setNewStreamsMap(protocol.PerspectiveClient)
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("client-side streams, as a client", func() {
|
||||||
|
It("rejects streams with odd IDs", func() {
|
||||||
|
_, err := m.GetOrOpenStream(5)
|
||||||
|
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("gets new streams", func() {
|
||||||
|
s, err := m.GetOrOpenStream(6)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(s.StreamID()).To(Equal(protocol.StreamID(6)))
|
||||||
|
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
|
||||||
|
Expect(m.numIncomingStreams).To(BeZero())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("server-side streams", func() {
|
||||||
|
It("opens stream 1 first", func() {
|
||||||
|
s, err := m.OpenStream()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(s).ToNot(BeNil())
|
||||||
|
Expect(s.StreamID()).To(BeEquivalentTo(1))
|
||||||
|
Expect(m.numOutgoingStreams).To(BeZero())
|
||||||
|
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("opens multiple streams", func() {
|
||||||
|
s1, err := m.OpenStream()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
s2, err := m.OpenStream()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(s2.StreamID()).To(Equal(s1.StreamID() + 2))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("DoS mitigation, iterating and deleting", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
m.newStream = func(id protocol.StreamID) (*stream, error) {
|
setNewStreamsMap(protocol.PerspectiveServer)
|
||||||
return &stream{streamID: id}, nil
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets new streams", func() {
|
|
||||||
s, err := m.GetOrOpenStream(5)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
|
||||||
Expect(m.numIncomingStreams).To(Equal(uint32(1)))
|
|
||||||
Expect(m.numOutgoingStreams).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("client-side streams, as a server", func() {
|
|
||||||
It("rejects streams with even IDs", func() {
|
|
||||||
_, err := m.GetOrOpenStream(6)
|
|
||||||
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets existing streams", func() {
|
|
||||||
s, err := m.GetOrOpenStream(5)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
s, err = m.GetOrOpenStream(5)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
|
||||||
Expect(m.numIncomingStreams).To(Equal(uint32(1)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns nil for closed streams", func() {
|
|
||||||
s, err := m.GetOrOpenStream(5)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
err = m.RemoveStream(5)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
s, err = m.GetOrOpenStream(5)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(s).To(BeNil())
|
|
||||||
Expect(m.numIncomingStreams).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("counting streams", func() {
|
|
||||||
var maxNumStreams int
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
maxNumStreams = int(cpm.GetMaxIncomingStreams())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when too many streams are opened", func() {
|
|
||||||
for i := 0; i < maxNumStreams; i++ {
|
|
||||||
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
}
|
|
||||||
_, err := m.GetOrOpenStream(protocol.StreamID(2*maxNumStreams + 2))
|
|
||||||
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not error when many streams are opened and closed", func() {
|
|
||||||
for i := 2; i < 10*maxNumStreams; i++ {
|
|
||||||
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
m.RemoveStream(protocol.StreamID(i*2 + 1))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("client-side streams, as a client", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
m.perspective = protocol.PerspectiveClient
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects streams with odd IDs", func() {
|
|
||||||
_, err := m.GetOrOpenStream(5)
|
|
||||||
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets new streams", func() {
|
|
||||||
s, err := m.GetOrOpenStream(6)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(6)))
|
|
||||||
Expect(m.numOutgoingStreams).To(Equal(uint32(1)))
|
|
||||||
Expect(m.numIncomingStreams).To(BeZero())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("server-side streams, as a server", func() {
|
|
||||||
It("rejects streams with odd IDs", func() {
|
|
||||||
_, err := m.OpenStream(5)
|
|
||||||
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("opens a new stream", func() {
|
|
||||||
s, err := m.OpenStream(6)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(s).ToNot(BeNil())
|
|
||||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(6)))
|
|
||||||
Expect(m.numIncomingStreams).To(BeZero())
|
|
||||||
Expect(m.numOutgoingStreams).To(Equal(uint32(1)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns an error for already openend streams", func() {
|
|
||||||
_, err := m.OpenStream(4)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = m.OpenStream(4)
|
|
||||||
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 4, which is already open"))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("counting streams", func() {
|
|
||||||
var maxNumStreams int
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
maxNumStreams = int(cpm.GetMaxOutgoingStreams())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when too many streams are opened", func() {
|
|
||||||
for i := 1; i <= maxNumStreams; i++ {
|
|
||||||
_, err := m.OpenStream(protocol.StreamID(2 * i))
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
}
|
|
||||||
_, err := m.OpenStream(protocol.StreamID(2*maxNumStreams + 10))
|
|
||||||
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not error when many streams are opened and closed", func() {
|
|
||||||
for i := 2; i < 10*maxNumStreams; i++ {
|
|
||||||
_, err := m.OpenStream(protocol.StreamID(2*i + 2))
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
m.RemoveStream(protocol.StreamID(2 * i))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("allows many server- and client-side streams at the same time", func() {
|
|
||||||
for i := 1; i < int(cpm.GetMaxOutgoingStreams()); i++ {
|
|
||||||
_, err := m.OpenStream(protocol.StreamID(2 * i))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
for i := 0; i < int(cpm.GetMaxIncomingStreams()); i++ {
|
|
||||||
_, err := m.GetOrOpenStream(protocol.StreamID(2*i + 1))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("server-side streams, as a client", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
m.perspective = protocol.PerspectiveClient
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects streams with even IDs", func() {
|
|
||||||
_, err := m.OpenStream(6)
|
|
||||||
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("opens a new stream", func() {
|
|
||||||
s, err := m.OpenStream(7)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(s).ToNot(BeNil())
|
|
||||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(7)))
|
|
||||||
Expect(m.numOutgoingStreams).To(BeZero())
|
|
||||||
Expect(m.numIncomingStreams).To(Equal(uint32(1)))
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("DoS mitigation", func() {
|
Context("DoS mitigation", func() {
|
||||||
|
@ -235,7 +236,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 {
|
for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 {
|
||||||
streamID := protocol.StreamID(i)
|
streamID := protocol.StreamID(i)
|
||||||
_, err := m.GetOrOpenStream(streamID)
|
_, err := m.GetOrOpenStream(streamID)
|
||||||
Expect(m.highestStreamOpenedByClient).To(Equal(streamID))
|
Expect(m.highestStreamOpenedByPeer).To(Equal(streamID))
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
err = m.RemoveStream(streamID)
|
err = m.RemoveStream(streamID)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
@ -253,7 +254,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
err = m.RemoveStream(streamID)
|
err = m.RemoveStream(streamID)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
}
|
}
|
||||||
Expect(m.highestStreamOpenedByClient).To(Equal(protocol.StreamID(protocol.MaxNewStreamIDDelta + 13)))
|
Expect(m.highestStreamOpenedByPeer).To(Equal(protocol.StreamID(protocol.MaxNewStreamIDDelta + 13)))
|
||||||
_, err := m.GetOrOpenStream(11)
|
_, err := m.GetOrOpenStream(11)
|
||||||
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 11, which is a lot smaller than the highest opened stream, 413"))
|
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 11, which is a lot smaller than the highest opened stream, 413"))
|
||||||
_, err = m.GetOrOpenStream(13)
|
_, err = m.GetOrOpenStream(13)
|
||||||
|
@ -264,7 +265,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 {
|
for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 {
|
||||||
streamID := protocol.StreamID(i)
|
streamID := protocol.StreamID(i)
|
||||||
_, err := m.GetOrOpenStream(streamID)
|
_, err := m.GetOrOpenStream(streamID)
|
||||||
Expect(m.highestStreamOpenedByClient).To(Equal(streamID))
|
Expect(m.highestStreamOpenedByPeer).To(Equal(streamID))
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
err = m.RemoveStream(streamID)
|
err = m.RemoveStream(streamID)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
@ -282,7 +283,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
for i := 1; i < 1002; i += 2 {
|
for i := 1; i < 1002; i += 2 {
|
||||||
streamID := protocol.StreamID(i)
|
streamID := protocol.StreamID(i)
|
||||||
_, err := m.GetOrOpenStream(streamID)
|
_, err := m.GetOrOpenStream(streamID)
|
||||||
Expect(m.highestStreamOpenedByClient).To(Equal(streamID))
|
Expect(m.highestStreamOpenedByPeer).To(Equal(streamID))
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
if streamID != 23 {
|
if streamID != 23 {
|
||||||
err = m.RemoveStream(streamID)
|
err = m.RemoveStream(streamID)
|
||||||
|
@ -302,7 +303,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 {
|
for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 {
|
||||||
streamID := protocol.StreamID(i)
|
streamID := protocol.StreamID(i)
|
||||||
_, err := m.GetOrOpenStream(streamID)
|
_, err := m.GetOrOpenStream(streamID)
|
||||||
Expect(m.highestStreamOpenedByClient).To(Equal(streamID))
|
Expect(m.highestStreamOpenedByPeer).To(Equal(streamID))
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
err = m.RemoveStream(streamID)
|
err = m.RemoveStream(streamID)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
@ -316,214 +317,214 @@ var _ = Describe("Streams Map", func() {
|
||||||
Expect(len(m.streams)).To(BeNumerically("<", 2*protocol.MaxNewStreamIDDelta))
|
Expect(len(m.streams)).To(BeNumerically("<", 2*protocol.MaxNewStreamIDDelta))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
|
||||||
|
|
||||||
Context("deleting streams", func() {
|
Context("deleting streams", func() {
|
||||||
BeforeEach(func() {
|
|
||||||
for i := 1; i <= 5; i++ {
|
|
||||||
err := m.putStream(&stream{streamID: protocol.StreamID(i)})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when removing non-existing stream", func() {
|
|
||||||
err := m.RemoveStream(1337)
|
|
||||||
Expect(err).To(MatchError("attempted to remove non-existing stream: 1337"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("removes the first stream", func() {
|
|
||||||
err := m.RemoveStream(1)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(m.openStreams).To(HaveLen(4))
|
|
||||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("removes a stream in the middle", func() {
|
|
||||||
err := m.RemoveStream(3)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(m.openStreams).To(HaveLen(4))
|
|
||||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("removes a stream at the end", func() {
|
|
||||||
err := m.RemoveStream(5)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(m.openStreams).To(HaveLen(4))
|
|
||||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("removes all streams", func() {
|
|
||||||
for i := 1; i <= 5; i++ {
|
|
||||||
err := m.RemoveStream(protocol.StreamID(i))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
Expect(m.openStreams).To(BeEmpty())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Iterate", func() {
|
|
||||||
// create 3 streams, ids 1 to 3
|
|
||||||
BeforeEach(func() {
|
|
||||||
for i := 1; i <= 3; i++ {
|
|
||||||
err := m.putStream(&stream{streamID: protocol.StreamID(i)})
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("executes the lambda exactly once for every stream", func() {
|
|
||||||
var numIterations int
|
|
||||||
callbackCalled := make(map[protocol.StreamID]bool)
|
|
||||||
fn := func(str *stream) (bool, error) {
|
|
||||||
callbackCalled[str.StreamID()] = true
|
|
||||||
numIterations++
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
err := m.Iterate(fn)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(callbackCalled).To(HaveKey(protocol.StreamID(1)))
|
|
||||||
Expect(callbackCalled).To(HaveKey(protocol.StreamID(2)))
|
|
||||||
Expect(callbackCalled).To(HaveKey(protocol.StreamID(3)))
|
|
||||||
Expect(numIterations).To(Equal(3))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("stops iterating when the callback returns false", func() {
|
|
||||||
var numIterations int
|
|
||||||
fn := func(str *stream) (bool, error) {
|
|
||||||
numIterations++
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
err := m.Iterate(fn)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// due to map access randomization, we don't know for which stream the callback was executed
|
|
||||||
// but it must only be executed once
|
|
||||||
Expect(numIterations).To(Equal(1))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the error, if the lambda returns one", func() {
|
|
||||||
var numIterations int
|
|
||||||
expectedError := errors.New("test")
|
|
||||||
fn := func(str *stream) (bool, error) {
|
|
||||||
numIterations++
|
|
||||||
return true, expectedError
|
|
||||||
}
|
|
||||||
err := m.Iterate(fn)
|
|
||||||
Expect(err).To(MatchError(expectedError))
|
|
||||||
Expect(numIterations).To(Equal(1))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("RoundRobinIterate", func() {
|
|
||||||
// create 5 streams, ids 4 to 8
|
|
||||||
var lambdaCalledForStream []protocol.StreamID
|
|
||||||
var numIterations int
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
lambdaCalledForStream = lambdaCalledForStream[:0]
|
|
||||||
numIterations = 0
|
|
||||||
for i := 4; i <= 8; i++ {
|
|
||||||
err := m.putStream(&stream{streamID: protocol.StreamID(i)})
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("executes the lambda exactly once for every stream", func() {
|
|
||||||
fn := func(str *stream) (bool, error) {
|
|
||||||
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
|
||||||
numIterations++
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
err := m.RoundRobinIterate(fn)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(numIterations).To(Equal(5))
|
|
||||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
|
|
||||||
Expect(m.roundRobinIndex).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("goes around once when starting in the middle", func() {
|
|
||||||
fn := func(str *stream) (bool, error) {
|
|
||||||
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
|
||||||
numIterations++
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
m.roundRobinIndex = 3 // pointing to stream 7
|
|
||||||
err := m.RoundRobinIterate(fn)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(numIterations).To(Equal(5))
|
|
||||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{7, 8, 4, 5, 6}))
|
|
||||||
Expect(m.roundRobinIndex).To(Equal(uint32(3)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("picks up at the index+1 where it last stopped", func() {
|
|
||||||
fn := func(str *stream) (bool, error) {
|
|
||||||
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
|
||||||
numIterations++
|
|
||||||
if str.StreamID() == 5 {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
err := m.RoundRobinIterate(fn)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(numIterations).To(Equal(2))
|
|
||||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5}))
|
|
||||||
Expect(m.roundRobinIndex).To(Equal(uint32(2)))
|
|
||||||
numIterations = 0
|
|
||||||
lambdaCalledForStream = lambdaCalledForStream[:0]
|
|
||||||
fn2 := func(str *stream) (bool, error) {
|
|
||||||
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
|
||||||
numIterations++
|
|
||||||
if str.StreamID() == 7 {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
err = m.RoundRobinIterate(fn2)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(numIterations).To(Equal(2))
|
|
||||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{6, 7}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adjust the RoundRobinIndex when deleting an element in front", func() {
|
|
||||||
m.roundRobinIndex = 3 // stream 7
|
|
||||||
m.RemoveStream(5)
|
|
||||||
Expect(m.roundRobinIndex).To(Equal(uint32(2)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't adjust the RoundRobinIndex when deleting an element at the back", func() {
|
|
||||||
m.roundRobinIndex = 1 // stream 5
|
|
||||||
m.RemoveStream(7)
|
|
||||||
Expect(m.roundRobinIndex).To(Equal(uint32(1)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't adjust the RoundRobinIndex when deleting the element it is pointing to", func() {
|
|
||||||
m.roundRobinIndex = 3 // stream 7
|
|
||||||
m.RemoveStream(7)
|
|
||||||
Expect(m.roundRobinIndex).To(Equal(uint32(3)))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Prioritizing crypto- and header streams", func() {
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
err := m.putStream(&stream{streamID: 1})
|
for i := 1; i <= 5; i++ {
|
||||||
Expect(err).NotTo(HaveOccurred())
|
err := m.putStream(&stream{streamID: protocol.StreamID(i)})
|
||||||
err = m.putStream(&stream{streamID: 3})
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(err).NotTo(HaveOccurred())
|
}
|
||||||
|
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5}))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("gets crypto- and header stream first, then picks up at the round-robin position", func() {
|
It("errors when removing non-existing stream", func() {
|
||||||
m.roundRobinIndex = 3 // stream 7
|
err := m.RemoveStream(1337)
|
||||||
|
Expect(err).To(MatchError("attempted to remove non-existing stream: 1337"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("removes the first stream", func() {
|
||||||
|
err := m.RemoveStream(1)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(m.openStreams).To(HaveLen(4))
|
||||||
|
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("removes a stream in the middle", func() {
|
||||||
|
err := m.RemoveStream(3)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(m.openStreams).To(HaveLen(4))
|
||||||
|
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("removes a stream at the end", func() {
|
||||||
|
err := m.RemoveStream(5)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(m.openStreams).To(HaveLen(4))
|
||||||
|
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("removes all streams", func() {
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
err := m.RemoveStream(protocol.StreamID(i))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
Expect(m.openStreams).To(BeEmpty())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("Iterate", func() {
|
||||||
|
// create 3 streams, ids 1 to 3
|
||||||
|
BeforeEach(func() {
|
||||||
|
for i := 1; i <= 3; i++ {
|
||||||
|
err := m.putStream(&stream{streamID: protocol.StreamID(i)})
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("executes the lambda exactly once for every stream", func() {
|
||||||
|
var numIterations int
|
||||||
|
callbackCalled := make(map[protocol.StreamID]bool)
|
||||||
|
fn := func(str *stream) (bool, error) {
|
||||||
|
callbackCalled[str.StreamID()] = true
|
||||||
|
numIterations++
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
err := m.Iterate(fn)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(callbackCalled).To(HaveKey(protocol.StreamID(1)))
|
||||||
|
Expect(callbackCalled).To(HaveKey(protocol.StreamID(2)))
|
||||||
|
Expect(callbackCalled).To(HaveKey(protocol.StreamID(3)))
|
||||||
|
Expect(numIterations).To(Equal(3))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("stops iterating when the callback returns false", func() {
|
||||||
|
var numIterations int
|
||||||
|
fn := func(str *stream) (bool, error) {
|
||||||
|
numIterations++
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
err := m.Iterate(fn)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
// due to map access randomization, we don't know for which stream the callback was executed
|
||||||
|
// but it must only be executed once
|
||||||
|
Expect(numIterations).To(Equal(1))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the error, if the lambda returns one", func() {
|
||||||
|
var numIterations int
|
||||||
|
expectedError := errors.New("test")
|
||||||
|
fn := func(str *stream) (bool, error) {
|
||||||
|
numIterations++
|
||||||
|
return true, expectedError
|
||||||
|
}
|
||||||
|
err := m.Iterate(fn)
|
||||||
|
Expect(err).To(MatchError(expectedError))
|
||||||
|
Expect(numIterations).To(Equal(1))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("RoundRobinIterate", func() {
|
||||||
|
// create 5 streams, ids 4 to 8
|
||||||
|
var lambdaCalledForStream []protocol.StreamID
|
||||||
|
var numIterations int
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
lambdaCalledForStream = lambdaCalledForStream[:0]
|
||||||
|
numIterations = 0
|
||||||
|
for i := 4; i <= 8; i++ {
|
||||||
|
err := m.putStream(&stream{streamID: protocol.StreamID(i)})
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("executes the lambda exactly once for every stream", func() {
|
||||||
fn := func(str *stream) (bool, error) {
|
fn := func(str *stream) (bool, error) {
|
||||||
if numIterations >= 3 {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
||||||
numIterations++
|
numIterations++
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
err := m.RoundRobinIterate(fn)
|
err := m.RoundRobinIterate(fn)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(numIterations).To(Equal(3))
|
Expect(numIterations).To(Equal(5))
|
||||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 3, 7}))
|
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
|
||||||
|
Expect(m.roundRobinIndex).To(BeZero())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("goes around once when starting in the middle", func() {
|
||||||
|
fn := func(str *stream) (bool, error) {
|
||||||
|
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
||||||
|
numIterations++
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
m.roundRobinIndex = 3 // pointing to stream 7
|
||||||
|
err := m.RoundRobinIterate(fn)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(numIterations).To(Equal(5))
|
||||||
|
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{7, 8, 4, 5, 6}))
|
||||||
|
Expect(m.roundRobinIndex).To(Equal(uint32(3)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("picks up at the index+1 where it last stopped", func() {
|
||||||
|
fn := func(str *stream) (bool, error) {
|
||||||
|
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
||||||
|
numIterations++
|
||||||
|
if str.StreamID() == 5 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
err := m.RoundRobinIterate(fn)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(numIterations).To(Equal(2))
|
||||||
|
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5}))
|
||||||
|
Expect(m.roundRobinIndex).To(Equal(uint32(2)))
|
||||||
|
numIterations = 0
|
||||||
|
lambdaCalledForStream = lambdaCalledForStream[:0]
|
||||||
|
fn2 := func(str *stream) (bool, error) {
|
||||||
|
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
||||||
|
numIterations++
|
||||||
|
if str.StreamID() == 7 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
err = m.RoundRobinIterate(fn2)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(numIterations).To(Equal(2))
|
||||||
|
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{6, 7}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("adjust the RoundRobinIndex when deleting an element in front", func() {
|
||||||
|
m.roundRobinIndex = 3 // stream 7
|
||||||
|
m.RemoveStream(5)
|
||||||
|
Expect(m.roundRobinIndex).To(Equal(uint32(2)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("doesn't adjust the RoundRobinIndex when deleting an element at the back", func() {
|
||||||
|
m.roundRobinIndex = 1 // stream 5
|
||||||
|
m.RemoveStream(7)
|
||||||
|
Expect(m.roundRobinIndex).To(BeEquivalentTo(1))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("doesn't adjust the RoundRobinIndex when deleting the element it is pointing to", func() {
|
||||||
|
m.roundRobinIndex = 3 // stream 7
|
||||||
|
m.RemoveStream(7)
|
||||||
|
Expect(m.roundRobinIndex).To(Equal(uint32(3)))
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("Prioritizing crypto- and header streams", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
err := m.putStream(&stream{streamID: 1})
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
err = m.putStream(&stream{streamID: 3})
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("gets crypto- and header stream first, then picks up at the round-robin position", func() {
|
||||||
|
m.roundRobinIndex = 3 // stream 7
|
||||||
|
fn := func(str *stream) (bool, error) {
|
||||||
|
if numIterations >= 3 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
||||||
|
numIterations++
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
err := m.RoundRobinIterate(fn)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(numIterations).To(Equal(3))
|
||||||
|
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 3, 7}))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue