use StreamsMap in Session and StreamFramer

fixes #208
This commit is contained in:
Marten Seemann 2016-08-05 19:30:28 +07:00
parent 0cc42b8847
commit a796877b94
5 changed files with 136 additions and 100 deletions

View file

@ -2,7 +2,6 @@ package quic
import (
"bytes"
"sync"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/handshake"
@ -24,7 +23,7 @@ var _ = Describe("Packet packer", func() {
fcm.sendWindowSizes[5] = protocol.MaxByteCount
fcm.sendWindowSizes[7] = protocol.MaxByteCount
streamFramer = newStreamFramer(&map[protocol.StreamID]*stream{}, &sync.RWMutex{}, fcm)
streamFramer = newStreamFramer(newStreamsMap(), fcm)
packer = &packetPacker{
cryptoSetup: &handshake.CryptoSetup{},

View file

@ -48,7 +48,7 @@ type Session struct {
conn connection
streams map[protocol.StreamID]*stream
streamsMap *streamsMap
openStreamsCount uint32
streamsMutex sync.RWMutex
@ -108,12 +108,13 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
}
session := &Session{
connectionID: connectionID,
version: v,
conn: conn,
streamCallback: streamCallback,
closeCallback: closeCallback,
streams: make(map[protocol.StreamID]*stream),
connectionID: connectionID,
version: v,
conn: conn,
streamCallback: streamCallback,
closeCallback: closeCallback,
// streams: make(map[protocol.StreamID]*stream),
streamsMap: newStreamsMap(),
sentPacketHandler: sentPacketHandler,
receivedPacketHandler: receivedPacketHandler,
stopWaitingManager: stopWaitingManager,
@ -135,7 +136,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
return nil, err
}
session.streamFramer = newStreamFramer(&session.streams, &session.streamsMutex, flowControlManager)
session.streamFramer = newStreamFramer(session.streamsMap, flowControlManager)
session.packer = newPacketPacker(connectionID, session.cryptoSetup, session.connectionParametersManager, session.streamFramer, v)
session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: v}
@ -331,10 +332,10 @@ func (s *Session) handlePacket(remoteAddr interface{}, hdr *PublicHeader, data [
func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error {
s.streamsMutex.Lock()
defer s.streamsMutex.Unlock()
str, streamExists := s.streams[frame.StreamID]
str, strExists := s.streamsMap.GetStream(frame.StreamID)
var err error
if !streamExists {
if !strExists {
if !s.isValidStreamID(frame.StreamID) {
return qerr.InvalidStreamID
}
@ -352,7 +353,7 @@ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error {
if err != nil {
return err
}
if !streamExists {
if !strExists {
s.streamCallback(s, str)
}
return nil
@ -366,14 +367,14 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error
s.streamsMutex.RLock()
defer s.streamsMutex.RUnlock()
if frame.StreamID != 0 {
stream, ok := s.streams[frame.StreamID]
if ok && stream == nil {
str, strExists := s.streamsMap.GetStream(frame.StreamID)
if strExists && str == nil {
return errWindowUpdateOnClosedStream
}
// open new stream when receiving a WindowUpdate for a non-existing stream
// this can occur if the client immediately sends a WindowUpdate for a newly opened stream, and packet reordering occurs such that the packet opening the new stream arrives after the WindowUpdate
if !ok {
if !strExists {
s.newStreamImpl(frame.StreamID)
}
}
@ -384,7 +385,7 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error
// TODO: Handle frame.byteOffset
func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error {
s.streamsMutex.RLock()
str, streamExists := s.streams[frame.StreamID]
str, streamExists := s.streamsMap.GetStream(frame.StreamID)
s.streamsMutex.RUnlock()
if !streamExists || str == nil {
return errRstStreamOnInvalidStream
@ -445,12 +446,16 @@ func (s *Session) closeImpl(e error, remoteClose bool) error {
func (s *Session) closeStreamsWithError(err error) {
s.streamsMutex.Lock()
defer s.streamsMutex.Unlock()
for _, str := range s.streams {
fn := func(str *stream) (bool, error) {
if str == nil {
continue
return true, nil
}
s.closeStreamWithError(str, err)
return true, nil
}
s.streamsMap.Iterate(fn)
}
func (s *Session) closeStreamWithError(str *stream, err error) {
@ -595,7 +600,8 @@ func (s *Session) OpenStream(id protocol.StreamID) (utils.Stream, error) {
func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) {
s.streamsMutex.Lock()
defer s.streamsMutex.Unlock()
if stream, ok := s.streams[id]; ok {
stream, strExists := s.streamsMap.GetStream(id)
if strExists {
return stream, nil
}
return s.newStreamImpl(id)
@ -608,7 +614,8 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) {
go s.Close(qerr.TooManyOpenStreams)
return nil, qerr.TooManyOpenStreams
}
if _, ok := s.streams[id]; ok {
_, strExists := s.streamsMap.GetStream(id)
if strExists {
return nil, fmt.Errorf("Session: stream with ID %d already exists", id)
}
stream, err := newStream(s.scheduleSending, s.connectionParametersManager, s.flowControlManager, id)
@ -624,26 +631,33 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) {
}
atomic.AddUint32(&s.openStreamsCount, 1)
s.streams[id] = stream
err = s.streamsMap.PutStream(stream)
if err != nil {
return nil, err
}
return stream, nil
}
// garbageCollectStreams goes through all streams and removes EOF'ed streams
// from the streams map.
func (s *Session) garbageCollectStreams() {
s.streamsMutex.Lock()
defer s.streamsMutex.Unlock()
for k, v := range s.streams {
if v == nil {
continue
fn := func(str *stream) (bool, error) {
if str == nil {
return true, nil
}
if v.finished() {
utils.Debugf("Garbage-collecting stream %d", k)
id := str.StreamID()
if str.finished() {
atomic.AddUint32(&s.openStreamsCount, ^uint32(0)) // decrement
s.streams[k] = nil
s.flowControlManager.RemoveStream(k)
err := s.streamsMap.RemoveStream(id)
if err != nil {
return false, err
}
s.flowControlManager.RemoveStream(id)
}
return true, nil
}
s.streamsMap.Iterate(fn)
}
func (s *Session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error {
@ -680,20 +694,24 @@ func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) {
var res []*frames.WindowUpdateFrame
for id, str := range s.streams {
fn := func(str *stream) (bool, error) {
if str == nil {
continue
return true, nil
}
id := str.StreamID()
doUpdate, offset, err := s.flowControlManager.MaybeTriggerStreamWindowUpdate(id)
if err != nil {
return nil, err
return false, err
}
if doUpdate {
res = append(res, &frames.WindowUpdateFrame{StreamID: id, ByteOffset: offset})
}
return true, nil
}
s.streamsMap.Iterate(fn)
doUpdate, offset := s.flowControlManager.MaybeTriggerConnectionWindowUpdate()
if doUpdate {
res = append(res, &frames.WindowUpdateFrame{StreamID: 0, ByteOffset: offset})

View file

@ -114,7 +114,7 @@ var _ = Describe("Session", func() {
)
Expect(err).NotTo(HaveOccurred())
session = pSession.(*Session)
Expect(session.streams).To(HaveLen(1)) // Crypto stream
Expect(session.streamsMap.NumberOfStreams()).To(Equal(1)) // Crypto stream
})
Context("when handling stream frames", func() {
@ -123,10 +123,12 @@ var _ = Describe("Session", func() {
StreamID: 5,
Data: []byte{0xde, 0xca, 0xfb, 0xad},
})
Expect(session.streams).To(HaveLen(2))
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
_, err := str.Read(p)
Expect(err).ToNot(HaveOccurred())
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
})
@ -154,16 +156,18 @@ var _ = Describe("Session", func() {
StreamID: 5,
Data: []byte{0xde, 0xca},
})
Expect(session.streams).To(HaveLen(2))
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
Expect(streamCallbackCalled).To(BeTrue())
session.handleStreamFrame(&frames.StreamFrame{
StreamID: 5,
Offset: 2,
Data: []byte{0xfb, 0xad},
})
Expect(session.streams).To(HaveLen(2))
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
_, err := str.Read(p)
Expect(err).ToNot(HaveOccurred())
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
})
@ -173,8 +177,9 @@ var _ = Describe("Session", func() {
Expect(err).ToNot(HaveOccurred())
str.Close()
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ = session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
})
It("does not delete streams with FIN bit", func() {
@ -183,16 +188,18 @@ var _ = Describe("Session", func() {
Data: []byte{0xde, 0xca, 0xfb, 0xad},
FinBit: true,
})
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
_, err := str.Read(p)
Expect(err).To(MatchError(io.EOF))
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ = session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
})
It("deletes streams with FIN bit & close", func() {
@ -201,23 +208,27 @@ var _ = Describe("Session", func() {
Data: []byte{0xde, 0xca, 0xfb, 0xad},
FinBit: true,
})
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
_, err := str.Read(p)
Expect(err).To(MatchError(io.EOF))
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ = session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
// We still need to close the stream locally
session.streams[5].Close()
str.Close()
// ... and simulate that we actually the FIN
session.streams[5].sentFin()
str.sentFin()
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).To(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(1))
str, strExists := session.streamsMap.GetStream(5)
Expect(strExists).To(BeTrue())
Expect(str).To(BeNil())
// flow controller should have been notified
_, err = session.flowControlManager.SendWindowSize(5)
Expect(err).To(MatchError("Error accessing the flowController map."))
@ -229,31 +240,35 @@ var _ = Describe("Session", func() {
StreamID: 5,
Data: []byte{0xde, 0xca, 0xfb, 0xad},
})
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
Expect(err).ToNot(HaveOccurred())
_, err := str.Read(p)
session.closeStreamsWithError(testErr)
_, err = session.streams[5].Read(p)
_, err = str.Read(p)
Expect(err).To(MatchError(testErr))
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).To(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(1))
str, strExists := session.streamsMap.GetStream(5)
Expect(strExists).To(BeTrue())
Expect(str).To(BeNil())
})
It("closes empty streams with error", func() {
testErr := errors.New("test")
session.newStreamImpl(5)
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(session.streamsMap.NumberOfStreams()).To(Equal(2))
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
session.closeStreamsWithError(testErr)
_, err := session.streams[5].Read([]byte{0})
_, err := str.Read([]byte{0})
Expect(err).To(MatchError(testErr))
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).To(BeNil())
str, strExists := session.streamsMap.GetStream(5)
Expect(strExists).To(BeTrue())
Expect(str).To(BeNil())
})
It("informs the FlowControlManager about new streams", func() {
@ -271,10 +286,12 @@ var _ = Describe("Session", func() {
Data: []byte{},
FinBit: true,
})
_, err := session.streams[5].Read([]byte{0})
str, _ := session.streamsMap.GetStream(5)
Expect(str).ToNot(BeNil())
_, err := str.Read([]byte{0})
Expect(err).To(MatchError(io.EOF))
session.streams[5].Close()
session.streams[5].sentFin()
str.Close()
str.sentFin()
session.garbageCollectStreams()
err = session.handleStreamFrame(&frames.StreamFrame{
StreamID: 5,
@ -344,13 +361,17 @@ var _ = Describe("Session", func() {
ByteOffset: 1337,
})
Expect(err).ToNot(HaveOccurred())
Expect(session.streams).To(HaveKey(protocol.StreamID(5)))
Expect(session.streams[5]).ToNot(BeNil())
str, strExists := session.streamsMap.GetStream(5)
Expect(strExists).To(BeTrue())
Expect(str).ToNot(BeNil())
})
It("errors when receiving a WindowUpdateFrame for a closed stream", func() {
session.streams[5] = nil // this is what the garbageCollectStreams() does when a Stream is closed
err := session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{
session.handleStreamFrame(&frames.StreamFrame{StreamID: 5})
err := session.streamsMap.RemoveStream(5)
Expect(err).ToNot(HaveOccurred())
session.garbageCollectStreams()
err = session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{
StreamID: 5,
ByteOffset: 1337,
})
@ -358,8 +379,11 @@ var _ = Describe("Session", func() {
})
It("ignores errors when receiving a WindowUpdateFrame for a closed stream", func() {
session.streams[5] = nil // this is what the garbageCollectStreams() does when a Stream is closed
err := session.handleFrames([]frames.Frame{&frames.WindowUpdateFrame{
session.handleStreamFrame(&frames.StreamFrame{StreamID: 5})
err := session.streamsMap.RemoveStream(5)
Expect(err).ToNot(HaveOccurred())
session.garbageCollectStreams()
err = session.handleFrames([]frames.Frame{&frames.WindowUpdateFrame{
StreamID: 5,
ByteOffset: 1337,
}})

View file

@ -1,8 +1,6 @@
package quic
import (
"sync"
"github.com/lucas-clemente/quic-go/flowcontrol"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
@ -11,8 +9,7 @@ import (
type streamFramer struct {
// TODO: Simplify by extracting the streams map into a separate object
streams *map[protocol.StreamID]*stream
streamsMutex *sync.RWMutex
streamsMap *streamsMap
flowControlManager flowcontrol.FlowControlManager
@ -20,10 +17,9 @@ type streamFramer struct {
blockedFrameQueue []*frames.BlockedFrame
}
func newStreamFramer(streams *map[protocol.StreamID]*stream, streamsMutex *sync.RWMutex, flowControlManager flowcontrol.FlowControlManager) *streamFramer {
func newStreamFramer(streamsMap *streamsMap, flowControlManager flowcontrol.FlowControlManager) *streamFramer {
return &streamFramer{
streams: streams,
streamsMutex: streamsMutex,
streamsMap: streamsMap,
flowControlManager: flowControlManager,
}
}
@ -73,15 +69,12 @@ func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount
}
func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []*frames.StreamFrame) {
f.streamsMutex.RLock()
defer f.streamsMutex.RUnlock()
frame := &frames.StreamFrame{DataLenPresent: true}
var currentLen protocol.ByteCount
for _, s := range *f.streams {
fn := func(s *stream) (bool, error) {
if s == nil {
continue
return true, nil
}
frame.StreamID = s.streamID
@ -89,7 +82,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
frame.Offset = s.writeOffset
frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error
if currentLen+frameHeaderBytes > maxBytes {
return // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here
return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here
}
maxLen := maxBytes - currentLen - frameHeaderBytes
@ -99,7 +92,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
}
if maxLen == 0 {
continue
return true, nil
}
data := s.getDataForWriting(maxLen)
@ -111,7 +104,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
currentLen += frameHeaderBytes + frame.DataLen()
frame = &frames.StreamFrame{DataLenPresent: true}
}
continue
return true, nil
}
frame.Data = data
@ -129,7 +122,11 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
res = append(res, frame)
currentLen += frameHeaderBytes + frame.DataLen()
frame = &frames.StreamFrame{DataLenPresent: true}
return true, nil
}
f.streamsMap.Iterate(fn)
return
}

View file

@ -2,7 +2,6 @@ package quic
import (
"bytes"
"sync"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
@ -14,7 +13,7 @@ var _ = Describe("Stream Framer", func() {
var (
retransmittedFrame1, retransmittedFrame2 *frames.StreamFrame
framer *streamFramer
streamsMap map[protocol.StreamID]*stream
streamsMap *streamsMap
stream1, stream2 *stream
fcm *mockFlowControlHandler
)
@ -31,18 +30,17 @@ var _ = Describe("Stream Framer", func() {
stream1 = &stream{streamID: 10}
stream2 = &stream{streamID: 11}
streamsMap = map[protocol.StreamID]*stream{
1: nil, 2: nil, 3: nil, 4: nil, // we have to be able to deal with nil frames
10: stream1,
11: stream2,
}
streamsMap = newStreamsMap()
streamsMap.PutStream(stream1)
streamsMap.PutStream(stream2)
fcm = newMockFlowControlHandler()
fcm.sendWindowSizes[stream1.streamID] = protocol.MaxByteCount
fcm.sendWindowSizes[stream2.streamID] = protocol.MaxByteCount
fcm.sendWindowSizes[retransmittedFrame1.StreamID] = protocol.MaxByteCount
fcm.sendWindowSizes[retransmittedFrame2.StreamID] = protocol.MaxByteCount
framer = newStreamFramer(&streamsMap, &sync.RWMutex{}, fcm)
framer = newStreamFramer(streamsMap, fcm)
})
It("sets the DataLenPresent for dequeued retransmitted frames", func() {