From c748a8dfc07858f56d8974ef9dec034eb070f44d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 19 May 2016 19:40:21 +0700 Subject: [PATCH] create FlowController interface --- .../flow_controller.go | 12 ++- .../flow_controller_test.go | 10 +-- flowcontrol/flowcontrol_suite_test.go | 13 +++ flowcontrol/interface.go | 17 ++++ session.go | 5 +- session_test.go | 4 +- stream.go | 9 ++- stream_test.go | 80 ++++++++++++------- 8 files changed, 107 insertions(+), 43 deletions(-) rename flow_controller.go => flowcontrol/flow_controller.go (92%) rename flow_controller_test.go => flowcontrol/flow_controller_test.go (97%) create mode 100644 flowcontrol/flowcontrol_suite_test.go create mode 100644 flowcontrol/interface.go diff --git a/flow_controller.go b/flowcontrol/flow_controller.go similarity index 92% rename from flow_controller.go rename to flowcontrol/flow_controller.go index d7cd4c35..99a26f59 100644 --- a/flow_controller.go +++ b/flowcontrol/flow_controller.go @@ -1,4 +1,4 @@ -package quic +package flowcontrol import ( "sync" @@ -24,7 +24,8 @@ type flowController struct { mutex sync.RWMutex } -func newFlowController(streamID protocol.StreamID, connectionParametersManager *handshake.ConnectionParametersManager) *flowController { +// NewFlowController gets a new flow controller +func NewFlowController(streamID protocol.StreamID, connectionParametersManager *handshake.ConnectionParametersManager) FlowController { fc := flowController{ streamID: streamID, connectionParametersManager: connectionParametersManager, @@ -160,3 +161,10 @@ func (c *flowController) CheckFlowControlViolation() bool { } return false } + +func (c *flowController) GetHighestReceived() protocol.ByteCount { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.highestReceived +} diff --git a/flow_controller_test.go b/flowcontrol/flow_controller_test.go similarity index 97% rename from flow_controller_test.go rename to flowcontrol/flow_controller_test.go index 6b2c6917..ffc0861c 100644 --- a/flow_controller_test.go +++ b/flowcontrol/flow_controller_test.go @@ -1,4 +1,4 @@ -package quic +package flowcontrol import ( "reflect" @@ -35,24 +35,24 @@ var _ = Describe("Flow controller", func() { }) It("reads the stream send and receive windows when acting as stream-level flow controller", func() { - fc := newFlowController(5, cpm) + fc := NewFlowController(5, cpm).(*flowController) Expect(fc.streamID).To(Equal(protocol.StreamID(5))) Expect(fc.receiveFlowControlWindow).To(Equal(protocol.ByteCount(2000))) }) It("reads the stream send and receive windows when acting as stream-level flow controller", func() { - fc := newFlowController(0, cpm) + fc := NewFlowController(0, cpm).(*flowController) Expect(fc.streamID).To(Equal(protocol.StreamID(0))) Expect(fc.receiveFlowControlWindow).To(Equal(protocol.ByteCount(4000))) }) It("does not set the stream flow control windows for sending", func() { - fc := newFlowController(5, cpm) + fc := NewFlowController(5, cpm).(*flowController) Expect(fc.sendFlowControlWindow).To(BeZero()) }) It("does not set the connection flow control windows for sending", func() { - fc := newFlowController(0, cpm) + fc := NewFlowController(0, cpm).(*flowController) Expect(fc.sendFlowControlWindow).To(BeZero()) }) }) diff --git a/flowcontrol/flowcontrol_suite_test.go b/flowcontrol/flowcontrol_suite_test.go new file mode 100644 index 00000000..17920ab3 --- /dev/null +++ b/flowcontrol/flowcontrol_suite_test.go @@ -0,0 +1,13 @@ +package flowcontrol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "testing" +) + +func TestCrypto(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "FlowControl Suite") +} diff --git a/flowcontrol/interface.go b/flowcontrol/interface.go new file mode 100644 index 00000000..f7cccb54 --- /dev/null +++ b/flowcontrol/interface.go @@ -0,0 +1,17 @@ +package flowcontrol + +import "github.com/lucas-clemente/quic-go/protocol" + +// A FlowController handles the flow control +type FlowController interface { + AddBytesSent(n protocol.ByteCount) + UpdateSendWindow(newOffset protocol.ByteCount) bool + SendWindowSize() protocol.ByteCount + UpdateHighestReceived(byteOffset protocol.ByteCount) protocol.ByteCount + IncrementHighestReceived(increment protocol.ByteCount) + AddBytesRead(n protocol.ByteCount) + MaybeTriggerBlocked() bool + MaybeTriggerWindowUpdate() (bool, protocol.ByteCount) + CheckFlowControlViolation() bool + GetHighestReceived() protocol.ByteCount +} diff --git a/session.go b/session.go index 0e0d6122..beb67072 100644 --- a/session.go +++ b/session.go @@ -9,6 +9,7 @@ import ( "time" "github.com/lucas-clemente/quic-go/ackhandler" + "github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" @@ -51,7 +52,7 @@ type Session struct { stopWaitingManager ackhandler.StopWaitingManager windowUpdateManager *windowUpdateManager - flowController *flowController // connection level flow controller + flowController flowcontrol.FlowController // connection level flow controller unpacker *packetUnpacker packer *packetPacker @@ -93,7 +94,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol sentPacketHandler: ackhandler.NewSentPacketHandler(stopWaitingManager), receivedPacketHandler: ackhandler.NewReceivedPacketHandler(), stopWaitingManager: stopWaitingManager, - flowController: newFlowController(0, connectionParametersManager), + flowController: flowcontrol.NewFlowController(0, connectionParametersManager), windowUpdateManager: newWindowUpdateManager(), receivedPackets: make(chan receivedPacket, protocol.MaxSessionUnprocessedPackets), closeChan: make(chan struct{}, 1), diff --git a/session_test.go b/session_test.go index 7623bb2f..bcc30454 100644 --- a/session_test.go +++ b/session_test.go @@ -254,7 +254,7 @@ var _ = Describe("Session", func() { ByteOffset: 0x8000, }) Expect(err).ToNot(HaveOccurred()) - Expect(session.streams[5].flowController.sendFlowControlWindow).To(Equal(protocol.ByteCount(0x8000))) + Expect(session.streams[5].flowController.SendWindowSize()).To(Equal(protocol.ByteCount(0x8000))) }) It("updates the Flow Control Windows of the connection", func() { @@ -263,7 +263,7 @@ var _ = Describe("Session", func() { ByteOffset: 0x800000, }) Expect(err).ToNot(HaveOccurred()) - Expect(session.flowController.sendFlowControlWindow).To(Equal(protocol.ByteCount(0x800000))) + Expect(session.flowController.SendWindowSize()).To(Equal(protocol.ByteCount(0x800000))) }) It("errors when the stream is not known", func() { diff --git a/stream.go b/stream.go index 200f340f..c76c332c 100644 --- a/stream.go +++ b/stream.go @@ -5,6 +5,7 @@ import ( "sync" "sync/atomic" + "github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" @@ -44,21 +45,21 @@ type stream struct { frameQueue streamFrameSorter newFrameOrErrCond sync.Cond - flowController *flowController - connectionFlowController *flowController + flowController flowcontrol.FlowController + connectionFlowController flowcontrol.FlowController contributesToConnectionFlowControl bool windowUpdateOrErrCond sync.Cond } // newStream creates a new Stream -func newStream(session streamHandler, connectionParameterManager *handshake.ConnectionParametersManager, connectionFlowController *flowController, StreamID protocol.StreamID) (*stream, error) { +func newStream(session streamHandler, connectionParameterManager *handshake.ConnectionParametersManager, connectionFlowController flowcontrol.FlowController, StreamID protocol.StreamID) (*stream, error) { s := &stream{ session: session, streamID: StreamID, connectionFlowController: connectionFlowController, contributesToConnectionFlowControl: true, - flowController: newFlowController(StreamID, connectionParameterManager), + flowController: flowcontrol.NewFlowController(StreamID, connectionParameterManager), } // crypto and header stream don't contribute to connection level flow control diff --git a/stream_test.go b/stream_test.go index 6cbcf713..a88449f5 100644 --- a/stream_test.go +++ b/stream_test.go @@ -4,8 +4,11 @@ import ( "bytes" "errors" "io" + "reflect" "time" + "unsafe" + "github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" @@ -46,13 +49,11 @@ var _ = Describe("Stream", func() { ) BeforeEach(func() { + var streamID protocol.StreamID = 1337 handler = &mockStreamHandler{} cpm := handshake.NewConnectionParamatersManager() - flowController := flowController{ - sendFlowControlWindow: 0xFFFFFF, - receiveFlowControlWindow: 0xFFFFFF, - } - str, _ = newStream(handler, cpm, &flowController, 1337) + flowController := flowcontrol.NewFlowController(streamID, cpm) + str, _ = newStream(handler, cpm, flowController, streamID) }) It("gets stream id", func() { @@ -250,31 +251,33 @@ var _ = Describe("Stream", func() { } err := str.AddStreamFrame(&frame) Expect(err).ToNot(HaveOccurred()) - Expect(str.flowController.highestReceived).To(Equal(protocol.ByteCount(8))) + Expect(str.flowController.GetHighestReceived()).To(Equal(protocol.ByteCount(8))) }) It("updates the connection level flow controller", func() { str.contributesToConnectionFlowControl = true - str.connectionFlowController.highestReceived = 10 + newVal := str.connectionFlowController.UpdateHighestReceived(10) + Expect(newVal).To(Equal(protocol.ByteCount(10))) frame := frames.StreamFrame{ Offset: 2, Data: []byte("foobar"), } err := str.AddStreamFrame(&frame) Expect(err).ToNot(HaveOccurred()) - Expect(str.connectionFlowController.highestReceived).To(Equal(protocol.ByteCount(10 + 8))) + Expect(str.connectionFlowController.GetHighestReceived()).To(Equal(protocol.ByteCount(10 + 8))) }) It("doesn't update the connection level flow controller if the stream doesn't contribute", func() { str.contributesToConnectionFlowControl = false - str.connectionFlowController.highestReceived = 10 + newVal := str.connectionFlowController.UpdateHighestReceived(10) + Expect(newVal).To(Equal(protocol.ByteCount(10))) frame := frames.StreamFrame{ Offset: 2, Data: []byte("foobar"), } err := str.AddStreamFrame(&frame) Expect(err).ToNot(HaveOccurred()) - Expect(str.connectionFlowController.highestReceived).To(Equal(protocol.ByteCount(10))) + Expect(str.connectionFlowController.GetHighestReceived()).To(Equal(protocol.ByteCount(10))) }) }) }) @@ -331,16 +334,19 @@ var _ = Describe("Stream", func() { Context("flow control", func() { It("writes everything if the flow control window is big enough", func() { - str.flowController.sendFlowControlWindow = 4 + updated := str.flowController.UpdateSendWindow(4) + Expect(updated).To(BeTrue()) n, err := str.Write([]byte{0xDE, 0xCA, 0xFB, 0xAD}) Expect(n).To(Equal(4)) Expect(err).ToNot(HaveOccurred()) }) It("doesn't care about the connection flow control window if it is not contributing", func() { - str.flowController.sendFlowControlWindow = 4 + updated := str.flowController.UpdateSendWindow(4) + Expect(updated).To(BeTrue()) str.contributesToConnectionFlowControl = false - str.connectionFlowController.sendFlowControlWindow = 1 + updated = str.connectionFlowController.UpdateSendWindow(1) + Expect(updated).To(BeTrue()) n, err := str.Write([]byte{0xDE, 0xCA, 0xFB, 0xAD}) Expect(n).To(Equal(4)) Expect(err).ToNot(HaveOccurred()) @@ -348,7 +354,8 @@ var _ = Describe("Stream", func() { It("waits for a stream flow control window update", func() { var b bool - str.flowController.sendFlowControlWindow = 1 + updated := str.flowController.UpdateSendWindow(1) + Expect(updated).To(BeTrue()) _, err := str.Write([]byte{0x42}) Expect(err).ToNot(HaveOccurred()) @@ -365,27 +372,33 @@ var _ = Describe("Stream", func() { It("waits for a connection flow control window update", func() { var b bool - str.flowController.sendFlowControlWindow = 1000 - str.connectionFlowController.sendFlowControlWindow = 1 + updated := str.flowController.UpdateSendWindow(1000) + Expect(updated).To(BeTrue()) + updated = str.connectionFlowController.UpdateSendWindow(1) + Expect(updated).To(BeTrue()) str.contributesToConnectionFlowControl = true _, err := str.Write([]byte{0x42}) Expect(err).ToNot(HaveOccurred()) + var sendWindowUpdated bool go func() { time.Sleep(2 * time.Millisecond) b = true - str.connectionFlowController.UpdateSendWindow(3) + sendWindowUpdated = str.connectionFlowController.UpdateSendWindow(3) str.ConnectionFlowControlWindowUpdated() }() + n, err := str.Write([]byte{0x13, 0x37}) Expect(b).To(BeTrue()) + Expect(sendWindowUpdated).To(BeTrue()) Expect(n).To(Equal(2)) Expect(err).ToNot(HaveOccurred()) }) It("splits writing of frames when given more data than the flow control windows size", func() { - str.flowController.sendFlowControlWindow = 2 + updated := str.flowController.UpdateSendWindow(2) + Expect(updated).To(BeTrue()) var b bool go func() { @@ -403,7 +416,9 @@ var _ = Describe("Stream", func() { It("writes after a flow control window update", func() { var b bool - str.flowController.sendFlowControlWindow = 1 + updated := str.flowController.UpdateSendWindow(1) + Expect(updated).To(BeTrue()) + _, err := str.Write([]byte{0x42}) Expect(err).ToNot(HaveOccurred()) @@ -420,7 +435,8 @@ var _ = Describe("Stream", func() { It("immediately returns on remote errors", func() { var b bool - str.flowController.sendFlowControlWindow = 1 + updated := str.flowController.UpdateSendWindow(1) + Expect(updated).To(BeTrue()) testErr := errors.New("test error") @@ -439,22 +455,25 @@ var _ = Describe("Stream", func() { Context("Blocked streams", func() { It("notifies the session when a stream is flow control blocked", func() { - str.flowController.sendFlowControlWindow = 1337 - str.flowController.bytesSent = 1337 + updated := str.flowController.UpdateSendWindow(1337) + Expect(updated).To(BeTrue()) + str.flowController.AddBytesSent(1337) str.maybeTriggerBlocked() Expect(handler.receivedBlockedCalled).To(BeTrue()) Expect(handler.receivedBlockedForStream).To(Equal(str.streamID)) }) It("notifies the session as soon as a stream is reaching the end of the window", func() { - str.flowController.sendFlowControlWindow = 4 + updated := str.flowController.UpdateSendWindow(4) + Expect(updated).To(BeTrue()) str.Write([]byte{0xDE, 0xCA, 0xFB, 0xAD}) Expect(handler.receivedBlockedCalled).To(BeTrue()) Expect(handler.receivedBlockedForStream).To(Equal(str.streamID)) }) It("notifies the session as soon as a stream is flow control blocked", func() { - str.flowController.sendFlowControlWindow = 2 + updated := str.flowController.UpdateSendWindow(2) + Expect(updated).To(BeTrue()) go func() { str.Write([]byte{0xDE, 0xCA, 0xFB, 0xAD}) }() @@ -468,8 +487,9 @@ var _ = Describe("Stream", func() { var receiveFlowControlWindow protocol.ByteCount = 1000 var receiveFlowControlWindowIncrement protocol.ByteCount = 1000 BeforeEach(func() { - str.flowController.receiveFlowControlWindow = receiveFlowControlWindow - str.flowController.receiveFlowControlWindowIncrement = receiveFlowControlWindowIncrement + // set receiveFlowControlWindow and receiveFlowControlWindowIncrement in the stream-level flow controller + *(*protocol.ByteCount)(unsafe.Pointer(reflect.ValueOf(str.flowController).Elem().FieldByName("receiveFlowControlWindow").UnsafeAddr())) = receiveFlowControlWindow + *(*protocol.ByteCount)(unsafe.Pointer(reflect.ValueOf(str.flowController).Elem().FieldByName("receiveFlowControlWindowIncrement").UnsafeAddr())) = receiveFlowControlWindowIncrement }) It("updates the flow control window", func() { @@ -489,8 +509,12 @@ var _ = Describe("Stream", func() { }) It("updates the connection level flow control window", func() { - str.connectionFlowController.receiveFlowControlWindow = 100 - str.connectionFlowController.receiveFlowControlWindowIncrement = 100 + var connectionReceiveFlowControlWindow protocol.ByteCount = 100 + var connectionReceiveFlowControlWindowIncrement protocol.ByteCount = 100 + // set receiveFlowControlWindow and receiveFlowControlWindowIncrement in the connection-level flow controller + *(*protocol.ByteCount)(unsafe.Pointer(reflect.ValueOf(str.connectionFlowController).Elem().FieldByName("receiveFlowControlWindow").UnsafeAddr())) = connectionReceiveFlowControlWindow + *(*protocol.ByteCount)(unsafe.Pointer(reflect.ValueOf(str.connectionFlowController).Elem().FieldByName("receiveFlowControlWindowIncrement").UnsafeAddr())) = connectionReceiveFlowControlWindowIncrement + len := 100/2 + 1 frame := frames.StreamFrame{ Offset: 0,