diff --git a/packet_packer_test.go b/packet_packer_test.go index e6579637..f56727ad 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -2,7 +2,6 @@ package quic import ( "bytes" - "sync" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/handshake" @@ -24,7 +23,7 @@ var _ = Describe("Packet packer", func() { fcm.sendWindowSizes[5] = protocol.MaxByteCount fcm.sendWindowSizes[7] = protocol.MaxByteCount - streamFramer = newStreamFramer(&map[protocol.StreamID]*stream{}, &sync.RWMutex{}, fcm) + streamFramer = newStreamFramer(newStreamsMap(), fcm) packer = &packetPacker{ cryptoSetup: &handshake.CryptoSetup{}, diff --git a/session.go b/session.go index bc455676..a42c538d 100644 --- a/session.go +++ b/session.go @@ -48,7 +48,7 @@ type Session struct { conn connection - streams map[protocol.StreamID]*stream + streamsMap *streamsMap openStreamsCount uint32 streamsMutex sync.RWMutex @@ -108,12 +108,13 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol } session := &Session{ - connectionID: connectionID, - version: v, - conn: conn, - streamCallback: streamCallback, - closeCallback: closeCallback, - streams: make(map[protocol.StreamID]*stream), + connectionID: connectionID, + version: v, + conn: conn, + streamCallback: streamCallback, + closeCallback: closeCallback, + // streams: make(map[protocol.StreamID]*stream), + streamsMap: newStreamsMap(), sentPacketHandler: sentPacketHandler, receivedPacketHandler: receivedPacketHandler, stopWaitingManager: stopWaitingManager, @@ -135,7 +136,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol return nil, err } - session.streamFramer = newStreamFramer(&session.streams, &session.streamsMutex, flowControlManager) + session.streamFramer = newStreamFramer(session.streamsMap, flowControlManager) session.packer = newPacketPacker(connectionID, session.cryptoSetup, session.connectionParametersManager, session.streamFramer, v) session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: v} @@ -331,10 +332,10 @@ func (s *Session) handlePacket(remoteAddr interface{}, hdr *PublicHeader, data [ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error { s.streamsMutex.Lock() defer s.streamsMutex.Unlock() - str, streamExists := s.streams[frame.StreamID] + str, strExists := s.streamsMap.GetStream(frame.StreamID) var err error - if !streamExists { + if !strExists { if !s.isValidStreamID(frame.StreamID) { return qerr.InvalidStreamID } @@ -352,7 +353,7 @@ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error { if err != nil { return err } - if !streamExists { + if !strExists { s.streamCallback(s, str) } return nil @@ -366,14 +367,14 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error s.streamsMutex.RLock() defer s.streamsMutex.RUnlock() if frame.StreamID != 0 { - stream, ok := s.streams[frame.StreamID] - if ok && stream == nil { + str, strExists := s.streamsMap.GetStream(frame.StreamID) + if strExists && str == nil { return errWindowUpdateOnClosedStream } // open new stream when receiving a WindowUpdate for a non-existing stream // this can occur if the client immediately sends a WindowUpdate for a newly opened stream, and packet reordering occurs such that the packet opening the new stream arrives after the WindowUpdate - if !ok { + if !strExists { s.newStreamImpl(frame.StreamID) } } @@ -384,7 +385,7 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error // TODO: Handle frame.byteOffset func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { s.streamsMutex.RLock() - str, streamExists := s.streams[frame.StreamID] + str, streamExists := s.streamsMap.GetStream(frame.StreamID) s.streamsMutex.RUnlock() if !streamExists || str == nil { return errRstStreamOnInvalidStream @@ -445,12 +446,16 @@ func (s *Session) closeImpl(e error, remoteClose bool) error { func (s *Session) closeStreamsWithError(err error) { s.streamsMutex.Lock() defer s.streamsMutex.Unlock() - for _, str := range s.streams { + + fn := func(str *stream) (bool, error) { if str == nil { - continue + return true, nil } s.closeStreamWithError(str, err) + return true, nil } + + s.streamsMap.Iterate(fn) } func (s *Session) closeStreamWithError(str *stream, err error) { @@ -595,7 +600,8 @@ func (s *Session) OpenStream(id protocol.StreamID) (utils.Stream, error) { func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { s.streamsMutex.Lock() defer s.streamsMutex.Unlock() - if stream, ok := s.streams[id]; ok { + stream, strExists := s.streamsMap.GetStream(id) + if strExists { return stream, nil } return s.newStreamImpl(id) @@ -608,7 +614,8 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) { go s.Close(qerr.TooManyOpenStreams) return nil, qerr.TooManyOpenStreams } - if _, ok := s.streams[id]; ok { + _, strExists := s.streamsMap.GetStream(id) + if strExists { return nil, fmt.Errorf("Session: stream with ID %d already exists", id) } stream, err := newStream(s.scheduleSending, s.connectionParametersManager, s.flowControlManager, id) @@ -624,26 +631,33 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) { } atomic.AddUint32(&s.openStreamsCount, 1) - s.streams[id] = stream + err = s.streamsMap.PutStream(stream) + if err != nil { + return nil, err + } return stream, nil } // garbageCollectStreams goes through all streams and removes EOF'ed streams // from the streams map. func (s *Session) garbageCollectStreams() { - s.streamsMutex.Lock() - defer s.streamsMutex.Unlock() - for k, v := range s.streams { - if v == nil { - continue + fn := func(str *stream) (bool, error) { + if str == nil { + return true, nil } - if v.finished() { - utils.Debugf("Garbage-collecting stream %d", k) + id := str.StreamID() + if str.finished() { atomic.AddUint32(&s.openStreamsCount, ^uint32(0)) // decrement - s.streams[k] = nil - s.flowControlManager.RemoveStream(k) + err := s.streamsMap.RemoveStream(id) + if err != nil { + return false, err + } + s.flowControlManager.RemoveStream(id) } + return true, nil } + + s.streamsMap.Iterate(fn) } func (s *Session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { @@ -680,20 +694,24 @@ func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) { var res []*frames.WindowUpdateFrame - for id, str := range s.streams { + fn := func(str *stream) (bool, error) { if str == nil { - continue + return true, nil } + id := str.StreamID() doUpdate, offset, err := s.flowControlManager.MaybeTriggerStreamWindowUpdate(id) if err != nil { - return nil, err + return false, err } if doUpdate { res = append(res, &frames.WindowUpdateFrame{StreamID: id, ByteOffset: offset}) } + return true, nil } + s.streamsMap.Iterate(fn) + doUpdate, offset := s.flowControlManager.MaybeTriggerConnectionWindowUpdate() if doUpdate { res = append(res, &frames.WindowUpdateFrame{StreamID: 0, ByteOffset: offset}) diff --git a/session_test.go b/session_test.go index 430631f0..36bc845a 100644 --- a/session_test.go +++ b/session_test.go @@ -114,7 +114,7 @@ var _ = Describe("Session", func() { ) Expect(err).NotTo(HaveOccurred()) session = pSession.(*Session) - Expect(session.streams).To(HaveLen(1)) // Crypto stream + Expect(session.streamsMap.NumberOfStreams()).To(Equal(1)) // Crypto stream }) Context("when handling stream frames", func() { @@ -123,10 +123,12 @@ var _ = Describe("Session", func() { StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, }) - Expect(session.streams).To(HaveLen(2)) + Expect(session.streamsMap.NumberOfStreams()).To(Equal(2)) Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) - _, err := session.streams[5].Read(p) + str, _ := session.streamsMap.GetStream(5) + Expect(str).ToNot(BeNil()) + _, err := str.Read(p) Expect(err).ToNot(HaveOccurred()) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) }) @@ -154,16 +156,18 @@ var _ = Describe("Session", func() { StreamID: 5, Data: []byte{0xde, 0xca}, }) - Expect(session.streams).To(HaveLen(2)) + Expect(session.streamsMap.NumberOfStreams()).To(Equal(2)) Expect(streamCallbackCalled).To(BeTrue()) session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Offset: 2, Data: []byte{0xfb, 0xad}, }) - Expect(session.streams).To(HaveLen(2)) + Expect(session.streamsMap.NumberOfStreams()).To(Equal(2)) p := make([]byte, 4) - _, err := session.streams[5].Read(p) + str, _ := session.streamsMap.GetStream(5) + Expect(str).ToNot(BeNil()) + _, err := str.Read(p) Expect(err).ToNot(HaveOccurred()) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) }) @@ -173,8 +177,9 @@ var _ = Describe("Session", func() { Expect(err).ToNot(HaveOccurred()) str.Close() session.garbageCollectStreams() - Expect(session.streams).To(HaveLen(2)) - Expect(session.streams[5]).ToNot(BeNil()) + Expect(session.streamsMap.NumberOfStreams()).To(Equal(2)) + str, _ = session.streamsMap.GetStream(5) + Expect(str).ToNot(BeNil()) }) It("does not delete streams with FIN bit", func() { @@ -183,16 +188,18 @@ var _ = Describe("Session", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, FinBit: true, }) - Expect(session.streams).To(HaveLen(2)) - Expect(session.streams[5]).ToNot(BeNil()) + Expect(session.streamsMap.NumberOfStreams()).To(Equal(2)) + str, _ := session.streamsMap.GetStream(5) + Expect(str).ToNot(BeNil()) Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) - _, err := session.streams[5].Read(p) + _, err := str.Read(p) Expect(err).To(MatchError(io.EOF)) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) session.garbageCollectStreams() - Expect(session.streams).To(HaveLen(2)) - Expect(session.streams[5]).ToNot(BeNil()) + Expect(session.streamsMap.NumberOfStreams()).To(Equal(2)) + str, _ = session.streamsMap.GetStream(5) + Expect(str).ToNot(BeNil()) }) It("deletes streams with FIN bit & close", func() { @@ -201,23 +208,27 @@ var _ = Describe("Session", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, FinBit: true, }) - Expect(session.streams).To(HaveLen(2)) - Expect(session.streams[5]).ToNot(BeNil()) + Expect(session.streamsMap.NumberOfStreams()).To(Equal(2)) + str, _ := session.streamsMap.GetStream(5) + Expect(str).ToNot(BeNil()) Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) - _, err := session.streams[5].Read(p) + _, err := str.Read(p) Expect(err).To(MatchError(io.EOF)) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) session.garbageCollectStreams() - Expect(session.streams).To(HaveLen(2)) - Expect(session.streams[5]).ToNot(BeNil()) + Expect(session.streamsMap.NumberOfStreams()).To(Equal(2)) + str, _ = session.streamsMap.GetStream(5) + Expect(str).ToNot(BeNil()) // We still need to close the stream locally - session.streams[5].Close() + str.Close() // ... and simulate that we actually the FIN - session.streams[5].sentFin() + str.sentFin() session.garbageCollectStreams() - Expect(session.streams).To(HaveLen(2)) - Expect(session.streams[5]).To(BeNil()) + Expect(session.streamsMap.NumberOfStreams()).To(Equal(1)) + str, strExists := session.streamsMap.GetStream(5) + Expect(strExists).To(BeTrue()) + Expect(str).To(BeNil()) // flow controller should have been notified _, err = session.flowControlManager.SendWindowSize(5) Expect(err).To(MatchError("Error accessing the flowController map.")) @@ -229,31 +240,35 @@ var _ = Describe("Session", func() { StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, }) - Expect(session.streams).To(HaveLen(2)) - Expect(session.streams[5]).ToNot(BeNil()) + Expect(session.streamsMap.NumberOfStreams()).To(Equal(2)) + str, _ := session.streamsMap.GetStream(5) + Expect(str).ToNot(BeNil()) Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) - _, err := session.streams[5].Read(p) - Expect(err).ToNot(HaveOccurred()) + _, err := str.Read(p) session.closeStreamsWithError(testErr) - _, err = session.streams[5].Read(p) + _, err = str.Read(p) Expect(err).To(MatchError(testErr)) session.garbageCollectStreams() - Expect(session.streams).To(HaveLen(2)) - Expect(session.streams[5]).To(BeNil()) + Expect(session.streamsMap.NumberOfStreams()).To(Equal(1)) + str, strExists := session.streamsMap.GetStream(5) + Expect(strExists).To(BeTrue()) + Expect(str).To(BeNil()) }) It("closes empty streams with error", func() { testErr := errors.New("test") session.newStreamImpl(5) - Expect(session.streams).To(HaveLen(2)) - Expect(session.streams[5]).ToNot(BeNil()) + Expect(session.streamsMap.NumberOfStreams()).To(Equal(2)) + str, _ := session.streamsMap.GetStream(5) + Expect(str).ToNot(BeNil()) session.closeStreamsWithError(testErr) - _, err := session.streams[5].Read([]byte{0}) + _, err := str.Read([]byte{0}) Expect(err).To(MatchError(testErr)) session.garbageCollectStreams() - Expect(session.streams).To(HaveLen(2)) - Expect(session.streams[5]).To(BeNil()) + str, strExists := session.streamsMap.GetStream(5) + Expect(strExists).To(BeTrue()) + Expect(str).To(BeNil()) }) It("informs the FlowControlManager about new streams", func() { @@ -271,10 +286,12 @@ var _ = Describe("Session", func() { Data: []byte{}, FinBit: true, }) - _, err := session.streams[5].Read([]byte{0}) + str, _ := session.streamsMap.GetStream(5) + Expect(str).ToNot(BeNil()) + _, err := str.Read([]byte{0}) Expect(err).To(MatchError(io.EOF)) - session.streams[5].Close() - session.streams[5].sentFin() + str.Close() + str.sentFin() session.garbageCollectStreams() err = session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, @@ -344,13 +361,17 @@ var _ = Describe("Session", func() { ByteOffset: 1337, }) Expect(err).ToNot(HaveOccurred()) - Expect(session.streams).To(HaveKey(protocol.StreamID(5))) - Expect(session.streams[5]).ToNot(BeNil()) + str, strExists := session.streamsMap.GetStream(5) + Expect(strExists).To(BeTrue()) + Expect(str).ToNot(BeNil()) }) It("errors when receiving a WindowUpdateFrame for a closed stream", func() { - session.streams[5] = nil // this is what the garbageCollectStreams() does when a Stream is closed - err := session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{ + session.handleStreamFrame(&frames.StreamFrame{StreamID: 5}) + err := session.streamsMap.RemoveStream(5) + Expect(err).ToNot(HaveOccurred()) + session.garbageCollectStreams() + err = session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{ StreamID: 5, ByteOffset: 1337, }) @@ -358,8 +379,11 @@ var _ = Describe("Session", func() { }) It("ignores errors when receiving a WindowUpdateFrame for a closed stream", func() { - session.streams[5] = nil // this is what the garbageCollectStreams() does when a Stream is closed - err := session.handleFrames([]frames.Frame{&frames.WindowUpdateFrame{ + session.handleStreamFrame(&frames.StreamFrame{StreamID: 5}) + err := session.streamsMap.RemoveStream(5) + Expect(err).ToNot(HaveOccurred()) + session.garbageCollectStreams() + err = session.handleFrames([]frames.Frame{&frames.WindowUpdateFrame{ StreamID: 5, ByteOffset: 1337, }}) diff --git a/stream_framer.go b/stream_framer.go index 4674439b..c4021ca6 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -1,8 +1,6 @@ package quic import ( - "sync" - "github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/protocol" @@ -11,8 +9,7 @@ import ( type streamFramer struct { // TODO: Simplify by extracting the streams map into a separate object - streams *map[protocol.StreamID]*stream - streamsMutex *sync.RWMutex + streamsMap *streamsMap flowControlManager flowcontrol.FlowControlManager @@ -20,10 +17,9 @@ type streamFramer struct { blockedFrameQueue []*frames.BlockedFrame } -func newStreamFramer(streams *map[protocol.StreamID]*stream, streamsMutex *sync.RWMutex, flowControlManager flowcontrol.FlowControlManager) *streamFramer { +func newStreamFramer(streamsMap *streamsMap, flowControlManager flowcontrol.FlowControlManager) *streamFramer { return &streamFramer{ - streams: streams, - streamsMutex: streamsMutex, + streamsMap: streamsMap, flowControlManager: flowControlManager, } } @@ -73,15 +69,12 @@ func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount } func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []*frames.StreamFrame) { - f.streamsMutex.RLock() - defer f.streamsMutex.RUnlock() - frame := &frames.StreamFrame{DataLenPresent: true} var currentLen protocol.ByteCount - for _, s := range *f.streams { + fn := func(s *stream) (bool, error) { if s == nil { - continue + return true, nil } frame.StreamID = s.streamID @@ -89,7 +82,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] frame.Offset = s.writeOffset frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error if currentLen+frameHeaderBytes > maxBytes { - return // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here + return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here } maxLen := maxBytes - currentLen - frameHeaderBytes @@ -99,7 +92,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] } if maxLen == 0 { - continue + return true, nil } data := s.getDataForWriting(maxLen) @@ -111,7 +104,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] currentLen += frameHeaderBytes + frame.DataLen() frame = &frames.StreamFrame{DataLenPresent: true} } - continue + return true, nil } frame.Data = data @@ -129,7 +122,11 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] res = append(res, frame) currentLen += frameHeaderBytes + frame.DataLen() frame = &frames.StreamFrame{DataLenPresent: true} + return true, nil } + + f.streamsMap.Iterate(fn) + return } diff --git a/stream_framer_test.go b/stream_framer_test.go index 5bd08db8..476356d8 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -2,7 +2,6 @@ package quic import ( "bytes" - "sync" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/protocol" @@ -14,7 +13,7 @@ var _ = Describe("Stream Framer", func() { var ( retransmittedFrame1, retransmittedFrame2 *frames.StreamFrame framer *streamFramer - streamsMap map[protocol.StreamID]*stream + streamsMap *streamsMap stream1, stream2 *stream fcm *mockFlowControlHandler ) @@ -31,18 +30,17 @@ var _ = Describe("Stream Framer", func() { stream1 = &stream{streamID: 10} stream2 = &stream{streamID: 11} - streamsMap = map[protocol.StreamID]*stream{ - 1: nil, 2: nil, 3: nil, 4: nil, // we have to be able to deal with nil frames - 10: stream1, - 11: stream2, - } + + streamsMap = newStreamsMap() + streamsMap.PutStream(stream1) + streamsMap.PutStream(stream2) fcm = newMockFlowControlHandler() fcm.sendWindowSizes[stream1.streamID] = protocol.MaxByteCount fcm.sendWindowSizes[stream2.streamID] = protocol.MaxByteCount fcm.sendWindowSizes[retransmittedFrame1.StreamID] = protocol.MaxByteCount fcm.sendWindowSizes[retransmittedFrame2.StreamID] = protocol.MaxByteCount - framer = newStreamFramer(&streamsMap, &sync.RWMutex{}, fcm) + framer = newStreamFramer(streamsMap, fcm) }) It("sets the DataLenPresent for dequeued retransmitted frames", func() {