diff --git a/packet_packer_test.go b/packet_packer_test.go index 6edfdde9..2ae8aa9f 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -61,11 +61,8 @@ var _ = Describe("Packet packer", func() { BeforeEach(func() { cryptoStream = &stream{flowController: flowcontrol.NewStreamFlowController(1, false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)} - streamsMap := newStreamsMap(nil, protocol.PerspectiveServer) - streamsMap.streams[1] = cryptoStream - streamsMap.openStreams = []protocol.StreamID{1} - streamFramer = newStreamFramer(streamsMap, nil) + streamFramer = newStreamFramer(cryptoStream, streamsMap, nil) packer = &packetPacker{ cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, diff --git a/session.go b/session.go index 0ccd5bab..ca7769ef 100644 --- a/session.go +++ b/session.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "errors" "fmt" - "io" "net" "sync" "time" @@ -60,7 +59,8 @@ type session struct { conn connection - streamsMap *streamsMap + streamsMap *streamsMap + cryptoStream streamI rttStats *congestion.RTTStats @@ -195,21 +195,14 @@ func (s *session) setup( } s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) - s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ReceiveConnectionFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), s.rttStats, ) s.streamsMap = newStreamsMap(s.newStream, s.perspective) - var cryptoStream io.ReadWriter - // open the crypto stream - if s.perspective == protocol.PerspectiveServer { - cryptoStream, _ = s.GetOrOpenStream(1) - _, _ = s.AcceptStream() // don't expose the crypto stream - } else { - cryptoStream, _ = s.OpenStream() - } + s.cryptoStream = s.newStream(1) + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController) var err error if s.perspective == protocol.PerspectiveServer { @@ -218,7 +211,7 @@ func (s *session) setup( } if s.version.UsesTLS() { s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer( - cryptoStream, + s.cryptoStream, tlsConf, transportParams, paramsChan, @@ -228,7 +221,7 @@ func (s *session) setup( ) } else { s.cryptoSetup, err = newCryptoSetup( - cryptoStream, + s.cryptoStream, s.connectionID, s.conn.RemoteAddr(), s.version, @@ -244,7 +237,7 @@ func (s *session) setup( transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission if s.version.UsesTLS() { s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient( - cryptoStream, + s.cryptoStream, hostname, tlsConf, transportParams, @@ -256,7 +249,7 @@ func (s *session) setup( ) } else { s.cryptoSetup, err = newCryptoSetupClient( - cryptoStream, + s.cryptoStream, hostname, s.connectionID, s.version, @@ -272,7 +265,6 @@ func (s *session) setup( return nil, nil, err } - s.streamFramer = newStreamFramer(s.streamsMap, s.connFlowController) s.packer = newPacketPacker(s.connectionID, s.cryptoSetup, s.streamFramer, @@ -529,6 +521,9 @@ func (s *session) handlePacket(p *receivedPacket) { } func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { + if frame.StreamID == 1 { + return s.cryptoStream.AddStreamFrame(frame) + } str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) if err != nil { return err @@ -610,6 +605,7 @@ func (s *session) handleCloseError(closeErr closeError) error { utils.Errorf("Closing session with error: %s", closeErr.err.Error()) } + s.cryptoStream.Cancel(quicErr) s.streamsMap.CloseWithError(quicErr) if closeErr.err == errCloseSessionForNewVersion { diff --git a/session_test.go b/session_test.go index 8e6b7571..45b8d1a4 100644 --- a/session_test.go +++ b/session_test.go @@ -189,7 +189,7 @@ var _ = Describe("Session", func() { ) Expect(err).NotTo(HaveOccurred()) sess = pSess.(*session) - Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // 1 stream: the crypto stream + Expect(sess.streamsMap.openStreams).To(BeEmpty()) }) AfterEach(func() { @@ -472,6 +472,9 @@ var _ = Describe("Session", func() { }) It("handles CONNECTION_CLOSE frames", func() { + cryptoStream := mocks.NewMockStreamI(mockCtrl) + cryptoStream.EXPECT().Cancel(gomock.Any()) + sess.cryptoStream = cryptoStream done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -482,9 +485,6 @@ var _ = Describe("Session", func() { _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) sess.streamsMap.Range(func(s streamI) { - if s.StreamID() == 1 { // the crypto stream is created by the session setup and is not a mock stream - return - } s.(*mocks.MockStreamI).EXPECT().Cancel(gomock.Any()) }) err = sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}) @@ -1541,7 +1541,7 @@ var _ = Describe("Client Session", func() { ) sess = sessP.(*session) Expect(err).ToNot(HaveOccurred()) - Expect(sess.streamsMap.openStreams).To(HaveLen(1)) + Expect(sess.streamsMap.openStreams).To(BeEmpty()) }) AfterEach(func() { diff --git a/stream_framer.go b/stream_framer.go index 0e6137e4..ae1b556c 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -7,7 +7,8 @@ import ( ) type streamFramer struct { - streamsMap *streamsMap + streamsMap *streamsMap + cryptoStream streamI connFlowController flowcontrol.ConnectionFlowController @@ -15,9 +16,14 @@ type streamFramer struct { blockedFrameQueue []*wire.BlockedFrame } -func newStreamFramer(streamsMap *streamsMap, cfc flowcontrol.ConnectionFlowController) *streamFramer { +func newStreamFramer( + cryptoStream streamI, + streamsMap *streamsMap, + cfc flowcontrol.ConnectionFlowController, +) *streamFramer { return &streamFramer{ streamsMap: streamsMap, + cryptoStream: cryptoStream, connFlowController: cfc, } } @@ -45,8 +51,7 @@ func (f *streamFramer) HasFramesForRetransmission() bool { } func (f *streamFramer) HasCryptoStreamFrame() bool { - cs, _ := f.streamsMap.GetOrOpenStream(1) - return cs.LenOfDataForWriting() > 0 + return f.cryptoStream.LenOfDataForWriting() > 0 } // TODO(lclemente): This is somewhat duplicate with the normal path for generating frames. @@ -54,13 +59,12 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str if !f.HasCryptoStreamFrame() { return nil } - cs, _ := f.streamsMap.GetOrOpenStream(1) frame := &wire.StreamFrame{ StreamID: 1, - Offset: cs.GetWriteOffset(), + Offset: f.cryptoStream.GetWriteOffset(), } frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error - frame.Data = cs.GetDataForWriting(maxLen - frameHeaderBytes) + frame.Data = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes) return frame } @@ -95,7 +99,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] var currentLen protocol.ByteCount fn := func(s streamI) (bool, error) { - if s == nil || s.StreamID() == 1 /* crypto stream is handled separately */ { + if s == nil { return true, nil } @@ -146,7 +150,6 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] } f.streamsMap.RoundRobinIterate(fn) - return } diff --git a/stream_framer_test.go b/stream_framer_test.go index e94f592c..670a4d98 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -46,7 +46,7 @@ var _ = Describe("Stream Framer", func() { streamsMap.putStream(stream2) connFC = mocks.NewMockConnectionFlowController(mockCtrl) - framer = newStreamFramer(streamsMap, connFC) + framer = newStreamFramer(nil, streamsMap, connFC) }) setNoData := func(str *mocks.MockStreamI) { diff --git a/streams_map.go b/streams_map.go index 915d35e3..aa10f129 100644 --- a/streams_map.go +++ b/streams_map.go @@ -59,13 +59,11 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) *stream sm.openStreamOrErrCond.L = &sm.mutex if pers == protocol.PerspectiveClient { - sm.nextStream = 1 + sm.nextStream = 3 sm.nextStreamToAccept = 2 - // TODO: find a better solution for opening the crypto stream - sm.maxOutgoingStreams = 1 // allow the crypto stream } else { sm.nextStream = 2 - sm.nextStreamToAccept = 1 + sm.nextStreamToAccept = 3 } return &sm @@ -271,7 +269,7 @@ func (m *streamsMap) DeleteClosedStreams() error { // RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false // It uses a round-robin-like scheduling to ensure that every stream is considered fairly -// It prioritizes the crypto- and the header-stream (StreamIDs 1 and 3) +// It prioritizes the the header-stream (StreamID 3) func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -279,19 +277,18 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { numStreams := len(m.streams) startIndex := m.roundRobinIndex - for _, i := range []protocol.StreamID{1, 3} { - cont, err := m.iterateFunc(i, fn) - if err != nil && err != errMapAccess { - return err - } - if !cont { - return nil - } + // prioritize the header stream + cont, err := m.iterateFunc(3, fn) + if err != nil && err != errMapAccess { + return err + } + if !cont { + return nil } for i := 0; i < numStreams; i++ { streamID := m.openStreams[(i+startIndex)%numStreams] - if streamID == 1 || streamID == 3 { + if streamID == 3 { continue } diff --git a/streams_map_test.go b/streams_map_test.go index e76b3c4f..bb08868e 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -280,7 +280,7 @@ var _ = Describe("Streams Map", func() { Consistently(func() bool { return accepted }).Should(BeFalse()) }) - It("accepts stream 1 first", func() { + It("start with stream 3", func() { var str streamI go func() { defer GinkgoRecover() @@ -288,10 +288,10 @@ var _ = Describe("Streams Map", func() { str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) }() - _, err := m.GetOrOpenStream(1) + _, err := m.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) Eventually(func() Stream { return str }).ShouldNot(BeNil()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) }) It("returns an implicitly opened stream, if a stream number is skipped", func() { @@ -305,7 +305,7 @@ var _ = Describe("Streams Map", func() { _, err := m.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) Eventually(func() Stream { return str }).ShouldNot(BeNil()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) }) It("returns to multiple accepts", func() { @@ -322,12 +322,12 @@ var _ = Describe("Streams Map", func() { str2, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) }() - _, err := m.GetOrOpenStream(3) // opens stream 1 and 3 + _, err := m.GetOrOpenStream(5) // opens stream 3 and 5 Expect(err).ToNot(HaveOccurred()) Eventually(func() streamI { return str1 }).ShouldNot(BeNil()) Eventually(func() streamI { return str2 }).ShouldNot(BeNil()) Expect(str1.StreamID()).ToNot(Equal(str2.StreamID())) - Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(1 + 3)) + Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(3 + 5)) }) It("waits a new stream is available", func() { @@ -339,10 +339,10 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) }() Consistently(func() streamI { return str }).Should(BeNil()) - _, err := m.GetOrOpenStream(1) + _, err := m.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) Eventually(func() streamI { return str }).ShouldNot(BeNil()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) }) It("returns multiple streams on subsequent Accept calls, if available", func() { @@ -353,22 +353,22 @@ var _ = Describe("Streams Map", func() { str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) }() - _, err := m.GetOrOpenStream(3) + _, err := m.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) Eventually(func() streamI { return str }).ShouldNot(BeNil()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(5))) }) It("blocks after accepting a stream", func() { var accepted bool - _, err := m.GetOrOpenStream(1) + _, err := m.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) str, err := m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) go func() { defer GinkgoRecover() _, _ = m.AcceptStream() @@ -400,6 +400,7 @@ var _ = Describe("Streams Map", func() { Context("as a client", func() { BeforeEach(func() { setNewStreamsMap(protocol.PerspectiveClient) + m.UpdateMaxStreamLimit(100) }) Context("client-side streams", func() { @@ -434,21 +435,21 @@ var _ = Describe("Streams Map", func() { It("doesn't reopen an already closed stream", func() { str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) - deleteStream(1) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + deleteStream(3) Expect(err).ToNot(HaveOccurred()) - str, err = m.GetOrOpenStream(1) + str, err = m.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) }) }) Context("server-side streams", func() { - It("opens stream 1 first", func() { + It("starts with stream 3", func() { s, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) - Expect(s.StreamID()).To(BeEquivalentTo(1)) + Expect(s.StreamID()).To(BeEquivalentTo(3)) Expect(m.numOutgoingStreams).To(BeZero()) Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) }) @@ -685,18 +686,16 @@ var _ = Describe("Streams Map", func() { }) }) - Context("Prioritizing crypto- and header streams", func() { + Context("Prioritizing the header stream", func() { BeforeEach(func() { - err := m.putStream(&stream{streamID: 1}) - Expect(err).NotTo(HaveOccurred()) - err = m.putStream(&stream{streamID: 3}) + err := m.putStream(&stream{streamID: 3}) Expect(err).NotTo(HaveOccurred()) }) It("gets crypto- and header stream first, then picks up at the round-robin position", func() { m.roundRobinIndex = 3 // stream 7 fn := func(str streamI) (bool, error) { - if numIterations >= 3 { + if numIterations >= 2 { return false, nil } lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) @@ -705,8 +704,8 @@ var _ = Describe("Streams Map", func() { } err := m.RoundRobinIterate(fn) Expect(err).ToNot(HaveOccurred()) - Expect(numIterations).To(Equal(3)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 3, 7})) + Expect(numIterations).To(Equal(2)) + Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{3, 7})) }) }) })