implement basic flow control

fixes #37
This commit is contained in:
Marten Seemann 2016-05-03 12:07:01 +07:00
parent daf4e4a867
commit f240df6ea5
4 changed files with 161 additions and 16 deletions

View file

@ -23,7 +23,8 @@ type receivedPacket struct {
} }
var ( var (
errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream") errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream")
errWindowUpdateOnInvalidStream = errors.New("WINDOW_UPDATE received for unknown stream")
) )
// StreamCallback gets a stream frame and returns a reply frame // StreamCallback gets a stream frame and returns a reply frame
@ -172,6 +173,7 @@ func (s *Session) handlePacket(remoteAddr interface{}, publicHeader *PublicHeade
fmt.Printf("\t<- %#v\n", frame) fmt.Printf("\t<- %#v\n", frame)
case *frames.WindowUpdateFrame: case *frames.WindowUpdateFrame:
fmt.Printf("\t<- %#v\n", frame) fmt.Printf("\t<- %#v\n", frame)
err = s.handleWindowUpdateFrame(frame)
case *frames.BlockedFrame: case *frames.BlockedFrame:
fmt.Printf("BLOCKED frame received for connection %x stream %d\n", s.connectionID, frame.StreamID) fmt.Printf("BLOCKED frame received for connection %x stream %d\n", s.connectionID, frame.StreamID)
default: default:
@ -215,6 +217,25 @@ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error {
return nil return nil
} }
func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error {
if frame.StreamID == 0 {
// TODO: handle connection level WindowUpdateFrames
// return errors.New("Connection level flow control not yet implemented")
return nil
}
s.streamsMutex.RLock()
stream, ok := s.streams[frame.StreamID]
s.streamsMutex.RUnlock()
if !ok {
return errWindowUpdateOnInvalidStream
}
stream.UpdateFlowControlWindow(frame.ByteOffset)
return nil
}
// TODO: Handle frame.byteOffset // TODO: Handle frame.byteOffset
func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error {
s.streamsMutex.RLock() s.streamsMutex.RLock()

View file

@ -172,7 +172,27 @@ var _ = Describe("Session", func() {
StreamID: 5, StreamID: 5,
ErrorCode: 42, ErrorCode: 42,
}) })
Expect(err).To(MatchError("RST_STREAM received for unknown stream")) Expect(err).To(MatchError(errRstStreamOnInvalidStream))
})
})
Context("handling WINDOW_UPDATE frames", func() {
It("updates the Flow Control Windows of a stream", func() {
_, err := session.NewStream(5)
Expect(err).ToNot(HaveOccurred())
err = session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{
StreamID: 5,
ByteOffset: 0x8000,
})
Expect(err).ToNot(HaveOccurred())
})
It("errors when the stream is not known", func() {
err := session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{
StreamID: 5,
ByteOffset: 1337,
})
Expect(err).To(MatchError(errWindowUpdateOnInvalidStream))
}) })
}) })

View file

@ -2,6 +2,8 @@ package quic
import ( import (
"io" "io"
"sync"
"sync/atomic"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
@ -26,14 +28,19 @@ type stream struct {
frameQueue []*frames.StreamFrame // TODO: replace with heap frameQueue []*frames.StreamFrame // TODO: replace with heap
remoteErr error remoteErr error
currentErr error currentErr error
flowControlWindow uint64
windowUpdateCond *sync.Cond
} }
// newStream creates a new Stream // newStream creates a new Stream
func newStream(session streamHandler, StreamID protocol.StreamID) *stream { func newStream(session streamHandler, StreamID protocol.StreamID) *stream {
return &stream{ return &stream{
session: session, session: session,
streamID: StreamID, streamID: StreamID,
streamFrames: make(chan *frames.StreamFrame, 8), // ToDo: add config option for this number streamFrames: make(chan *frames.StreamFrame, 8), // ToDo: add config option for this number
flowControlWindow: 0x4000, // 16 byte, TODO: read this from the negotiated connection parameters (TagCFCW)
windowUpdateCond: sync.NewCond(&sync.Mutex{}),
} }
} }
@ -138,21 +145,44 @@ func (s *stream) ReadByte() (byte, error) {
return p[0], err return p[0], err
} }
func (s *stream) UpdateFlowControlWindow(n uint64) {
if n > s.flowControlWindow {
atomic.StoreUint64((*uint64)(&s.flowControlWindow), n)
s.windowUpdateCond.Broadcast()
}
}
func (s *stream) Write(p []byte) (int, error) { func (s *stream) Write(p []byte) (int, error) {
if s.remoteErr != nil { if s.remoteErr != nil {
return 0, s.remoteErr return 0, s.remoteErr
} }
data := make([]byte, len(p))
copy(data, p) dataWritten := 0
err := s.session.QueueStreamFrame(&frames.StreamFrame{
StreamID: s.streamID, for dataWritten < len(p) {
Offset: s.writeOffset, s.windowUpdateCond.L.Lock()
Data: data, remainingBytesInWindow := int64(s.flowControlWindow) - int64(s.writeOffset)
}) for ; remainingBytesInWindow == 0; remainingBytesInWindow = int64(s.flowControlWindow) - int64(s.writeOffset) {
if err != nil { s.windowUpdateCond.Wait()
return 0, err }
s.windowUpdateCond.L.Unlock()
dataLen := utils.Min(len(p), int(remainingBytesInWindow))
data := make([]byte, dataLen)
copy(data, p)
err := s.session.QueueStreamFrame(&frames.StreamFrame{
StreamID: s.streamID,
Offset: s.writeOffset,
Data: data,
})
if err != nil {
return 0, err
}
dataWritten += dataLen
s.writeOffset += uint64(dataLen)
} }
s.writeOffset += uint64(len(p))
return len(p), nil return len(p), nil
} }
@ -171,7 +201,7 @@ func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error {
return nil return nil
} }
// RegisterError is called by session to indicate that an error occured and the // RegisterError is called by session to indicate that an error occurred and the
// stream should be closed. // stream should be closed.
func (s *stream) RegisterError(err error) { func (s *stream) RegisterError(err error) {
s.remoteErr = err s.remoteErr = err

View file

@ -246,6 +246,80 @@ var _ = Describe("Stream", func() {
Expect(n).To(BeZero()) Expect(n).To(BeZero())
Expect(err).To(Equal(testErr)) Expect(err).To(Equal(testErr))
}) })
Context("flow control", func() {
It("writes everything if the flow control window is big enough", func() {
str.flowControlWindow = 4
n, err := str.Write([]byte{0xDE, 0xCA, 0xFB, 0xAD})
Expect(n).To(Equal(4))
Expect(err).ToNot(HaveOccurred())
})
It("waits for a flow control window update", func() {
var b bool
str.flowControlWindow = 1
_, err := str.Write([]byte{0x42})
Expect(err).ToNot(HaveOccurred())
go func() {
time.Sleep(2 * time.Millisecond)
b = true
str.UpdateFlowControlWindow(3)
}()
n, err := str.Write([]byte{0x13, 0x37})
Expect(b).To(BeTrue())
Expect(n).To(Equal(2))
Expect(err).ToNot(HaveOccurred())
})
It("splits writing of frames when given more data than the flow control windows size", func() {
str.flowControlWindow = 2
var b bool
go func() {
time.Sleep(time.Millisecond)
b = true
str.UpdateFlowControlWindow(4)
}()
n, err := str.Write([]byte{0xDE, 0xCA, 0xFB, 0xAD})
Expect(len(handler.frames)).To(Equal(2))
Expect(b).To(BeTrue())
Expect(n).To(Equal(4))
Expect(err).ToNot(HaveOccurred())
})
It("writes after a flow control window update", func() {
var b bool
str.flowControlWindow = 1
_, err := str.Write([]byte{0x42})
Expect(err).ToNot(HaveOccurred())
go func() {
time.Sleep(time.Millisecond)
b = true
str.UpdateFlowControlWindow(3)
}()
n, err := str.Write([]byte{0xDE, 0xAD})
Expect(b).To(BeTrue())
Expect(n).To(Equal(2))
Expect(err).ToNot(HaveOccurred())
})
})
})
Context("flow control window updating", func() {
It("updates the flow control window", func() {
str.flowControlWindow = 3
str.UpdateFlowControlWindow(4)
Expect(str.flowControlWindow).To(Equal(uint64(4)))
})
It("never shrinks the flow control window", func() {
str.flowControlWindow = 100
str.UpdateFlowControlWindow(50)
Expect(str.flowControlWindow).To(Equal(uint64(100)))
})
}) })
Context("getting next str frame", func() { Context("getting next str frame", func() {