handle the crypto stream separately

This commit is contained in:
Marten Seemann 2017-10-21 11:43:57 +07:00
parent 9825ddb43a
commit 5ee7b205c6
7 changed files with 66 additions and 74 deletions

View file

@ -61,11 +61,8 @@ var _ = Describe("Packet packer", func() {
BeforeEach(func() { BeforeEach(func() {
cryptoStream = &stream{flowController: flowcontrol.NewStreamFlowController(1, false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)} cryptoStream = &stream{flowController: flowcontrol.NewStreamFlowController(1, false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)}
streamsMap := newStreamsMap(nil, protocol.PerspectiveServer) streamsMap := newStreamsMap(nil, protocol.PerspectiveServer)
streamsMap.streams[1] = cryptoStream streamFramer = newStreamFramer(cryptoStream, streamsMap, nil)
streamsMap.openStreams = []protocol.StreamID{1}
streamFramer = newStreamFramer(streamsMap, nil)
packer = &packetPacker{ packer = &packetPacker{
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},

View file

@ -5,7 +5,6 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"time" "time"
@ -60,7 +59,8 @@ type session struct {
conn connection conn connection
streamsMap *streamsMap streamsMap *streamsMap
cryptoStream streamI
rttStats *congestion.RTTStats rttStats *congestion.RTTStats
@ -195,21 +195,14 @@ func (s *session) setup(
} }
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController( s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ReceiveConnectionFlowControlWindow, protocol.ReceiveConnectionFlowControlWindow,
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
s.rttStats, s.rttStats,
) )
s.streamsMap = newStreamsMap(s.newStream, s.perspective) s.streamsMap = newStreamsMap(s.newStream, s.perspective)
var cryptoStream io.ReadWriter s.cryptoStream = s.newStream(1)
// open the crypto stream s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController)
if s.perspective == protocol.PerspectiveServer {
cryptoStream, _ = s.GetOrOpenStream(1)
_, _ = s.AcceptStream() // don't expose the crypto stream
} else {
cryptoStream, _ = s.OpenStream()
}
var err error var err error
if s.perspective == protocol.PerspectiveServer { if s.perspective == protocol.PerspectiveServer {
@ -218,7 +211,7 @@ func (s *session) setup(
} }
if s.version.UsesTLS() { if s.version.UsesTLS() {
s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer( s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer(
cryptoStream, s.cryptoStream,
tlsConf, tlsConf,
transportParams, transportParams,
paramsChan, paramsChan,
@ -228,7 +221,7 @@ func (s *session) setup(
) )
} else { } else {
s.cryptoSetup, err = newCryptoSetup( s.cryptoSetup, err = newCryptoSetup(
cryptoStream, s.cryptoStream,
s.connectionID, s.connectionID,
s.conn.RemoteAddr(), s.conn.RemoteAddr(),
s.version, s.version,
@ -244,7 +237,7 @@ func (s *session) setup(
transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission
if s.version.UsesTLS() { if s.version.UsesTLS() {
s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient( s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient(
cryptoStream, s.cryptoStream,
hostname, hostname,
tlsConf, tlsConf,
transportParams, transportParams,
@ -256,7 +249,7 @@ func (s *session) setup(
) )
} else { } else {
s.cryptoSetup, err = newCryptoSetupClient( s.cryptoSetup, err = newCryptoSetupClient(
cryptoStream, s.cryptoStream,
hostname, hostname,
s.connectionID, s.connectionID,
s.version, s.version,
@ -272,7 +265,6 @@ func (s *session) setup(
return nil, nil, err return nil, nil, err
} }
s.streamFramer = newStreamFramer(s.streamsMap, s.connFlowController)
s.packer = newPacketPacker(s.connectionID, s.packer = newPacketPacker(s.connectionID,
s.cryptoSetup, s.cryptoSetup,
s.streamFramer, s.streamFramer,
@ -529,6 +521,9 @@ func (s *session) handlePacket(p *receivedPacket) {
} }
func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { func (s *session) handleStreamFrame(frame *wire.StreamFrame) error {
if frame.StreamID == 1 {
return s.cryptoStream.AddStreamFrame(frame)
}
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
if err != nil { if err != nil {
return err return err
@ -610,6 +605,7 @@ func (s *session) handleCloseError(closeErr closeError) error {
utils.Errorf("Closing session with error: %s", closeErr.err.Error()) utils.Errorf("Closing session with error: %s", closeErr.err.Error())
} }
s.cryptoStream.Cancel(quicErr)
s.streamsMap.CloseWithError(quicErr) s.streamsMap.CloseWithError(quicErr)
if closeErr.err == errCloseSessionForNewVersion { if closeErr.err == errCloseSessionForNewVersion {

View file

@ -189,7 +189,7 @@ var _ = Describe("Session", func() {
) )
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
sess = pSess.(*session) sess = pSess.(*session)
Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // 1 stream: the crypto stream Expect(sess.streamsMap.openStreams).To(BeEmpty())
}) })
AfterEach(func() { AfterEach(func() {
@ -472,6 +472,9 @@ var _ = Describe("Session", func() {
}) })
It("handles CONNECTION_CLOSE frames", func() { It("handles CONNECTION_CLOSE frames", func() {
cryptoStream := mocks.NewMockStreamI(mockCtrl)
cryptoStream.EXPECT().Cancel(gomock.Any())
sess.cryptoStream = cryptoStream
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -482,9 +485,6 @@ var _ = Describe("Session", func() {
_, err := sess.GetOrOpenStream(5) _, err := sess.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
sess.streamsMap.Range(func(s streamI) { sess.streamsMap.Range(func(s streamI) {
if s.StreamID() == 1 { // the crypto stream is created by the session setup and is not a mock stream
return
}
s.(*mocks.MockStreamI).EXPECT().Cancel(gomock.Any()) s.(*mocks.MockStreamI).EXPECT().Cancel(gomock.Any())
}) })
err = sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}) err = sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}})
@ -1541,7 +1541,7 @@ var _ = Describe("Client Session", func() {
) )
sess = sessP.(*session) sess = sessP.(*session)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(sess.streamsMap.openStreams).To(HaveLen(1)) Expect(sess.streamsMap.openStreams).To(BeEmpty())
}) })
AfterEach(func() { AfterEach(func() {

View file

@ -7,7 +7,8 @@ import (
) )
type streamFramer struct { type streamFramer struct {
streamsMap *streamsMap streamsMap *streamsMap
cryptoStream streamI
connFlowController flowcontrol.ConnectionFlowController connFlowController flowcontrol.ConnectionFlowController
@ -15,9 +16,14 @@ type streamFramer struct {
blockedFrameQueue []*wire.BlockedFrame blockedFrameQueue []*wire.BlockedFrame
} }
func newStreamFramer(streamsMap *streamsMap, cfc flowcontrol.ConnectionFlowController) *streamFramer { func newStreamFramer(
cryptoStream streamI,
streamsMap *streamsMap,
cfc flowcontrol.ConnectionFlowController,
) *streamFramer {
return &streamFramer{ return &streamFramer{
streamsMap: streamsMap, streamsMap: streamsMap,
cryptoStream: cryptoStream,
connFlowController: cfc, connFlowController: cfc,
} }
} }
@ -45,8 +51,7 @@ func (f *streamFramer) HasFramesForRetransmission() bool {
} }
func (f *streamFramer) HasCryptoStreamFrame() bool { func (f *streamFramer) HasCryptoStreamFrame() bool {
cs, _ := f.streamsMap.GetOrOpenStream(1) return f.cryptoStream.LenOfDataForWriting() > 0
return cs.LenOfDataForWriting() > 0
} }
// TODO(lclemente): This is somewhat duplicate with the normal path for generating frames. // TODO(lclemente): This is somewhat duplicate with the normal path for generating frames.
@ -54,13 +59,12 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str
if !f.HasCryptoStreamFrame() { if !f.HasCryptoStreamFrame() {
return nil return nil
} }
cs, _ := f.streamsMap.GetOrOpenStream(1)
frame := &wire.StreamFrame{ frame := &wire.StreamFrame{
StreamID: 1, StreamID: 1,
Offset: cs.GetWriteOffset(), Offset: f.cryptoStream.GetWriteOffset(),
} }
frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error
frame.Data = cs.GetDataForWriting(maxLen - frameHeaderBytes) frame.Data = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes)
return frame return frame
} }
@ -95,7 +99,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
var currentLen protocol.ByteCount var currentLen protocol.ByteCount
fn := func(s streamI) (bool, error) { fn := func(s streamI) (bool, error) {
if s == nil || s.StreamID() == 1 /* crypto stream is handled separately */ { if s == nil {
return true, nil return true, nil
} }
@ -146,7 +150,6 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
} }
f.streamsMap.RoundRobinIterate(fn) f.streamsMap.RoundRobinIterate(fn)
return return
} }

View file

@ -46,7 +46,7 @@ var _ = Describe("Stream Framer", func() {
streamsMap.putStream(stream2) streamsMap.putStream(stream2)
connFC = mocks.NewMockConnectionFlowController(mockCtrl) connFC = mocks.NewMockConnectionFlowController(mockCtrl)
framer = newStreamFramer(streamsMap, connFC) framer = newStreamFramer(nil, streamsMap, connFC)
}) })
setNoData := func(str *mocks.MockStreamI) { setNoData := func(str *mocks.MockStreamI) {

View file

@ -59,13 +59,11 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) *stream
sm.openStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex
if pers == protocol.PerspectiveClient { if pers == protocol.PerspectiveClient {
sm.nextStream = 1 sm.nextStream = 3
sm.nextStreamToAccept = 2 sm.nextStreamToAccept = 2
// TODO: find a better solution for opening the crypto stream
sm.maxOutgoingStreams = 1 // allow the crypto stream
} else { } else {
sm.nextStream = 2 sm.nextStream = 2
sm.nextStreamToAccept = 1 sm.nextStreamToAccept = 3
} }
return &sm return &sm
@ -271,7 +269,7 @@ func (m *streamsMap) DeleteClosedStreams() error {
// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false // RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false
// It uses a round-robin-like scheduling to ensure that every stream is considered fairly // It uses a round-robin-like scheduling to ensure that every stream is considered fairly
// It prioritizes the crypto- and the header-stream (StreamIDs 1 and 3) // It prioritizes the the header-stream (StreamID 3)
func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -279,19 +277,18 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
numStreams := len(m.streams) numStreams := len(m.streams)
startIndex := m.roundRobinIndex startIndex := m.roundRobinIndex
for _, i := range []protocol.StreamID{1, 3} { // prioritize the header stream
cont, err := m.iterateFunc(i, fn) cont, err := m.iterateFunc(3, fn)
if err != nil && err != errMapAccess { if err != nil && err != errMapAccess {
return err return err
} }
if !cont { if !cont {
return nil return nil
}
} }
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
streamID := m.openStreams[(i+startIndex)%numStreams] streamID := m.openStreams[(i+startIndex)%numStreams]
if streamID == 1 || streamID == 3 { if streamID == 3 {
continue continue
} }

View file

@ -280,7 +280,7 @@ var _ = Describe("Streams Map", func() {
Consistently(func() bool { return accepted }).Should(BeFalse()) Consistently(func() bool { return accepted }).Should(BeFalse())
}) })
It("accepts stream 1 first", func() { It("start with stream 3", func() {
var str streamI var str streamI
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -288,10 +288,10 @@ var _ = Describe("Streams Map", func() {
str, err = m.AcceptStream() str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}() }()
_, err := m.GetOrOpenStream(1) _, err := m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(func() Stream { return str }).ShouldNot(BeNil()) Eventually(func() Stream { return str }).ShouldNot(BeNil())
Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
}) })
It("returns an implicitly opened stream, if a stream number is skipped", func() { It("returns an implicitly opened stream, if a stream number is skipped", func() {
@ -305,7 +305,7 @@ var _ = Describe("Streams Map", func() {
_, err := m.GetOrOpenStream(5) _, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(func() Stream { return str }).ShouldNot(BeNil()) Eventually(func() Stream { return str }).ShouldNot(BeNil())
Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
}) })
It("returns to multiple accepts", func() { It("returns to multiple accepts", func() {
@ -322,12 +322,12 @@ var _ = Describe("Streams Map", func() {
str2, err = m.AcceptStream() str2, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}() }()
_, err := m.GetOrOpenStream(3) // opens stream 1 and 3 _, err := m.GetOrOpenStream(5) // opens stream 3 and 5
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(func() streamI { return str1 }).ShouldNot(BeNil()) Eventually(func() streamI { return str1 }).ShouldNot(BeNil())
Eventually(func() streamI { return str2 }).ShouldNot(BeNil()) Eventually(func() streamI { return str2 }).ShouldNot(BeNil())
Expect(str1.StreamID()).ToNot(Equal(str2.StreamID())) Expect(str1.StreamID()).ToNot(Equal(str2.StreamID()))
Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(1 + 3)) Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(3 + 5))
}) })
It("waits a new stream is available", func() { It("waits a new stream is available", func() {
@ -339,10 +339,10 @@ var _ = Describe("Streams Map", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}() }()
Consistently(func() streamI { return str }).Should(BeNil()) Consistently(func() streamI { return str }).Should(BeNil())
_, err := m.GetOrOpenStream(1) _, err := m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(func() streamI { return str }).ShouldNot(BeNil()) Eventually(func() streamI { return str }).ShouldNot(BeNil())
Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
}) })
It("returns multiple streams on subsequent Accept calls, if available", func() { It("returns multiple streams on subsequent Accept calls, if available", func() {
@ -353,22 +353,22 @@ var _ = Describe("Streams Map", func() {
str, err = m.AcceptStream() str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}() }()
_, err := m.GetOrOpenStream(3) _, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(func() streamI { return str }).ShouldNot(BeNil()) Eventually(func() streamI { return str }).ShouldNot(BeNil())
Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
str, err = m.AcceptStream() str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) Expect(str.StreamID()).To(Equal(protocol.StreamID(5)))
}) })
It("blocks after accepting a stream", func() { It("blocks after accepting a stream", func() {
var accepted bool var accepted bool
_, err := m.GetOrOpenStream(1) _, err := m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := m.AcceptStream() str, err := m.AcceptStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, _ = m.AcceptStream() _, _ = m.AcceptStream()
@ -400,6 +400,7 @@ var _ = Describe("Streams Map", func() {
Context("as a client", func() { Context("as a client", func() {
BeforeEach(func() { BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveClient) setNewStreamsMap(protocol.PerspectiveClient)
m.UpdateMaxStreamLimit(100)
}) })
Context("client-side streams", func() { Context("client-side streams", func() {
@ -434,21 +435,21 @@ var _ = Describe("Streams Map", func() {
It("doesn't reopen an already closed stream", func() { It("doesn't reopen an already closed stream", func() {
str, err := m.OpenStream() str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
deleteStream(1) deleteStream(3)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(1) str, err = m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil()) Expect(str).To(BeNil())
}) })
}) })
Context("server-side streams", func() { Context("server-side streams", func() {
It("opens stream 1 first", func() { It("starts with stream 3", func() {
s, err := m.OpenStream() s, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil()) Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(BeEquivalentTo(1)) Expect(s.StreamID()).To(BeEquivalentTo(3))
Expect(m.numOutgoingStreams).To(BeZero()) Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
}) })
@ -685,18 +686,16 @@ var _ = Describe("Streams Map", func() {
}) })
}) })
Context("Prioritizing crypto- and header streams", func() { Context("Prioritizing the header stream", func() {
BeforeEach(func() { BeforeEach(func() {
err := m.putStream(&stream{streamID: 1}) err := m.putStream(&stream{streamID: 3})
Expect(err).NotTo(HaveOccurred())
err = m.putStream(&stream{streamID: 3})
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
It("gets crypto- and header stream first, then picks up at the round-robin position", func() { It("gets crypto- and header stream first, then picks up at the round-robin position", func() {
m.roundRobinIndex = 3 // stream 7 m.roundRobinIndex = 3 // stream 7
fn := func(str streamI) (bool, error) { fn := func(str streamI) (bool, error) {
if numIterations >= 3 { if numIterations >= 2 {
return false, nil return false, nil
} }
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
@ -705,8 +704,8 @@ var _ = Describe("Streams Map", func() {
} }
err := m.RoundRobinIterate(fn) err := m.RoundRobinIterate(fn)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(numIterations).To(Equal(3)) Expect(numIterations).To(Equal(2))
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 3, 7})) Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{3, 7}))
}) })
}) })
}) })