create internal interfaces for send and receive streams

And use those for handling frames.
This commit is contained in:
Marten Seemann 2017-12-26 10:44:56 +07:00
parent 7a3209b3a4
commit e802491a8f
13 changed files with 391 additions and 58 deletions

View file

@ -0,0 +1,132 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: ReceiveStreamI)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockReceiveStreamI is a mock of ReceiveStreamI interface
type MockReceiveStreamI struct {
ctrl *gomock.Controller
recorder *MockReceiveStreamIMockRecorder
}
// MockReceiveStreamIMockRecorder is the mock recorder for MockReceiveStreamI
type MockReceiveStreamIMockRecorder struct {
mock *MockReceiveStreamI
}
// NewMockReceiveStreamI creates a new mock instance
func NewMockReceiveStreamI(ctrl *gomock.Controller) *MockReceiveStreamI {
mock := &MockReceiveStreamI{ctrl: ctrl}
mock.recorder = &MockReceiveStreamIMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockReceiveStreamI) EXPECT() *MockReceiveStreamIMockRecorder {
return m.recorder
}
// CancelRead mocks base method
func (m *MockReceiveStreamI) CancelRead(arg0 protocol.ApplicationErrorCode) error {
ret := m.ctrl.Call(m, "CancelRead", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CancelRead indicates an expected call of CancelRead
func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockReceiveStreamI)(nil).CancelRead), arg0)
}
// Read mocks base method
func (m *MockReceiveStreamI) Read(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Read", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read
func (mr *MockReceiveStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReceiveStreamI)(nil).Read), arg0)
}
// SetReadDeadline mocks base method
func (m *MockReceiveStreamI) SetReadDeadline(arg0 time.Time) error {
ret := m.ctrl.Call(m, "SetReadDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetReadDeadline indicates an expected call of SetReadDeadline
func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockReceiveStreamI)(nil).SetReadDeadline), arg0)
}
// StreamID mocks base method
func (m *MockReceiveStreamI) StreamID() protocol.StreamID {
ret := m.ctrl.Call(m, "StreamID")
ret0, _ := ret[0].(protocol.StreamID)
return ret0
}
// StreamID indicates an expected call of StreamID
func (mr *MockReceiveStreamIMockRecorder) StreamID() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockReceiveStreamI)(nil).StreamID))
}
// closeForShutdown mocks base method
func (m *MockReceiveStreamI) closeForShutdown(arg0 error) {
m.ctrl.Call(m, "closeForShutdown", arg0)
}
// closeForShutdown indicates an expected call of closeForShutdown
func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockReceiveStreamI)(nil).closeForShutdown), arg0)
}
// getWindowUpdate mocks base method
func (m *MockReceiveStreamI) getWindowUpdate() protocol.ByteCount {
ret := m.ctrl.Call(m, "getWindowUpdate")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// getWindowUpdate indicates an expected call of getWindowUpdate
func (mr *MockReceiveStreamIMockRecorder) getWindowUpdate() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockReceiveStreamI)(nil).getWindowUpdate))
}
// handleRstStreamFrame mocks base method
func (m *MockReceiveStreamI) handleRstStreamFrame(arg0 *wire.RstStreamFrame) error {
ret := m.ctrl.Call(m, "handleRstStreamFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// handleRstStreamFrame indicates an expected call of handleRstStreamFrame
func (mr *MockReceiveStreamIMockRecorder) handleRstStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleRstStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleRstStreamFrame), arg0)
}
// handleStreamFrame mocks base method
func (m *MockReceiveStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error {
ret := m.ctrl.Call(m, "handleStreamFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// handleStreamFrame indicates an expected call of handleStreamFrame
func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleStreamFrame), arg0)
}

View file

