add client functionality to the streamsMap

This commit is contained in:
Marten Seemann 2016-12-13 11:44:40 +07:00
parent 6cb48aad71
commit 16da08a440
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
6 changed files with 70 additions and 13 deletions

View file

@ -41,7 +41,7 @@ var _ = Describe("Packet packer", func() {
fcm.sendWindowSizes[7] = protocol.MaxByteCount
cpm := &mockConnectionParametersManager{}
streamFramer = newStreamFramer(newStreamsMap(nil, cpm), fcm)
streamFramer = newStreamFramer(newStreamsMap(nil, protocol.PerspectiveServer, cpm), fcm)
packer = &packetPacker{
cryptoSetup: &mockCryptoSetup{},

View file

@ -136,7 +136,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p
session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged)
session.setup()
cryptoStream, _ := session.GetOrOpenStream(1)
cryptoStream, _ := session.OpenStream(1)
var err error
session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, session.connectionParameters, session.aeadChanged)
if err != nil {
@ -174,7 +174,7 @@ func (s *Session) setup() {
s.lastNetworkActivityTime = now
s.sessionCreationTime = now
s.streamsMap = newStreamsMap(s.newStream, s.connectionParameters)
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters)
s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager)
}

View file

@ -159,7 +159,7 @@ var _ = Describe("Session", func() {
func(protocol.ConnectionID) { closeCallbackCalled = true },
)
Expect(err).ToNot(HaveOccurred())
Expect(clientSession.streamsMap.openStreams).To(HaveLen(1))
Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream
})

View file

@ -31,7 +31,7 @@ var _ = Describe("Stream Framer", func() {
stream1 = &stream{streamID: 10}
stream2 = &stream{streamID: 11}
streamsMap = newStreamsMap(nil, &mockConnectionParametersManager{})
streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, &mockConnectionParametersManager{})
streamsMap.putStream(stream1)
streamsMap.putStream(stream2)

View file

@ -13,6 +13,7 @@ import (
type streamsMap struct {
mutex sync.RWMutex
perspective protocol.Perspective
connectionParameters handshake.ConnectionParametersManager
streams map[protocol.StreamID]*stream
@ -38,8 +39,9 @@ var (
errMapAccess = errors.New("streamsMap: Error accessing the streams map")
)
func newStreamsMap(newStream newStreamLambda, connectionParameters handshake.ConnectionParametersManager) *streamsMap {
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap {
return &streamsMap{
perspective: pers,
streams: map[protocol.StreamID]*stream{},
openStreams: make([]protocol.StreamID, 0),
newStream: newStream,
@ -68,9 +70,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() {
return nil, qerr.TooManyOpenStreams
}
if id%2 == 0 {
if m.perspective == protocol.PerspectiveServer && id%2 == 0 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
}
if m.perspective == protocol.PerspectiveClient && id%2 == 1 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
}
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByClient {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByClient))
}
@ -79,7 +84,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
if err != nil {
return nil, err
}
m.numIncomingStreams++
if m.perspective == protocol.PerspectiveServer {
m.numIncomingStreams++
} else {
m.numOutgoingStreams++
}
if id > m.highestStreamOpenedByClient {
m.highestStreamOpenedByClient = id
@ -97,9 +107,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
// OpenStream opens a stream from the server's side
func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) {
if id%2 == 1 {
if m.perspective == protocol.PerspectiveServer && id%2 == 1 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
}
if m.perspective == protocol.PerspectiveClient && id%2 == 0 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
}
m.mutex.Lock()
defer m.mutex.Unlock()
@ -115,7 +128,12 @@ func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) {
if err != nil {
return nil, err
}
m.numOutgoingStreams++
if m.perspective == protocol.PerspectiveServer {
m.numOutgoingStreams++
} else {
m.numIncomingStreams++
}
m.putStream(s)
return s, nil

View file

@ -59,7 +59,7 @@ var _ = Describe("Streams Map", func() {
maxIncomingStreams: 75,
maxOutgoingStreams: 60,
}
m = newStreamsMap(nil, cpm)
m = newStreamsMap(nil, protocol.PerspectiveServer, cpm)
})
Context("getting and creating streams", func() {
@ -77,7 +77,7 @@ var _ = Describe("Streams Map", func() {
Expect(m.numOutgoingStreams).To(BeZero())
})
Context("client-side streams", func() {
Context("client-side streams, as a server", func() {
It("rejects streams with even IDs", func() {
_, err := m.GetOrOpenStream(6)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
@ -129,7 +129,26 @@ var _ = Describe("Streams Map", func() {
})
})
Context("server-side streams", func() {
Context("client-side streams, as a client", func() {
BeforeEach(func() {
m.perspective = protocol.PerspectiveClient
})
It("rejects streams with odd IDs", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
})
It("gets new streams", func() {
s, err := m.GetOrOpenStream(6)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(6)))
Expect(m.numOutgoingStreams).To(Equal(uint32(1)))
Expect(m.numIncomingStreams).To(BeZero())
})
})
Context("server-side streams, as a server", func() {
It("rejects streams with odd IDs", func() {
_, err := m.OpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
@ -188,6 +207,26 @@ var _ = Describe("Streams Map", func() {
})
})
Context("server-side streams, as a client", func() {
BeforeEach(func() {
m.perspective = protocol.PerspectiveClient
})
It("rejects streams with even IDs", func() {
_, err := m.OpenStream(6)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
})
It("opens a new stream", func() {
s, err := m.OpenStream(7)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(Equal(protocol.StreamID(7)))
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(Equal(uint32(1)))
})
})
Context("DoS mitigation", func() {
It("opens and closes a lot of streams", func() {
for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 {