add client functionality to the streamsMap

This commit is contained in:
Marten Seemann 2016-12-13 11:44:40 +07:00
parent 6cb48aad71
commit 16da08a440
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
6 changed files with 70 additions and 13 deletions

View file

@ -41,7 +41,7 @@ var _ = Describe("Packet packer", func() {
fcm.sendWindowSizes[7] = protocol.MaxByteCount fcm.sendWindowSizes[7] = protocol.MaxByteCount
cpm := &mockConnectionParametersManager{} cpm := &mockConnectionParametersManager{}
streamFramer = newStreamFramer(newStreamsMap(nil, cpm), fcm) streamFramer = newStreamFramer(newStreamsMap(nil, protocol.PerspectiveServer, cpm), fcm)
packer = &packetPacker{ packer = &packetPacker{
cryptoSetup: &mockCryptoSetup{}, cryptoSetup: &mockCryptoSetup{},

View file

@ -136,7 +136,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p
session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged) session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged)
session.setup() session.setup()
cryptoStream, _ := session.GetOrOpenStream(1) cryptoStream, _ := session.OpenStream(1)
var err error var err error
session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, session.connectionParameters, session.aeadChanged) session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, session.connectionParameters, session.aeadChanged)
if err != nil { if err != nil {
@ -174,7 +174,7 @@ func (s *Session) setup() {
s.lastNetworkActivityTime = now s.lastNetworkActivityTime = now
s.sessionCreationTime = now s.sessionCreationTime = now
s.streamsMap = newStreamsMap(s.newStream, s.connectionParameters) s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters)
s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager)
} }

View file

@ -159,7 +159,7 @@ var _ = Describe("Session", func() {
func(protocol.ConnectionID) { closeCallbackCalled = true }, func(protocol.ConnectionID) { closeCallbackCalled = true },
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream
}) })

View file

@ -31,7 +31,7 @@ var _ = Describe("Stream Framer", func() {
stream1 = &stream{streamID: 10} stream1 = &stream{streamID: 10}
stream2 = &stream{streamID: 11} stream2 = &stream{streamID: 11}
streamsMap = newStreamsMap(nil, &mockConnectionParametersManager{}) streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, &mockConnectionParametersManager{})
streamsMap.putStream(stream1) streamsMap.putStream(stream1)
streamsMap.putStream(stream2) streamsMap.putStream(stream2)

View file

@ -13,6 +13,7 @@ import (
type streamsMap struct { type streamsMap struct {
mutex sync.RWMutex mutex sync.RWMutex
perspective protocol.Perspective
connectionParameters handshake.ConnectionParametersManager connectionParameters handshake.ConnectionParametersManager
streams map[protocol.StreamID]*stream streams map[protocol.StreamID]*stream
@ -38,8 +39,9 @@ var (
errMapAccess = errors.New("streamsMap: Error accessing the streams map") errMapAccess = errors.New("streamsMap: Error accessing the streams map")
) )
func newStreamsMap(newStream newStreamLambda, connectionParameters handshake.ConnectionParametersManager) *streamsMap { func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap {
return &streamsMap{ return &streamsMap{
perspective: pers,
streams: map[protocol.StreamID]*stream{}, streams: map[protocol.StreamID]*stream{},
openStreams: make([]protocol.StreamID, 0), openStreams: make([]protocol.StreamID, 0),
newStream: newStream, newStream: newStream,
@ -68,9 +70,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() { if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() {
return nil, qerr.TooManyOpenStreams return nil, qerr.TooManyOpenStreams
} }
if id%2 == 0 { if m.perspective == protocol.PerspectiveServer && id%2 == 0 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
} }
if m.perspective == protocol.PerspectiveClient && id%2 == 1 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
}
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByClient { if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByClient {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByClient)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByClient))
} }
@ -79,7 +84,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.numIncomingStreams++
if m.perspective == protocol.PerspectiveServer {
m.numIncomingStreams++
} else {
m.numOutgoingStreams++
}
if id > m.highestStreamOpenedByClient { if id > m.highestStreamOpenedByClient {
m.highestStreamOpenedByClient = id m.highestStreamOpenedByClient = id
@ -97,9 +107,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
// OpenStream opens a stream from the server's side // OpenStream opens a stream from the server's side
func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) { func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) {
if id%2 == 1 { if m.perspective == protocol.PerspectiveServer && id%2 == 1 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
} }
if m.perspective == protocol.PerspectiveClient && id%2 == 0 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
}
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -115,7 +128,12 @@ func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.numOutgoingStreams++
if m.perspective == protocol.PerspectiveServer {
m.numOutgoingStreams++
} else {
m.numIncomingStreams++
}
m.putStream(s) m.putStream(s)
return s, nil return s, nil

View file

@ -59,7 +59,7 @@ var _ = Describe("Streams Map", func() {
maxIncomingStreams: 75, maxIncomingStreams: 75,
maxOutgoingStreams: 60, maxOutgoingStreams: 60,
} }
m = newStreamsMap(nil, cpm) m = newStreamsMap(nil, protocol.PerspectiveServer, cpm)
}) })
Context("getting and creating streams", func() { Context("getting and creating streams", func() {
@ -77,7 +77,7 @@ var _ = Describe("Streams Map", func() {
Expect(m.numOutgoingStreams).To(BeZero()) Expect(m.numOutgoingStreams).To(BeZero())
}) })
Context("client-side streams", func() { Context("client-side streams, as a server", func() {
It("rejects streams with even IDs", func() { It("rejects streams with even IDs", func() {
_, err := m.GetOrOpenStream(6) _, err := m.GetOrOpenStream(6)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side")) Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
@ -129,7 +129,26 @@ var _ = Describe("Streams Map", func() {
}) })
}) })
Context("server-side streams", func() { Context("client-side streams, as a client", func() {
BeforeEach(func() {
m.perspective = protocol.PerspectiveClient
})
It("rejects streams with odd IDs", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
})
It("gets new streams", func() {
s, err := m.GetOrOpenStream(6)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(6)))
Expect(m.numOutgoingStreams).To(Equal(uint32(1)))
Expect(m.numIncomingStreams).To(BeZero())
})
})
Context("server-side streams, as a server", func() {
It("rejects streams with odd IDs", func() { It("rejects streams with odd IDs", func() {
_, err := m.OpenStream(5) _, err := m.OpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side")) Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
@ -188,6 +207,26 @@ var _ = Describe("Streams Map", func() {
}) })
}) })
Context("server-side streams, as a client", func() {
BeforeEach(func() {
m.perspective = protocol.PerspectiveClient
})
It("rejects streams with even IDs", func() {
_, err := m.OpenStream(6)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
})
It("opens a new stream", func() {
s, err := m.OpenStream(7)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(Equal(protocol.StreamID(7)))
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(Equal(uint32(1)))
})
})
Context("DoS mitigation", func() { Context("DoS mitigation", func() {
It("opens and closes a lot of streams", func() { It("opens and closes a lot of streams", func() {
for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 { for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 {