@ -0,0 +1,154 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: SendStreamI)
// Package quic is a generated GoMock package.
package quic
import (
context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockSendStreamI is a mock of SendStreamI interface
type MockSendStreamI struct {
ctrl *gomock.Controller
recorder *MockSendStreamIMockRecorder
}
// MockSendStreamIMockRecorder is the mock recorder for MockSendStreamI
type MockSendStreamIMockRecorder struct {
mock *MockSendStreamI
}
// NewMockSendStreamI creates a new mock instance
func NewMockSendStreamI(ctrl *gomock.Controller) *MockSendStreamI {
mock := &MockSendStreamI{ctrl: ctrl}
mock.recorder = &MockSendStreamIMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockSendStreamI) EXPECT() *MockSendStreamIMockRecorder {
return m.recorder
}
// CancelWrite mocks base method
func (m *MockSendStreamI) CancelWrite(arg0 protocol.ApplicationErrorCode) error {
ret := m.ctrl.Call(m, "CancelWrite", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CancelWrite indicates an expected call of CancelWrite
func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockSendStreamI)(nil).CancelWrite), arg0)
}
// Close mocks base method
func (m *MockSendStreamI) Close() error {
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockSendStreamIMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendStreamI)(nil).Close))
}
// Context mocks base method
func (m *MockSendStreamI) Context() context.Context {
ret := m.ctrl.Call(m, "Context")
ret0, _ := ret[0].(context.Context)
return ret0
}
// Context indicates an expected call of Context
func (mr *MockSendStreamIMockRecorder) Context() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSendStreamI)(nil).Context))
}
// SetWriteDeadline mocks base method
func (m *MockSendStreamI) SetWriteDeadline(arg0 time.Time) error {
ret := m.ctrl.Call(m, "SetWriteDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetWriteDeadline indicates an expected call of SetWriteDeadline
func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockSendStreamI)(nil).SetWriteDeadline), arg0)
}
// StreamID mocks base method
func (m *MockSendStreamI) StreamID() protocol.StreamID {
ret := m.ctrl.Call(m, "StreamID")
ret0, _ := ret[0].(protocol.StreamID)
return ret0
}
// StreamID indicates an expected call of StreamID
func (mr *MockSendStreamIMockRecorder) StreamID() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockSendStreamI)(nil).StreamID))
}
// Write mocks base method
func (m *MockSendStreamI) Write(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Write indicates an expected call of Write
func (mr *MockSendStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendStreamI)(nil).Write), arg0)
}
// closeForShutdown mocks base method
func (m *MockSendStreamI) closeForShutdown(arg0 error) {
m.ctrl.Call(m, "closeForShutdown", arg0)
}
// closeForShutdown indicates an expected call of closeForShutdown
func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockSendStreamI)(nil).closeForShutdown), arg0)
}
// handleMaxStreamDataFrame mocks base method
func (m *MockSendStreamI) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) {
m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0)
}
// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame
func (mr *MockSendStreamIMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleMaxStreamDataFrame), arg0)
}
// handleStopSendingFrame mocks base method
func (m *MockSendStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) {
m.ctrl.Call(m, "handleStopSendingFrame", arg0)
}
// handleStopSendingFrame indicates an expected call of handleStopSendingFrame
func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0)
}
// popStreamFrame mocks base method
func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
ret := m.ctrl.Call(m, "popStreamFrame", arg0)
ret0, _ := ret[0].(*wire.StreamFrame)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// popStreamFrame indicates an expected call of popStreamFrame
func (mr *MockSendStreamIMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockSendStreamI)(nil).popStreamFrame), arg0)
}

View file

@ -34,15 +34,28 @@ func (m *MockStreamGetter) EXPECT() *MockStreamGetterMockRecorder {
return m.recorder
}
// GetOrOpenStream mocks base method
func (m *MockStreamGetter) GetOrOpenStream(arg0 protocol.StreamID) (streamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenStream", arg0)
ret0, _ := ret[0].(streamI)
// GetOrOpenReceiveStream mocks base method
func (m *MockStreamGetter) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0)
ret0, _ := ret[0].(receiveStreamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenStream indicates an expected call of GetOrOpenStream
func (mr *MockStreamGetterMockRecorder) GetOrOpenStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenStream), arg0)
// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream
func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0)
}
// GetOrOpenSendStream mocks base method
func (m *MockStreamGetter) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0)
ret0, _ := ret[0].(sendStreamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream
func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0)
}

View file

@ -1,6 +1,8 @@
package quic
//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI"
//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI ReceiveStreamI"
//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI SendStreamI"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource"

View file

