mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
add client functionality to the streamsMap
This commit is contained in:
parent
6cb48aad71
commit
16da08a440
6 changed files with 70 additions and 13 deletions
|
@ -41,7 +41,7 @@ var _ = Describe("Packet packer", func() {
|
|||
fcm.sendWindowSizes[7] = protocol.MaxByteCount
|
||||
|
||||
cpm := &mockConnectionParametersManager{}
|
||||
streamFramer = newStreamFramer(newStreamsMap(nil, cpm), fcm)
|
||||
streamFramer = newStreamFramer(newStreamsMap(nil, protocol.PerspectiveServer, cpm), fcm)
|
||||
|
||||
packer = &packetPacker{
|
||||
cryptoSetup: &mockCryptoSetup{},
|
||||
|
|
|
@ -136,7 +136,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p
|
|||
session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged)
|
||||
session.setup()
|
||||
|
||||
cryptoStream, _ := session.GetOrOpenStream(1)
|
||||
cryptoStream, _ := session.OpenStream(1)
|
||||
var err error
|
||||
session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, session.connectionParameters, session.aeadChanged)
|
||||
if err != nil {
|
||||
|
@ -174,7 +174,7 @@ func (s *Session) setup() {
|
|||
s.lastNetworkActivityTime = now
|
||||
s.sessionCreationTime = now
|
||||
|
||||
s.streamsMap = newStreamsMap(s.newStream, s.connectionParameters)
|
||||
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters)
|
||||
s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager)
|
||||
}
|
||||
|
||||
|
|
|
@ -159,7 +159,7 @@ var _ = Describe("Session", func() {
|
|||
func(protocol.ConnectionID) { closeCallbackCalled = true },
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(clientSession.streamsMap.openStreams).To(HaveLen(1))
|
||||
Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream
|
||||
|
||||
})
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ var _ = Describe("Stream Framer", func() {
|
|||
stream1 = &stream{streamID: 10}
|
||||
stream2 = &stream{streamID: 11}
|
||||
|
||||
streamsMap = newStreamsMap(nil, &mockConnectionParametersManager{})
|
||||
streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, &mockConnectionParametersManager{})
|
||||
streamsMap.putStream(stream1)
|
||||
streamsMap.putStream(stream2)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
type streamsMap struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
perspective protocol.Perspective
|
||||
connectionParameters handshake.ConnectionParametersManager
|
||||
|
||||
streams map[protocol.StreamID]*stream
|
||||
|
@ -38,8 +39,9 @@ var (
|
|||
errMapAccess = errors.New("streamsMap: Error accessing the streams map")
|
||||
)
|
||||
|
||||
func newStreamsMap(newStream newStreamLambda, connectionParameters handshake.ConnectionParametersManager) *streamsMap {
|
||||
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap {
|
||||
return &streamsMap{
|
||||
perspective: pers,
|
||||
streams: map[protocol.StreamID]*stream{},
|
||||
openStreams: make([]protocol.StreamID, 0),
|
||||
newStream: newStream,
|
||||
|
@ -68,9 +70,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
|||
if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() {
|
||||
return nil, qerr.TooManyOpenStreams
|
||||
}
|
||||
if id%2 == 0 {
|
||||
if m.perspective == protocol.PerspectiveServer && id%2 == 0 {
|
||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
|
||||
}
|
||||
if m.perspective == protocol.PerspectiveClient && id%2 == 1 {
|
||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
|
||||
}
|
||||
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByClient {
|
||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByClient))
|
||||
}
|
||||
|
@ -79,7 +84,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.numIncomingStreams++
|
||||
|
||||
if m.perspective == protocol.PerspectiveServer {
|
||||
m.numIncomingStreams++
|
||||
} else {
|
||||
m.numOutgoingStreams++
|
||||
}
|
||||
|
||||
if id > m.highestStreamOpenedByClient {
|
||||
m.highestStreamOpenedByClient = id
|
||||
|
@ -97,9 +107,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
|||
|
||||
// OpenStream opens a stream from the server's side
|
||||
func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) {
|
||||
if id%2 == 1 {
|
||||
if m.perspective == protocol.PerspectiveServer && id%2 == 1 {
|
||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
|
||||
}
|
||||
if m.perspective == protocol.PerspectiveClient && id%2 == 0 {
|
||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
@ -115,7 +128,12 @@ func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.numOutgoingStreams++
|
||||
|
||||
if m.perspective == protocol.PerspectiveServer {
|
||||
m.numOutgoingStreams++
|
||||
} else {
|
||||
m.numIncomingStreams++
|
||||
}
|
||||
|
||||
m.putStream(s)
|
||||
return s, nil
|
||||
|
|
|
@ -59,7 +59,7 @@ var _ = Describe("Streams Map", func() {
|
|||
maxIncomingStreams: 75,
|
||||
maxOutgoingStreams: 60,
|
||||
}
|
||||
m = newStreamsMap(nil, cpm)
|
||||
m = newStreamsMap(nil, protocol.PerspectiveServer, cpm)
|
||||
})
|
||||
|
||||
Context("getting and creating streams", func() {
|
||||
|
@ -77,7 +77,7 @@ var _ = Describe("Streams Map", func() {
|
|||
Expect(m.numOutgoingStreams).To(BeZero())
|
||||
})
|
||||
|
||||
Context("client-side streams", func() {
|
||||
Context("client-side streams, as a server", func() {
|
||||
It("rejects streams with even IDs", func() {
|
||||
_, err := m.GetOrOpenStream(6)
|
||||
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
|
||||
|
@ -129,7 +129,26 @@ var _ = Describe("Streams Map", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("server-side streams", func() {
|
||||
Context("client-side streams, as a client", func() {
|
||||
BeforeEach(func() {
|
||||
m.perspective = protocol.PerspectiveClient
|
||||
})
|
||||
|
||||
It("rejects streams with odd IDs", func() {
|
||||
_, err := m.GetOrOpenStream(5)
|
||||
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
|
||||
})
|
||||
|
||||
It("gets new streams", func() {
|
||||
s, err := m.GetOrOpenStream(6)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(6)))
|
||||
Expect(m.numOutgoingStreams).To(Equal(uint32(1)))
|
||||
Expect(m.numIncomingStreams).To(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
Context("server-side streams, as a server", func() {
|
||||
It("rejects streams with odd IDs", func() {
|
||||
_, err := m.OpenStream(5)
|
||||
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
|
||||
|
@ -188,6 +207,26 @@ var _ = Describe("Streams Map", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("server-side streams, as a client", func() {
|
||||
BeforeEach(func() {
|
||||
m.perspective = protocol.PerspectiveClient
|
||||
})
|
||||
|
||||
It("rejects streams with even IDs", func() {
|
||||
_, err := m.OpenStream(6)
|
||||
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
|
||||
})
|
||||
|
||||
It("opens a new stream", func() {
|
||||
s, err := m.OpenStream(7)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(7)))
|
||||
Expect(m.numOutgoingStreams).To(BeZero())
|
||||
Expect(m.numIncomingStreams).To(Equal(uint32(1)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("DoS mitigation", func() {
|
||||
It("opens and closes a lot of streams", func() {
|
||||
for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue