refactor stream to remove a couple of race conditions

fixes #42
This commit is contained in:
Lucas Clemente 2016-05-11 22:30:14 +02:00
parent 8b1009d482
commit 060df6be7d
5 changed files with 193 additions and 208 deletions

View file

@ -484,8 +484,6 @@ func (s *Session) garbageCollectStreams() {
if v == nil { if v == nil {
continue continue
} }
// Strictly speaking, this is not thread-safe. However it doesn't matter
// if the stream is deleted just shortly later, so we don't care.
if v.finishedReading() { if v.finishedReading() {
s.streams[k] = nil s.streams[k] = nil
} }

220
stream.go
View file

@ -3,6 +3,7 @@ package quic
import ( import (
"io" "io"
"sync" "sync"
"sync/atomic"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
@ -16,73 +17,104 @@ type streamHandler interface {
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
type stream struct { type stream struct {
session streamHandler
streamID protocol.StreamID streamID protocol.StreamID
// The chan of unordered stream frames. A nil in this channel is sent by the session streamHandler
// session if an error occurred, in this case, remoteErr is filled before.
streamFrames chan *frames.StreamFrame readPosInFrame int
currentFrame *frames.StreamFrame
readPosInFrame protocol.ByteCount
writeOffset protocol.ByteCount writeOffset protocol.ByteCount
readOffset protocol.ByteCount readOffset protocol.ByteCount
frameQueue []*frames.StreamFrame // TODO: replace with heap
remoteErr error
currentErr error
connectionParameterManager *handshake.ConnectionParametersManager // Once set, err must not be changed!
err error
mutex sync.Mutex
flowControlWindow protocol.ByteCount eof int32 // really a bool
windowUpdateCond *sync.Cond
frameQueue streamFrameSorter
newFrameOrErrCond sync.Cond
flowControlWindow protocol.ByteCount
windowUpdateOrErrCond sync.Cond
} }
// newStream creates a new Stream // newStream creates a new Stream
func newStream(session streamHandler, connectionParameterManager *handshake.ConnectionParametersManager, StreamID protocol.StreamID) (*stream, error) { func newStream(session streamHandler, connectionParameterManager *handshake.ConnectionParametersManager, StreamID protocol.StreamID) (*stream, error) {
s := &stream{ s := &stream{
session: session, session: session,
streamID: StreamID, streamID: StreamID,
streamFrames: make(chan *frames.StreamFrame, 8), // ToDo: add config option for this number
connectionParameterManager: connectionParameterManager,
windowUpdateCond: sync.NewCond(&sync.Mutex{}),
} }
flowControlWindow, err := connectionParameterManager.GetStreamFlowControlWindow() s.newFrameOrErrCond.L = &s.mutex
s.windowUpdateOrErrCond.L = &s.mutex
var err error
s.flowControlWindow, err = connectionParameterManager.GetStreamFlowControlWindow()
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.flowControlWindow = flowControlWindow
return s, nil return s, nil
} }
// Read implements io.Reader // Read implements io.Reader. It is not thread safe!
func (s *stream) Read(p []byte) (int, error) { func (s *stream) Read(p []byte) (int, error) {
if s.currentErr != nil { if atomic.LoadInt32(&s.eof) != 0 {
return 0, s.currentErr return 0, io.EOF
} }
bytesRead := 0 bytesRead := 0
for bytesRead < len(p) { for bytesRead < len(p) {
if s.currentFrame == nil { s.mutex.Lock()
var err error frame := s.frameQueue.Head()
s.currentFrame, err = s.getNextFrameInOrder(bytesRead == 0)
if err != nil { if frame == nil && bytesRead > 0 {
s.currentErr = err defer s.mutex.Unlock()
return bytesRead, err return bytesRead, s.err
}
if s.currentFrame == nil {
return bytesRead, nil
}
s.readPosInFrame = 0
} }
// TODO: don't cast to int for comparing
m := utils.Min(len(p)-bytesRead, int(protocol.ByteCount(len(s.currentFrame.Data))-s.readPosInFrame)) for {
copy(p[bytesRead:], s.currentFrame.Data[s.readPosInFrame:]) // Stop waiting on errors
s.readPosInFrame += protocol.ByteCount(m) if s.err != nil {
break
}
if frame != nil {
// Pop and continue if the frame doesn't have any new data
if frame.Offset+protocol.ByteCount(len(frame.Data)) <= s.readOffset && !frame.FinBit {
s.frameQueue.Pop()
frame = s.frameQueue.Head()
continue
}
// If the frame's offset is <= our current read pos, and we didn't
// go into the previous if, we can read data from the frame.
if frame.Offset <= s.readOffset {
// Set our read position in the frame properly
s.readPosInFrame = int(s.readOffset - frame.Offset)
break
}
}
s.newFrameOrErrCond.Wait()
frame = s.frameQueue.Head()
}
s.mutex.Unlock()
if frame == nil {
atomic.StoreInt32(&s.eof, 1)
// We have an err and no data, return the error
return bytesRead, s.err
}
m := utils.Min(len(p)-bytesRead, len(frame.Data)-s.readPosInFrame)
copy(p[bytesRead:], frame.Data[s.readPosInFrame:])
s.readPosInFrame += m
bytesRead += m bytesRead += m
s.readOffset += protocol.ByteCount(m) s.readOffset += protocol.ByteCount(m)
if s.readPosInFrame >= protocol.ByteCount(len(s.currentFrame.Data)) { if s.readPosInFrame >= len(frame.Data) {
fin := s.currentFrame.FinBit fin := frame.FinBit
s.currentFrame = nil s.mutex.Lock()
s.frameQueue.Pop()
s.mutex.Unlock()
if fin { if fin {
s.currentErr = io.EOF atomic.StoreInt32(&s.eof, 1)
return bytesRead, io.EOF return bytesRead, io.EOF
} }
} }
@ -91,96 +123,46 @@ func (s *stream) Read(p []byte) (int, error) {
return bytesRead, nil return bytesRead, nil
} }
func (s *stream) getNextFrameInOrder(wait bool) (*frames.StreamFrame, error) {
// First, check the queue
for i, f := range s.frameQueue {
if f.Offset == s.readOffset {
// Move last element into position i
s.frameQueue[i] = s.frameQueue[len(s.frameQueue)-1]
s.frameQueue = s.frameQueue[:len(s.frameQueue)-1]
return f, nil
}
}
for {
nextFrameFromChannel, err := s.nextFrameInChan(wait)
if err != nil {
return nil, err
}
if nextFrameFromChannel == nil {
return nil, nil
}
if nextFrameFromChannel.Offset == s.readOffset {
return nextFrameFromChannel, nil
}
// Discard if we already know it
if nextFrameFromChannel.Offset < s.readOffset {
continue
}
// Append to queue
s.frameQueue = append(s.frameQueue, nextFrameFromChannel)
}
}
func (s *stream) nextFrameInChan(blocking bool) (*frames.StreamFrame, error) {
var f *frames.StreamFrame
var ok bool
if blocking {
f, ok = <-s.streamFrames
} else {
select {
case f, ok = <-s.streamFrames:
default:
return nil, nil
}
}
if !ok {
panic("Stream: internal inconsistency: encountered closed chan without nil value (remote error) or FIN bit")
}
if f == nil {
// We read nil, which indicates a remoteErr
return nil, s.remoteErr
}
return f, nil
}
// ReadByte implements io.ByteReader // ReadByte implements io.ByteReader
func (s *stream) ReadByte() (byte, error) { func (s *stream) ReadByte() (byte, error) {
// TODO: Optimize
p := make([]byte, 1) p := make([]byte, 1)
_, err := io.ReadFull(s, p) _, err := io.ReadFull(s, p)
return p[0], err return p[0], err
} }
func (s *stream) UpdateFlowControlWindow(n protocol.ByteCount) { func (s *stream) UpdateFlowControlWindow(n protocol.ByteCount) {
s.mutex.Lock()
defer s.mutex.Unlock()
if n > s.flowControlWindow { if n > s.flowControlWindow {
s.windowUpdateCond.L.Lock()
s.flowControlWindow = n s.flowControlWindow = n
s.windowUpdateCond.L.Unlock() s.windowUpdateOrErrCond.Broadcast()
s.windowUpdateCond.Broadcast()
} }
} }
func (s *stream) Write(p []byte) (int, error) { func (s *stream) Write(p []byte) (int, error) {
if s.remoteErr != nil { s.mutex.Lock()
return 0, s.remoteErr err := s.err
s.mutex.Unlock()
if err != nil {
return 0, err
} }
dataWritten := 0 dataWritten := 0
for dataWritten < len(p) { for dataWritten < len(p) {
s.windowUpdateCond.L.Lock() s.mutex.Lock()
remainingBytesInWindow := int64(s.flowControlWindow) - int64(s.writeOffset) remainingBytesInWindow := int64(s.flowControlWindow) - int64(s.writeOffset)
for ; remainingBytesInWindow == 0; remainingBytesInWindow = int64(s.flowControlWindow) - int64(s.writeOffset) { for remainingBytesInWindow == 0 && s.err == nil {
if s.remoteErr != nil { s.windowUpdateOrErrCond.Wait()
return 0, s.remoteErr remainingBytesInWindow = int64(s.flowControlWindow) - int64(s.writeOffset)
} }
s.windowUpdateCond.Wait() s.mutex.Unlock()
if remainingBytesInWindow == 0 {
// We must have had an error
return 0, s.err
} }
s.windowUpdateCond.L.Unlock()
dataLen := utils.Min(len(p), int(remainingBytesInWindow)) dataLen := utils.Min(len(p), int(remainingBytesInWindow))
data := make([]byte, dataLen) data := make([]byte, dataLen)
@ -212,18 +194,26 @@ func (s *stream) Close() error {
// AddStreamFrame adds a new stream frame // AddStreamFrame adds a new stream frame
func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error {
s.streamFrames <- frame s.mutex.Lock()
s.frameQueue.Push(frame)
s.mutex.Unlock()
s.newFrameOrErrCond.Signal()
return nil return nil
} }
// RegisterError is called by session to indicate that an error occurred and the // RegisterError is called by session to indicate that an error occurred and the
// stream should be closed. // stream should be closed.
func (s *stream) RegisterError(err error) { func (s *stream) RegisterError(err error) {
s.remoteErr = err s.mutex.Lock()
s.streamFrames <- nil defer s.mutex.Unlock()
s.windowUpdateCond.Broadcast() if s.err != nil { // s.err must not be changed!
return
}
s.err = err
s.windowUpdateOrErrCond.Signal()
s.newFrameOrErrCond.Signal()
} }
func (s *stream) finishedReading() bool { func (s *stream) finishedReading() bool {
return s.currentErr != nil return atomic.LoadInt32(&s.eof) != 0
} }

35
stream_frame_sorter.go Normal file
View file

@ -0,0 +1,35 @@
package quic
import "github.com/lucas-clemente/quic-go/frames"
// TODO: This is currently quite inefficient
type streamFrameSorter struct {
items []*frames.StreamFrame
}
func (s *streamFrameSorter) Push(val *frames.StreamFrame) {
for i, f := range s.items {
if f.Offset > val.Offset {
// Insert here
s.items = append(s.items, nil)
copy(s.items[i+1:], s.items[i:])
s.items[i] = val
return
}
}
// Append at the end
s.items = append(s.items, val)
}
func (s *streamFrameSorter) Pop() *frames.StreamFrame {
res := s.items[0]
s.items = s.items[1:]
return res
}
func (s *streamFrameSorter) Head() *frames.StreamFrame {
if len(s.items) > 0 {
return s.items[0]
}
return nil
}

View file

@ -0,0 +1,49 @@
package quic
import (
"github.com/lucas-clemente/quic-go/frames"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("StreamFrame sorter", func() {
var (
s streamFrameSorter
)
BeforeEach(func() {
s = streamFrameSorter{}
})
It("head returns nil when empty", func() {
Expect(s.Head()).To(BeNil())
})
It("inserts and pops a single frame", func() {
f := &frames.StreamFrame{}
s.Push(f)
Expect(s.Head()).To(Equal(f))
Expect(s.Pop()).To(Equal(f))
Expect(s.Head()).To(BeNil())
})
It("inserts two frames in order", func() {
f1 := &frames.StreamFrame{Offset: 1}
f2 := &frames.StreamFrame{Offset: 2}
s.Push(f1)
s.Push(f2)
Expect(s.Pop()).To(Equal(f1))
Expect(s.Pop()).To(Equal(f2))
Expect(s.Head()).To(BeNil())
})
It("inserts two frames out of order", func() {
f1 := &frames.StreamFrame{Offset: 1}
f2 := &frames.StreamFrame{Offset: 2}
s.Push(f2)
s.Push(f1)
Expect(s.Pop()).To(Equal(f1))
Expect(s.Pop()).To(Equal(f2))
Expect(s.Head()).To(BeNil())
})
})

View file

@ -179,15 +179,15 @@ var _ = Describe("Stream", func() {
It("discards unneeded str frames", func() { It("discards unneeded str frames", func() {
frame1 := frames.StreamFrame{ frame1 := frames.StreamFrame{
Offset: 0, Offset: 0,
Data: []byte{0xDE, 0xAD}, Data: []byte("ab"),
} }
frame2 := frames.StreamFrame{ frame2 := frames.StreamFrame{
Offset: 1, Offset: 1,
Data: []byte{0x42, 0x24}, Data: []byte("xy"),
} }
frame3 := frames.StreamFrame{ frame3 := frames.StreamFrame{
Offset: 2, Offset: 2,
Data: []byte{0xBE, 0xEF}, Data: []byte("cd"),
} }
str.AddStreamFrame(&frame1) str.AddStreamFrame(&frame1)
str.AddStreamFrame(&frame2) str.AddStreamFrame(&frame2)
@ -196,7 +196,7 @@ var _ = Describe("Stream", func() {
n, err := str.Read(b) n, err := str.Read(b)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4)) Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) Expect(b).To(Equal([]byte("abyd")))
}) })
}) })
@ -342,93 +342,6 @@ var _ = Describe("Stream", func() {
}) })
}) })
Context("getting next str frame", func() {
It("gets next frame", func() {
str.AddStreamFrame(&frames.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
})
f, err := str.getNextFrameInOrder(true)
Expect(err).ToNot(HaveOccurred())
Expect(f.Data).To(Equal([]byte{0xDE, 0xAD}))
})
It("waits for next frame", func() {
var b bool
go func() {
time.Sleep(time.Millisecond)
b = true
str.AddStreamFrame(&frames.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
})
}()
f, err := str.getNextFrameInOrder(true)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(BeTrue())
Expect(f.Data).To(Equal([]byte{0xDE, 0xAD}))
})
It("queues non-matching str frames", func() {
var b bool
str.AddStreamFrame(&frames.StreamFrame{
Offset: 2,
Data: []byte{0xBE, 0xEF},
})
go func() {
time.Sleep(time.Millisecond)
b = true
str.AddStreamFrame(&frames.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
})
}()
f, err := str.getNextFrameInOrder(true)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(BeTrue())
Expect(f.Data).To(Equal([]byte{0xDE, 0xAD}))
str.readOffset += 2
f, err = str.getNextFrameInOrder(true)
Expect(err).ToNot(HaveOccurred())
Expect(f.Data).To(Equal([]byte{0xBE, 0xEF}))
})
It("returns nil if non-blocking", func() {
Expect(str.getNextFrameInOrder(false)).To(BeNil())
})
It("returns properly if non-blocking", func() {
str.AddStreamFrame(&frames.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
})
Expect(str.getNextFrameInOrder(false)).ToNot(BeNil())
})
It("dequeues 3rd frame after blocking on 1st", func() {
str.AddStreamFrame(&frames.StreamFrame{
Offset: 4,
Data: []byte{0x23, 0x42},
})
str.AddStreamFrame(&frames.StreamFrame{
Offset: 2,
Data: []byte{0xBE, 0xEF},
})
go func() {
time.Sleep(time.Millisecond)
str.AddStreamFrame(&frames.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
})
}()
Expect(str.getNextFrameInOrder(true)).ToNot(BeNil())
str.readOffset += 2
Expect(str.getNextFrameInOrder(true)).ToNot(BeNil())
str.readOffset += 2
Expect(str.getNextFrameInOrder(true)).ToNot(BeNil())
})
})
Context("closing", func() { Context("closing", func() {
Context("with fin bit", func() { Context("with fin bit", func() {
It("returns EOFs", func() { It("returns EOFs", func() {