diff --git a/session.go b/session.go index 2bef4d89..1cfc2291 100644 --- a/session.go +++ b/session.go @@ -23,7 +23,8 @@ type receivedPacket struct { } var ( - errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream") + errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream") + errWindowUpdateOnInvalidStream = errors.New("WINDOW_UPDATE received for unknown stream") ) // StreamCallback gets a stream frame and returns a reply frame @@ -172,6 +173,7 @@ func (s *Session) handlePacket(remoteAddr interface{}, publicHeader *PublicHeade fmt.Printf("\t<- %#v\n", frame) case *frames.WindowUpdateFrame: fmt.Printf("\t<- %#v\n", frame) + err = s.handleWindowUpdateFrame(frame) case *frames.BlockedFrame: fmt.Printf("BLOCKED frame received for connection %x stream %d\n", s.connectionID, frame.StreamID) default: @@ -215,6 +217,25 @@ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error { return nil } +func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error { + if frame.StreamID == 0 { + // TODO: handle connection level WindowUpdateFrames + // return errors.New("Connection level flow control not yet implemented") + return nil + } + s.streamsMutex.RLock() + stream, ok := s.streams[frame.StreamID] + s.streamsMutex.RUnlock() + + if !ok { + return errWindowUpdateOnInvalidStream + } + + stream.UpdateFlowControlWindow(frame.ByteOffset) + + return nil +} + // TODO: Handle frame.byteOffset func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { s.streamsMutex.RLock() diff --git a/session_test.go b/session_test.go index 34e1f2db..035c918d 100644 --- a/session_test.go +++ b/session_test.go @@ -172,7 +172,27 @@ var _ = Describe("Session", func() { StreamID: 5, ErrorCode: 42, }) - Expect(err).To(MatchError("RST_STREAM received for unknown stream")) + Expect(err).To(MatchError(errRstStreamOnInvalidStream)) + }) + }) + + Context("handling WINDOW_UPDATE frames", func() { + It("updates the Flow Control Windows of a stream", func() { + _, err := session.NewStream(5) + Expect(err).ToNot(HaveOccurred()) + err = session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{ + StreamID: 5, + ByteOffset: 0x8000, + }) + Expect(err).ToNot(HaveOccurred()) + }) + + It("errors when the stream is not known", func() { + err := session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{ + StreamID: 5, + ByteOffset: 1337, + }) + Expect(err).To(MatchError(errWindowUpdateOnInvalidStream)) }) }) diff --git a/stream.go b/stream.go index 6b39fcde..8e17ad2e 100644 --- a/stream.go +++ b/stream.go @@ -2,6 +2,8 @@ package quic import ( "io" + "sync" + "sync/atomic" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/protocol" @@ -26,14 +28,19 @@ type stream struct { frameQueue []*frames.StreamFrame // TODO: replace with heap remoteErr error currentErr error + + flowControlWindow uint64 + windowUpdateCond *sync.Cond } // newStream creates a new Stream func newStream(session streamHandler, StreamID protocol.StreamID) *stream { return &stream{ - session: session, - streamID: StreamID, - streamFrames: make(chan *frames.StreamFrame, 8), // ToDo: add config option for this number + session: session, + streamID: StreamID, + streamFrames: make(chan *frames.StreamFrame, 8), // ToDo: add config option for this number + flowControlWindow: 0x4000, // 16 byte, TODO: read this from the negotiated connection parameters (TagCFCW) + windowUpdateCond: sync.NewCond(&sync.Mutex{}), } } @@ -138,21 +145,44 @@ func (s *stream) ReadByte() (byte, error) { return p[0], err } +func (s *stream) UpdateFlowControlWindow(n uint64) { + if n > s.flowControlWindow { + atomic.StoreUint64((*uint64)(&s.flowControlWindow), n) + s.windowUpdateCond.Broadcast() + } +} + func (s *stream) Write(p []byte) (int, error) { if s.remoteErr != nil { return 0, s.remoteErr } - data := make([]byte, len(p)) - copy(data, p) - err := s.session.QueueStreamFrame(&frames.StreamFrame{ - StreamID: s.streamID, - Offset: s.writeOffset, - Data: data, - }) - if err != nil { - return 0, err + + dataWritten := 0 + + for dataWritten < len(p) { + s.windowUpdateCond.L.Lock() + remainingBytesInWindow := int64(s.flowControlWindow) - int64(s.writeOffset) + for ; remainingBytesInWindow == 0; remainingBytesInWindow = int64(s.flowControlWindow) - int64(s.writeOffset) { + s.windowUpdateCond.Wait() + } + s.windowUpdateCond.L.Unlock() + + dataLen := utils.Min(len(p), int(remainingBytesInWindow)) + data := make([]byte, dataLen) + copy(data, p) + err := s.session.QueueStreamFrame(&frames.StreamFrame{ + StreamID: s.streamID, + Offset: s.writeOffset, + Data: data, + }) + if err != nil { + return 0, err + } + + dataWritten += dataLen + s.writeOffset += uint64(dataLen) } - s.writeOffset += uint64(len(p)) + return len(p), nil } @@ -171,7 +201,7 @@ func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { return nil } -// RegisterError is called by session to indicate that an error occured and the +// RegisterError is called by session to indicate that an error occurred and the // stream should be closed. func (s *stream) RegisterError(err error) { s.remoteErr = err diff --git a/stream_test.go b/stream_test.go index b4564d59..914a859b 100644 --- a/stream_test.go +++ b/stream_test.go @@ -246,6 +246,80 @@ var _ = Describe("Stream", func() { Expect(n).To(BeZero()) Expect(err).To(Equal(testErr)) }) + + Context("flow control", func() { + It("writes everything if the flow control window is big enough", func() { + str.flowControlWindow = 4 + n, err := str.Write([]byte{0xDE, 0xCA, 0xFB, 0xAD}) + Expect(n).To(Equal(4)) + Expect(err).ToNot(HaveOccurred()) + }) + + It("waits for a flow control window update", func() { + var b bool + str.flowControlWindow = 1 + _, err := str.Write([]byte{0x42}) + Expect(err).ToNot(HaveOccurred()) + + go func() { + time.Sleep(2 * time.Millisecond) + b = true + str.UpdateFlowControlWindow(3) + }() + n, err := str.Write([]byte{0x13, 0x37}) + Expect(b).To(BeTrue()) + Expect(n).To(Equal(2)) + Expect(err).ToNot(HaveOccurred()) + }) + + It("splits writing of frames when given more data than the flow control windows size", func() { + str.flowControlWindow = 2 + var b bool + + go func() { + time.Sleep(time.Millisecond) + b = true + str.UpdateFlowControlWindow(4) + }() + + n, err := str.Write([]byte{0xDE, 0xCA, 0xFB, 0xAD}) + Expect(len(handler.frames)).To(Equal(2)) + Expect(b).To(BeTrue()) + Expect(n).To(Equal(4)) + Expect(err).ToNot(HaveOccurred()) + }) + + It("writes after a flow control window update", func() { + var b bool + str.flowControlWindow = 1 + _, err := str.Write([]byte{0x42}) + Expect(err).ToNot(HaveOccurred()) + + go func() { + time.Sleep(time.Millisecond) + b = true + str.UpdateFlowControlWindow(3) + }() + n, err := str.Write([]byte{0xDE, 0xAD}) + Expect(b).To(BeTrue()) + Expect(n).To(Equal(2)) + Expect(err).ToNot(HaveOccurred()) + }) + }) + }) + + Context("flow control window updating", func() { + It("updates the flow control window", func() { + str.flowControlWindow = 3 + str.UpdateFlowControlWindow(4) + Expect(str.flowControlWindow).To(Equal(uint64(4))) + }) + + It("never shrinks the flow control window", func() { + str.flowControlWindow = 100 + str.UpdateFlowControlWindow(50) + Expect(str.flowControlWindow).To(Equal(uint64(100))) + }) }) Context("getting next str frame", func() {