@ -12,6 +12,15 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)
type receiveStreamI interface {
ReceiveStream
handleStreamFrame(*wire.StreamFrame) error
handleRstStreamFrame(*wire.RstStreamFrame) error
closeForShutdown(error)
getWindowUpdate() protocol.ByteCount
}
type receiveStream struct {
mutex sync.Mutex
@ -40,6 +49,7 @@ type receiveStream struct {
}
var _ ReceiveStream = &receiveStream{}
var _ receiveStreamI = &receiveStream{}
func newReceiveStream(
streamID protocol.StreamID,

View file

@ -12,6 +12,14 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)
type sendStreamI interface {
SendStream
handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type sendStream struct {
mutex sync.Mutex
@ -41,6 +49,7 @@ type sendStream struct {
}
var _ SendStream = &sendStream{}
var _ sendStreamI = &sendStream{}
func newSendStream(
streamID protocol.StreamID,

View file

@ -24,7 +24,8 @@ type unpacker interface {
}
type streamGetter interface {
GetOrOpenStream(protocol.StreamID) (streamI, error)
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
}
type receivedPacket struct {
@ -574,7 +575,7 @@ func (s *session) handleStreamFrame(frame *wire.StreamFrame) error {
}
return s.cryptoStream.handleStreamFrame(frame)
}
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
if err != nil {
return err
}
@ -595,7 +596,7 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error
s.cryptoStream.handleMaxStreamDataFrame(frame)
return nil
}
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
if err != nil {
return err
}
@ -607,11 +608,26 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error
return nil
}
func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
if frame.StreamID == s.version.CryptoStreamID() {
return errors.New("Received RST_STREAM frame for the crypto stream")
}
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
if err != nil {
return err
}
if str == nil {
// stream is closed and already garbage collected
return nil
}
return str.handleRstStreamFrame(frame)
}
func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error {
if frame.StreamID == s.version.CryptoStreamID() {
return errors.New("Received a STOP_SENDING frame for the crypto stream")
}
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
if err != nil {
return err
}
@ -623,21 +639,6 @@ func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error {
return nil
}
func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
if frame.StreamID == s.version.CryptoStreamID() {
return errors.New("Received RST_STREAM frame for the crypto stream")
}
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
if err != nil {
return err
}
if str == nil {
// stream is closed and already garbage collected
return nil
}
return str.handleRstStreamFrame(frame)
}
func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error {
if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil {
return err

View file

@ -50,17 +50,20 @@ var _ streamSender = &uniStreamSender{}
type streamI interface {
Stream
closeForShutdown(error)
// for receiving
handleStreamFrame(*wire.StreamFrame) error
handleRstStreamFrame(*wire.RstStreamFrame) error
getWindowUpdate() protocol.ByteCount
// for sending
handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
// methods needed for flow control
getWindowUpdate() protocol.ByteCount
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
var _ receiveStreamI = (streamI)(nil)
var _ sendStreamI = (streamI)(nil)
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
//
// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually.
@ -77,7 +80,6 @@ type stream struct {
}
var _ Stream = &stream{}
var _ streamI = &stream{}
type deadlineError struct{}

View file

@ -113,7 +113,7 @@ func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) []*w
}
id := f.streamQueue[0]
f.streamQueue = f.streamQueue[1:]
str, err := f.streamGetter.GetOrOpenStream(id)
str, err := f.streamGetter.GetOrOpenSendStream(id)
if err != nil { // can happen if the stream completed after it said it had data
delete(f.activeStreams, id)
continue

View file

@ -21,8 +21,8 @@ var _ = Describe("Stream Framer", func() {
var (
retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame
framer *streamFramer
stream1, stream2 *MockStreamI
cryptoStream *MockCryptoStream
stream1, stream2 *MockSendStreamI
streamGetter *MockStreamGetter
)
@ -37,9 +37,9 @@ var _ = Describe("Stream Framer", func() {
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
}
stream1 = NewMockStreamI(mockCtrl)
stream1 = NewMockSendStreamI(mockCtrl)
stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes()
stream2 = NewMockStreamI(mockCtrl)
stream2 = NewMockSendStreamI(mockCtrl)
stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes()
cryptoStream = NewMockCryptoStream(mockCtrl)
framer = newStreamFramer(cryptoStream, streamGetter, versionGQUICFrames)
@ -126,7 +126,7 @@ var _ = Describe("Stream Framer", func() {
})
It("returns normal frames", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
f := &wire.StreamFrame{
StreamID: id1,
Data: []byte("foobar"),
@ -139,8 +139,8 @@ var _ = Describe("Stream Framer", func() {
})
It("skips a stream that was reported active, but was completed shortly after", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(nil, errors.New("stream was already deleted"))
streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(nil, errors.New("stream was already deleted"))
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f := &wire.StreamFrame{
StreamID: id2,
Data: []byte("foobar"),
@ -152,8 +152,8 @@ var _ = Describe("Stream Framer", func() {
})
It("skips a stream that was reported active, but doesn't have any data", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f := &wire.StreamFrame{
StreamID: id2,
Data: []byte("foobar"),
@ -166,7 +166,7 @@ var _ = Describe("Stream Framer", func() {
})
It("pops from a stream multiple times, if it has enough data", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil).Times(2)
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true)
@ -179,8 +179,8 @@ var _ = Describe("Stream Framer", func() {
})
It("re-queues a stream at the end, if it has enough data", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil).Times(2)
streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f11 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f12 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
@ -195,8 +195,8 @@ var _ = Describe("Stream Framer", func() {
})
It("only dequeues data from each stream once per packet", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
// both streams have more data, and will be re-queued
@ -208,8 +208,8 @@ var _ = Describe("Stream Framer", func() {
})
It("returns multiple normal frames in the order they were reported active", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f1 := &wire.StreamFrame{Data: []byte("foobar")}
f2 := &wire.StreamFrame{Data: []byte("foobaz")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false)
@ -220,7 +220,7 @@ var _ = Describe("Stream Framer", func() {
})
It("only asks a stream for data once, even if it was reported active multiple times", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
f := &wire.StreamFrame{Data: []byte("foobar")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) // only one call to this function
framer.AddActiveStream(id1)
@ -229,7 +229,7 @@ var _ = Describe("Stream Framer", func() {
})
It("returns retransmission frames before normal frames", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
framer.AddActiveStream(id1)
f1 := &wire.StreamFrame{Data: []byte("foobar")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false)
@ -244,7 +244,7 @@ var _ = Describe("Stream Framer", func() {
})
It("pops frames that have the minimum size", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
stream1.EXPECT().popStreamFrame(protocol.MinStreamFrameSize).Return(&wire.StreamFrame{Data: []byte("foobar")}, false)
framer.AddActiveStream(id1)
framer.PopStreamFrames(protocol.MinStreamFrameSize)
@ -256,7 +256,7 @@ var _ = Describe("Stream Framer", func() {
})
It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
// pop a frame such that the remaining size is one byte less than the minimum STREAM frame size
f := &wire.StreamFrame{
StreamID: id1,

View file

@ -90,6 +90,16 @@ func (m *streamsMap) nextStreamID(id protocol.StreamID) protocol.StreamID {
return id + 2
}
func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
// every bidirectional stream is also a receive stream
return m.GetOrOpenStream(id)
}
func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
// every bidirectional stream is also a send stream
return m.GetOrOpenStream(id)
}
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {

View file

@ -38,7 +38,7 @@ func (q *windowUpdateQueue) QueueAll() {
if id == q.cryptoStream.StreamID() {
offset = q.cryptoStream.getWindowUpdate()
} else {
str, err := q.streamGetter.GetOrOpenStream(id)
str, err := q.streamGetter.GetOrOpenReceiveStream(id)
if err != nil || str == nil { // the stream can be nil if it was completed before dequeing the window update
continue
}

View file

@ -31,8 +31,8 @@ var _ = Describe("Window Update Queue", func() {
stream1.EXPECT().getWindowUpdate().Return(protocol.ByteCount(10))
stream3 := NewMockStreamI(mockCtrl)
stream3.EXPECT().getWindowUpdate().Return(protocol.ByteCount(30))
streamGetter.EXPECT().GetOrOpenStream(protocol.StreamID(3)).Return(stream3, nil)
streamGetter.EXPECT().GetOrOpenStream(protocol.StreamID(1)).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(3)).Return(stream3, nil)
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(1)).Return(stream1, nil)
q.Add(3)
q.Add(1)
q.QueueAll()
@ -43,7 +43,7 @@ var _ = Describe("Window Update Queue", func() {
It("deletes the entry after getting the MAX_STREAM_DATA frame", func() {
stream10 := NewMockStreamI(mockCtrl)
stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(100))
streamGetter.EXPECT().GetOrOpenStream(protocol.StreamID(10)).Return(stream10, nil)
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil)
q.Add(10)
q.QueueAll()
Expect(queuedFrames).To(HaveLen(1))
@ -52,7 +52,7 @@ var _ = Describe("Window Update Queue", func() {
})
It("doesn't queue a MAX_STREAM_DATA for a closed stream", func() {
streamGetter.EXPECT().GetOrOpenStream(protocol.StreamID(12)).Return(nil, nil)
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(12)).Return(nil, nil)
q.Add(12)
q.QueueAll()
Expect(queuedFrames).To(BeEmpty())
@ -61,7 +61,7 @@ var _ = Describe("Window Update Queue", func() {
It("doesn't queue a MAX_STREAM_DATA if the flow controller returns an offset of 0", func() {
stream5 := NewMockStreamI(mockCtrl)
stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0))
streamGetter.EXPECT().GetOrOpenStream(protocol.StreamID(5)).Return(stream5, nil)
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(stream5, nil)
q.Add(5)
q.QueueAll()
Expect(queuedFrames).To(BeEmpty())
@ -79,7 +79,7 @@ var _ = Describe("Window Update Queue", func() {
It("deduplicates", func() {
stream10 := NewMockStreamI(mockCtrl)
stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(200))
streamGetter.EXPECT().GetOrOpenStream(protocol.StreamID(10)).Return(stream10, nil)
streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil)
q.Add(10)
q.Add(10)
q.QueueAll()