remove stream ID from OpenStream() method

This commit is contained in:
Marten Seemann 2017-02-09 17:05:58 +07:00
parent 8cd1e4484c
commit f47142eaac
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
9 changed files with 424 additions and 423 deletions

View file

@ -109,8 +109,8 @@ func (c *Client) Listen() error {
} }
// OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs) // OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs)
func (c *Client) OpenStream(id protocol.StreamID) (utils.Stream, error) { func (c *Client) OpenStream() (utils.Stream, error) {
return c.session.OpenStream(id) return c.session.OpenStream()
} }
// Close closes the connection // Close closes the connection

View file

@ -109,9 +109,9 @@ var _ = Describe("Client", func() {
}) })
It("opens a stream", func() { It("opens a stream", func() {
stream, err := client.OpenStream(1337) stream, err := client.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(stream.StreamID()).To(Equal(protocol.StreamID(1337))) Expect(stream).ToNot(BeNil())
}) })
Context("handling packets", func() { Context("handling packets", func() {

View file

@ -21,7 +21,7 @@ import (
) )
type quicClient interface { type quicClient interface {
OpenStream(protocol.StreamID) (utils.Stream, error) OpenStream() (utils.Stream, error)
Close(error) error Close(error) error
Listen() error Listen() error
} }
@ -39,7 +39,6 @@ type Client struct {
client quicClient client quicClient
headerStream utils.Stream headerStream utils.Stream
headerErr *qerr.QuicError headerErr *qerr.QuicError
highestOpenedStream protocol.StreamID
requestWriter *requestWriter requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response responses map[protocol.StreamID]chan *http.Response
@ -52,7 +51,6 @@ func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Cl
c := &Client{ c := &Client{
t: t, t: t,
hostname: authorityAddr("https", hostname), hostname: authorityAddr("https", hostname),
highestOpenedStream: 3,
responses: make(map[protocol.StreamID]chan *http.Response), responses: make(map[protocol.StreamID]chan *http.Response),
} }
c.cryptoChangedCond = sync.Cond{L: &c.mutex} c.cryptoChangedCond = sync.Cond{L: &c.mutex}
@ -88,10 +86,13 @@ func (c *Client) cryptoChangeCallback(isForwardSecure bool) {
func (c *Client) versionNegotiateCallback() error { func (c *Client) versionNegotiateCallback() error {
var err error var err error
// once the version has been negotiated, open the header stream // once the version has been negotiated, open the header stream
c.headerStream, err = c.client.OpenStream(3) c.headerStream, err = c.client.OpenStream()
if err != nil { if err != nil {
return err return err
} }
if c.headerStream.StreamID() != 3 {
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
}
c.requestWriter = newRequestWriter(c.headerStream) c.requestWriter = newRequestWriter(c.headerStream)
go c.handleHeaderStream() go c.handleHeaderStream()
return nil return nil
@ -160,21 +161,18 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
hasBody := (req.Body != nil) hasBody := (req.Body != nil)
c.mutex.Lock() c.mutex.Lock()
c.highestOpenedStream += 2
dataStreamID := c.highestOpenedStream
for c.encryptionLevel != protocol.EncryptionForwardSecure { for c.encryptionLevel != protocol.EncryptionForwardSecure {
c.cryptoChangedCond.Wait() c.cryptoChangedCond.Wait()
} }
hdrChan := make(chan *http.Response) hdrChan := make(chan *http.Response)
c.responses[dataStreamID] = hdrChan
c.mutex.Unlock()
// TODO: think about what to do with a TooManyOpenStreams error. Wait and retry? // TODO: think about what to do with a TooManyOpenStreams error. Wait and retry?
dataStream, err := c.client.OpenStream(dataStreamID) dataStream, err := c.client.OpenStream()
if err != nil { if err != nil {
c.Close(err) c.Close(err)
return nil, err return nil, err
} }
c.responses[dataStream.StreamID()] = hdrChan
c.mutex.Unlock()
var requestedGzip bool var requestedGzip bool
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
@ -182,7 +180,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
} }
// TODO: add support for trailers // TODO: add support for trailers
endStream := !hasBody endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStreamID, endStream, requestedGzip) err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil { if err != nil {
c.Close(err) c.Close(err)
return nil, err return nil, err
@ -209,7 +207,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
case res = <-hdrChan: case res = <-hdrChan:
receivedResponse = true receivedResponse = true
c.mutex.Lock() c.mutex.Lock()
delete(c.responses, dataStreamID) delete(c.responses, dataStream.StreamID())
c.mutex.Unlock() c.mutex.Unlock()
if res == nil { // an error occured on the header stream if res == nil { // an error occured on the header stream
c.Close(c.headerErr) c.Close(c.headerErr)

View file

@ -18,25 +18,25 @@ import (
) )
type mockQuicClient struct { type mockQuicClient struct {
nextStream protocol.StreamID
streams map[protocol.StreamID]*mockStream streams map[protocol.StreamID]*mockStream
closeErr error closeErr error
} }
func (m *mockQuicClient) Close(e error) error { m.closeErr = e; return nil } func (m *mockQuicClient) Close(e error) error { m.closeErr = e; return nil }
func (m *mockQuicClient) Listen() error { panic("not implemented") } func (m *mockQuicClient) Listen() error { panic("not implemented") }
func (m *mockQuicClient) OpenStream(id protocol.StreamID) (utils.Stream, error) { func (m *mockQuicClient) OpenStream() (utils.Stream, error) {
_, ok := m.streams[id] id := m.nextStream
if ok {
panic("Stream already exists")
}
ms := &mockStream{id: id} ms := &mockStream{id: id}
m.streams[id] = ms m.streams[id] = ms
m.nextStream += 2
return ms, nil return ms, nil
} }
func newMockQuicClient() *mockQuicClient { func newMockQuicClient() *mockQuicClient {
return &mockQuicClient{ return &mockQuicClient{
streams: make(map[protocol.StreamID]*mockStream), streams: make(map[protocol.StreamID]*mockStream),
nextStream: 5,
} }
} }
@ -77,6 +77,7 @@ var _ = Describe("Client", func() {
// delete the headerStream openend in the BeforeEach // delete the headerStream openend in the BeforeEach
client.headerStream = nil client.headerStream = nil
delete(qClient.streams, 3) delete(qClient.streams, 3)
qClient.nextStream = 3
Expect(client.headerStream).To(BeNil()) // header stream not yet opened Expect(client.headerStream).To(BeNil()) // header stream not yet opened
// now start the actual test // now start the actual test
err := client.versionNegotiateCallback() err := client.versionNegotiateCallback()
@ -133,7 +134,6 @@ var _ = Describe("Client", func() {
}() }()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty()) Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
Expect(client.highestOpenedStream).To(Equal(protocol.StreamID(5)))
Expect(qClient.streams).Should(HaveKey(protocol.StreamID(5))) Expect(qClient.streams).Should(HaveKey(protocol.StreamID(5)))
Expect(client.responses).To(HaveKey(protocol.StreamID(5))) Expect(client.responses).To(HaveKey(protocol.StreamID(5)))
rsp := &http.Response{ rsp := &http.Response{

View file

@ -19,7 +19,7 @@ import (
// packetHandler handles packets // packetHandler handles packets
type packetHandler interface { type packetHandler interface {
handlePacket(*receivedPacket) handlePacket(*receivedPacket)
OpenStream(protocol.StreamID) (utils.Stream, error) OpenStream() (utils.Stream, error)
run() run()
Close(error) error Close(error) error
} }

View file

@ -34,8 +34,8 @@ func (s *mockSession) Close(e error) error {
return nil return nil
} }
func (s *mockSession) OpenStream(id protocol.StreamID) (utils.Stream, error) { func (s *mockSession) OpenStream() (utils.Stream, error) {
return &stream{streamID: id}, nil return &stream{streamID: 1337}, nil
} }
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback closeCallback) (packetHandler, error) { func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback closeCallback) (packetHandler, error) {
return &mockSession{ return &mockSession{

View file

@ -144,7 +144,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p
session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged) session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged)
session.setup() session.setup()
cryptoStream, _ := session.OpenStream(1) cryptoStream, _ := session.OpenStream()
var err error var err error
session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, session.connectionParameters, session.aeadChanged, negotiatedVersions) session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, session.connectionParameters, session.aeadChanged, negotiatedVersions)
if err != nil { if err != nil {
@ -666,9 +666,9 @@ func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) {
return nil, err return nil, err
} }
// OpenStream opens a stream from the server's side // OpenStream opens a stream
func (s *Session) OpenStream(id protocol.StreamID) (utils.Stream, error) { func (s *Session) OpenStream() (utils.Stream, error) {
return s.streamsMap.OpenStream(id) return s.streamsMap.OpenStream()
} }
func (s *Session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { func (s *Session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) {

View file

@ -19,7 +19,9 @@ type streamsMap struct {
streams map[protocol.StreamID]*stream streams map[protocol.StreamID]*stream
openStreams []protocol.StreamID openStreams []protocol.StreamID
highestStreamOpenedByClient protocol.StreamID nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID
streamsOpenedAfterLastGarbageCollect int streamsOpenedAfterLastGarbageCollect int
newStream newStreamLambda newStream newStreamLambda
@ -40,13 +42,21 @@ var (
) )
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap { func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap {
return &streamsMap{ sm := streamsMap{
perspective: pers, perspective: pers,
streams: map[protocol.StreamID]*stream{}, streams: map[protocol.StreamID]*stream{},
openStreams: make([]protocol.StreamID, 0), openStreams: make([]protocol.StreamID, 0),
newStream: newStream, newStream: newStream,
connectionParameters: connectionParameters, connectionParameters: connectionParameters,
} }
if pers == protocol.PerspectiveClient {
sm.nextStream = 1
} else {
sm.nextStream = 2
}
return &sm
} }
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
@ -76,8 +86,8 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
if m.perspective == protocol.PerspectiveClient && id%2 == 1 { if m.perspective == protocol.PerspectiveClient && id%2 == 1 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
} }
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByClient { if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByClient)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
} }
s, err := m.newStream(id) s, err := m.newStream(id)
@ -91,8 +101,8 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
m.numOutgoingStreams++ m.numOutgoingStreams++
} }
if id > m.highestStreamOpenedByClient { if id > m.highestStreamOpenedByPeer {
m.highestStreamOpenedByClient = id m.highestStreamOpenedByPeer = id
} }
// maybe trigger garbage collection of streams map // maybe trigger garbage collection of streams map
@ -105,21 +115,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
return s, nil return s, nil
} }
// OpenStream opens a stream from the server's side // OpenStream opens the next available stream
func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) { func (m *streamsMap) OpenStream() (*stream, error) {
if m.perspective == protocol.PerspectiveServer && id%2 == 1 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
}
if m.perspective == protocol.PerspectiveClient && id%2 == 0 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
}
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
_, ok := m.streams[id]
if ok { id := m.nextStream
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is already open", id))
}
if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() { if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() {
return nil, qerr.TooManyOpenStreams return nil, qerr.TooManyOpenStreams
} }
@ -135,6 +136,7 @@ func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) {
m.numIncomingStreams++ m.numIncomingStreams++
} }
m.nextStream += 2
m.putStream(s) m.putStream(s)
return s, nil return s, nil
} }
@ -256,7 +258,7 @@ func (m *streamsMap) garbageCollectClosedStreams() {
// server-side streams can be gargage collected immediately // server-side streams can be gargage collected immediately
// client-side streams need to be kept as nils in the streams map for a bit longer, in order to prevent a client from reopening closed streams // client-side streams need to be kept as nils in the streams map for a bit longer, in order to prevent a client from reopening closed streams
if id%2 == 0 || id+protocol.MaxNewStreamIDDelta <= m.highestStreamOpenedByClient { if id%2 == 0 || id+protocol.MaxNewStreamIDDelta <= m.highestStreamOpenedByPeer {
delete(m.streams, id) delete(m.streams, id)
} }
} }

View file

@ -57,30 +57,35 @@ var _ = Describe("Streams Map", func() {
m *streamsMap m *streamsMap
) )
setNewStreamsMap := func(p protocol.Perspective) {
m = newStreamsMap(nil, p, cpm)
m.newStream = func(id protocol.StreamID) (*stream, error) {
return &stream{streamID: id}, nil
}
}
BeforeEach(func() { BeforeEach(func() {
cpm = &mockConnectionParametersManager{ cpm = &mockConnectionParametersManager{
maxIncomingStreams: 75, maxIncomingStreams: 75,
maxOutgoingStreams: 60, maxOutgoingStreams: 60,
} }
m = newStreamsMap(nil, protocol.PerspectiveServer, cpm)
}) })
Context("getting and creating streams", func() { Context("getting and creating streams", func() {
Context("as a server", func() {
BeforeEach(func() { BeforeEach(func() {
m.newStream = func(id protocol.StreamID) (*stream, error) { setNewStreamsMap(protocol.PerspectiveServer)
return &stream{streamID: id}, nil
}
}) })
Context("client-side streams", func() {
It("gets new streams", func() { It("gets new streams", func() {
s, err := m.GetOrOpenStream(5) s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
Expect(m.numIncomingStreams).To(Equal(uint32(1))) Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
Expect(m.numOutgoingStreams).To(BeZero()) Expect(m.numOutgoingStreams).To(BeZero())
}) })
Context("client-side streams, as a server", func() {
It("rejects streams with even IDs", func() { It("rejects streams with even IDs", func() {
_, err := m.GetOrOpenStream(6) _, err := m.GetOrOpenStream(6)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side")) Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
@ -92,7 +97,7 @@ var _ = Describe("Streams Map", func() {
s, err = m.GetOrOpenStream(5) s, err = m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
Expect(m.numIncomingStreams).To(Equal(uint32(1))) Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
}) })
It("returns nil for closed streams", func() { It("returns nil for closed streams", func() {
@ -132,45 +137,14 @@ var _ = Describe("Streams Map", func() {
}) })
}) })
Context("client-side streams, as a client", func() { Context("server-side streams", func() {
BeforeEach(func() { It("opens a stream 2 first", func() {
m.perspective = protocol.PerspectiveClient s, err := m.OpenStream()
})
It("rejects streams with odd IDs", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
})
It("gets new streams", func() {
s, err := m.GetOrOpenStream(6)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(6)))
Expect(m.numOutgoingStreams).To(Equal(uint32(1)))
Expect(m.numIncomingStreams).To(BeZero())
})
})
Context("server-side streams, as a server", func() {
It("rejects streams with odd IDs", func() {
_, err := m.OpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
})
It("opens a new stream", func() {
s, err := m.OpenStream(6)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil()) Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(Equal(protocol.StreamID(6))) Expect(s.StreamID()).To(Equal(protocol.StreamID(2)))
Expect(m.numIncomingStreams).To(BeZero()) Expect(m.numIncomingStreams).To(BeZero())
Expect(m.numOutgoingStreams).To(Equal(uint32(1))) Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
})
It("returns an error for already openend streams", func() {
_, err := m.OpenStream(4)
Expect(err).ToNot(HaveOccurred())
_, err = m.OpenStream(4)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 4, which is already open"))
}) })
Context("counting streams", func() { Context("counting streams", func() {
@ -182,24 +156,24 @@ var _ = Describe("Streams Map", func() {
It("errors when too many streams are opened", func() { It("errors when too many streams are opened", func() {
for i := 1; i <= maxNumStreams; i++ { for i := 1; i <= maxNumStreams; i++ {
_, err := m.OpenStream(protocol.StreamID(2 * i)) _, err := m.OpenStream()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
_, err := m.OpenStream(protocol.StreamID(2*maxNumStreams + 10)) _, err := m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams)) Expect(err).To(MatchError(qerr.TooManyOpenStreams))
}) })
It("does not error when many streams are opened and closed", func() { It("does not error when many streams are opened and closed", func() {
for i := 2; i < 10*maxNumStreams; i++ { for i := 2; i < 10*maxNumStreams; i++ {
_, err := m.OpenStream(protocol.StreamID(2*i + 2)) str, err := m.OpenStream()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
m.RemoveStream(protocol.StreamID(2 * i)) m.RemoveStream(str.StreamID())
} }
}) })
It("allows many server- and client-side streams at the same time", func() { It("allows many server- and client-side streams at the same time", func() {
for i := 1; i < int(cpm.GetMaxOutgoingStreams()); i++ { for i := 1; i < int(cpm.GetMaxOutgoingStreams()); i++ {
_, err := m.OpenStream(protocol.StreamID(2 * i)) _, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
for i := 0; i < int(cpm.GetMaxIncomingStreams()); i++ { for i := 0; i < int(cpm.GetMaxIncomingStreams()); i++ {
@ -209,25 +183,52 @@ var _ = Describe("Streams Map", func() {
}) })
}) })
}) })
})
Context("server-side streams, as a client", func() { Context("as a client", func() {
BeforeEach(func() { BeforeEach(func() {
m.perspective = protocol.PerspectiveClient setNewStreamsMap(protocol.PerspectiveClient)
}) })
It("rejects streams with even IDs", func() { Context("client-side streams, as a client", func() {
_, err := m.OpenStream(6) It("rejects streams with odd IDs", func() {
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side")) _, err := m.GetOrOpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
}) })
It("opens a new stream", func() { It("gets new streams", func() {
s, err := m.OpenStream(7) s, err := m.GetOrOpenStream(6)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(6)))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
Expect(m.numIncomingStreams).To(BeZero())
})
})
Context("server-side streams", func() {
It("opens stream 1 first", func() {
s, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil()) Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(Equal(protocol.StreamID(7))) Expect(s.StreamID()).To(BeEquivalentTo(1))
Expect(m.numOutgoingStreams).To(BeZero()) Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(Equal(uint32(1))) Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
}) })
It("opens multiple streams", func() {
s1, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
s2, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(s2.StreamID()).To(Equal(s1.StreamID() + 2))
})
})
})
})
Context("DoS mitigation, iterating and deleting", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveServer)
}) })
Context("DoS mitigation", func() { Context("DoS mitigation", func() {
@ -235,7 +236,7 @@ var _ = Describe("Streams Map", func() {
for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 { for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 {
streamID := protocol.StreamID(i) streamID := protocol.StreamID(i)
_, err := m.GetOrOpenStream(streamID) _, err := m.GetOrOpenStream(streamID)
Expect(m.highestStreamOpenedByClient).To(Equal(streamID)) Expect(m.highestStreamOpenedByPeer).To(Equal(streamID))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = m.RemoveStream(streamID) err = m.RemoveStream(streamID)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -253,7 +254,7 @@ var _ = Describe("Streams Map", func() {
err = m.RemoveStream(streamID) err = m.RemoveStream(streamID)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
Expect(m.highestStreamOpenedByClient).To(Equal(protocol.StreamID(protocol.MaxNewStreamIDDelta + 13))) Expect(m.highestStreamOpenedByPeer).To(Equal(protocol.StreamID(protocol.MaxNewStreamIDDelta + 13)))
_, err := m.GetOrOpenStream(11) _, err := m.GetOrOpenStream(11)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 11, which is a lot smaller than the highest opened stream, 413")) Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 11, which is a lot smaller than the highest opened stream, 413"))
_, err = m.GetOrOpenStream(13) _, err = m.GetOrOpenStream(13)
@ -264,7 +265,7 @@ var _ = Describe("Streams Map", func() {
for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 { for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 {
streamID := protocol.StreamID(i) streamID := protocol.StreamID(i)
_, err := m.GetOrOpenStream(streamID) _, err := m.GetOrOpenStream(streamID)
Expect(m.highestStreamOpenedByClient).To(Equal(streamID)) Expect(m.highestStreamOpenedByPeer).To(Equal(streamID))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = m.RemoveStream(streamID) err = m.RemoveStream(streamID)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -282,7 +283,7 @@ var _ = Describe("Streams Map", func() {
for i := 1; i < 1002; i += 2 { for i := 1; i < 1002; i += 2 {
streamID := protocol.StreamID(i) streamID := protocol.StreamID(i)
_, err := m.GetOrOpenStream(streamID) _, err := m.GetOrOpenStream(streamID)
Expect(m.highestStreamOpenedByClient).To(Equal(streamID)) Expect(m.highestStreamOpenedByPeer).To(Equal(streamID))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
if streamID != 23 { if streamID != 23 {
err = m.RemoveStream(streamID) err = m.RemoveStream(streamID)
@ -302,7 +303,7 @@ var _ = Describe("Streams Map", func() {
for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 { for i := 1; i < 4*protocol.MaxNewStreamIDDelta; i += 2 {
streamID := protocol.StreamID(i) streamID := protocol.StreamID(i)
_, err := m.GetOrOpenStream(streamID) _, err := m.GetOrOpenStream(streamID)
Expect(m.highestStreamOpenedByClient).To(Equal(streamID)) Expect(m.highestStreamOpenedByPeer).To(Equal(streamID))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = m.RemoveStream(streamID) err = m.RemoveStream(streamID)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -316,7 +317,6 @@ var _ = Describe("Streams Map", func() {
Expect(len(m.streams)).To(BeNumerically("<", 2*protocol.MaxNewStreamIDDelta)) Expect(len(m.streams)).To(BeNumerically("<", 2*protocol.MaxNewStreamIDDelta))
}) })
}) })
})
Context("deleting streams", func() { Context("deleting streams", func() {
BeforeEach(func() { BeforeEach(func() {
@ -493,7 +493,7 @@ var _ = Describe("Streams Map", func() {
It("doesn't adjust the RoundRobinIndex when deleting an element at the back", func() { It("doesn't adjust the RoundRobinIndex when deleting an element at the back", func() {
m.roundRobinIndex = 1 // stream 5 m.roundRobinIndex = 1 // stream 5
m.RemoveStream(7) m.RemoveStream(7)
Expect(m.roundRobinIndex).To(Equal(uint32(1))) Expect(m.roundRobinIndex).To(BeEquivalentTo(1))
}) })
It("doesn't adjust the RoundRobinIndex when deleting the element it is pointing to", func() { It("doesn't adjust the RoundRobinIndex when deleting the element it is pointing to", func() {
@ -528,3 +528,4 @@ var _ = Describe("Streams Map", func() {
}) })
}) })
}) })
})