mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 13:17:36 +03:00
implement and use the MAX_STREAMS frame
This commit is contained in:
parent
0f931ca54e
commit
9518c90c0a
18 changed files with 248 additions and 167 deletions
|
@ -29,6 +29,12 @@ func (s StreamID) Type() StreamType {
|
||||||
return StreamTypeBidi
|
return StreamTypeBidi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StreamNum returns how many streams in total are below this
|
||||||
|
// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9)
|
||||||
|
func (s StreamID) StreamNum() uint64 {
|
||||||
|
return uint64(s/4) + 1
|
||||||
|
}
|
||||||
|
|
||||||
// MaxStreamID is the highest stream ID that a peer is allowed to open,
|
// MaxStreamID is the highest stream ID that a peer is allowed to open,
|
||||||
// when it is allowed to open numStreams.
|
// when it is allowed to open numStreams.
|
||||||
func MaxStreamID(stype StreamType, numStreams uint64, pers Perspective) StreamID {
|
func MaxStreamID(stype StreamType, numStreams uint64, pers Perspective) StreamID {
|
||||||
|
|
|
@ -27,6 +27,17 @@ var _ = Describe("Stream ID", func() {
|
||||||
Expect(FirstStream(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3)))
|
Expect(FirstStream(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3)))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("tells the stream number", func() {
|
||||||
|
Expect(StreamID(0).StreamNum()).To(BeEquivalentTo(1))
|
||||||
|
Expect(StreamID(1).StreamNum()).To(BeEquivalentTo(1))
|
||||||
|
Expect(StreamID(2).StreamNum()).To(BeEquivalentTo(1))
|
||||||
|
Expect(StreamID(3).StreamNum()).To(BeEquivalentTo(1))
|
||||||
|
Expect(StreamID(8).StreamNum()).To(BeEquivalentTo(3))
|
||||||
|
Expect(StreamID(9).StreamNum()).To(BeEquivalentTo(3))
|
||||||
|
Expect(StreamID(10).StreamNum()).To(BeEquivalentTo(3))
|
||||||
|
Expect(StreamID(11).StreamNum()).To(BeEquivalentTo(3))
|
||||||
|
})
|
||||||
|
|
||||||
Context("maximum stream IDs", func() {
|
Context("maximum stream IDs", func() {
|
||||||
It("doesn't allow any", func() {
|
It("doesn't allow any", func() {
|
||||||
Expect(MaxStreamID(StreamTypeBidi, 0, PerspectiveClient)).To(Equal(StreamID(0)))
|
Expect(MaxStreamID(StreamTypeBidi, 0, PerspectiveClient)).To(Equal(StreamID(0)))
|
||||||
|
|
|
@ -55,11 +55,6 @@ func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
|
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
|
||||||
}
|
}
|
||||||
case 0x6:
|
|
||||||
frame, err = parseMaxStreamIDFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x7:
|
case 0x7:
|
||||||
frame, err = parsePingFrame(r, v)
|
frame, err = parsePingFrame(r, v)
|
||||||
case 0x8:
|
case 0x8:
|
||||||
|
@ -97,6 +92,11 @@ func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = qerr.Error(qerr.InvalidAckData, err.Error())
|
err = qerr.Error(qerr.InvalidAckData, err.Error())
|
||||||
}
|
}
|
||||||
|
case 0x1c, 0x1d:
|
||||||
|
frame, err = parseMaxStreamsFrame(r, v)
|
||||||
|
if err != nil {
|
||||||
|
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||||
|
}
|
||||||
case 0x18:
|
case 0x18:
|
||||||
frame, err = parseCryptoFrame(r, v)
|
frame, err = parseCryptoFrame(r, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -97,8 +97,11 @@ var _ = Describe("Frame parsing", func() {
|
||||||
Expect(frame).To(Equal(f))
|
Expect(frame).To(Equal(f))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("unpacks MAX_STREAM_ID frames", func() {
|
It("unpacks MAX_STREAMS frames", func() {
|
||||||
f := &MaxStreamIDFrame{StreamID: 0x1337}
|
f := &MaxStreamsFrame{
|
||||||
|
Type: protocol.StreamTypeBidi,
|
||||||
|
MaxStreams: 0x1337,
|
||||||
|
}
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
err := f.Write(buf, versionIETFFrames)
|
err := f.Write(buf, versionIETFFrames)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
|
@ -1,37 +0,0 @@
|
||||||
package wire
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A MaxStreamIDFrame is a MAX_STREAM_ID frame
|
|
||||||
type MaxStreamIDFrame struct {
|
|
||||||
StreamID protocol.StreamID
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseMaxStreamIDFrame parses a MAX_STREAM_ID frame
|
|
||||||
func parseMaxStreamIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamIDFrame, error) {
|
|
||||||
// read the Type byte
|
|
||||||
if _, err := r.ReadByte(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
streamID, err := utils.ReadVarInt(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &MaxStreamIDFrame{StreamID: protocol.StreamID(streamID)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *MaxStreamIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
|
||||||
b.WriteByte(0x6)
|
|
||||||
utils.WriteVarInt(b, uint64(f.StreamID))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Length of a written frame
|
|
||||||
func (f *MaxStreamIDFrame) Length(protocol.VersionNumber) protocol.ByteCount {
|
|
||||||
return 1 + utils.VarIntLen(uint64(f.StreamID))
|
|
||||||
}
|
|
|
@ -1,51 +0,0 @@
|
||||||
package wire
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("MAX_STREAM_ID frame", func() {
|
|
||||||
Context("parsing", func() {
|
|
||||||
It("accepts sample frame", func() {
|
|
||||||
data := []byte{0x6}
|
|
||||||
data = append(data, encodeVarInt(0xdecafbad)...)
|
|
||||||
b := bytes.NewReader(data)
|
|
||||||
f, err := parseMaxStreamIDFrame(b, protocol.VersionWhatever)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(f.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
|
|
||||||
Expect(b.Len()).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors on EOFs", func() {
|
|
||||||
data := []byte{0x06}
|
|
||||||
data = append(data, encodeVarInt(0xdeadbeefcafe13)...)
|
|
||||||
_, err := parseMaxStreamIDFrame(bytes.NewReader(data), protocol.VersionWhatever)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
for i := range data {
|
|
||||||
_, err := parseMaxStreamIDFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("writing", func() {
|
|
||||||
It("writes a sample frame", func() {
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
frame := MaxStreamIDFrame{StreamID: 0x12345678}
|
|
||||||
frame.Write(b, protocol.VersionWhatever)
|
|
||||||
expected := []byte{0x6}
|
|
||||||
expected = append(expected, encodeVarInt(0x12345678)...)
|
|
||||||
Expect(b.Bytes()).To(Equal(expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("has the correct min length", func() {
|
|
||||||
frame := MaxStreamIDFrame{StreamID: 0x1337}
|
|
||||||
Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(0x1337)))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
51
internal/wire/max_streams_frame.go
Normal file
51
internal/wire/max_streams_frame.go
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
package wire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A MaxStreamsFrame is a MAX_STREAMS frame
|
||||||
|
type MaxStreamsFrame struct {
|
||||||
|
Type protocol.StreamType
|
||||||
|
MaxStreams uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamsFrame, error) {
|
||||||
|
typeByte, err := r.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f := &MaxStreamsFrame{}
|
||||||
|
switch typeByte {
|
||||||
|
case 0x1c:
|
||||||
|
f.Type = protocol.StreamTypeBidi
|
||||||
|
case 0x1d:
|
||||||
|
f.Type = protocol.StreamTypeUni
|
||||||
|
}
|
||||||
|
streamID, err := utils.ReadVarInt(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f.MaxStreams = streamID
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *MaxStreamsFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
||||||
|
switch f.Type {
|
||||||
|
case protocol.StreamTypeBidi:
|
||||||
|
b.WriteByte(0x1c)
|
||||||
|
case protocol.StreamTypeUni:
|
||||||
|
b.WriteByte(0x1d)
|
||||||
|
}
|
||||||
|
utils.WriteVarInt(b, f.MaxStreams)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Length of a written frame
|
||||||
|
func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount {
|
||||||
|
return 1 + utils.VarIntLen(f.MaxStreams)
|
||||||
|
}
|
78
internal/wire/max_streams_frame_test.go
Normal file
78
internal/wire/max_streams_frame_test.go
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
package wire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("MAX_STREAMS frame", func() {
|
||||||
|
Context("parsing", func() {
|
||||||
|
It("accepts a frame for a bidirectional stream", func() {
|
||||||
|
data := []byte{0x1c}
|
||||||
|
data = append(data, encodeVarInt(0xdecaf)...)
|
||||||
|
b := bytes.NewReader(data)
|
||||||
|
f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(f.Type).To(Equal(protocol.StreamTypeBidi))
|
||||||
|
Expect(f.MaxStreams).To(BeEquivalentTo(0xdecaf))
|
||||||
|
Expect(b.Len()).To(BeZero())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("accepts a frame for a bidirectional stream", func() {
|
||||||
|
data := []byte{0x1d}
|
||||||
|
data = append(data, encodeVarInt(0xdecaf)...)
|
||||||
|
b := bytes.NewReader(data)
|
||||||
|
f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(f.Type).To(Equal(protocol.StreamTypeUni))
|
||||||
|
Expect(f.MaxStreams).To(BeEquivalentTo(0xdecaf))
|
||||||
|
Expect(b.Len()).To(BeZero())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors on EOFs", func() {
|
||||||
|
data := []byte{0x1d}
|
||||||
|
data = append(data, encodeVarInt(0xdeadbeefcafe13)...)
|
||||||
|
_, err := parseMaxStreamsFrame(bytes.NewReader(data), protocol.VersionWhatever)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
for i := range data {
|
||||||
|
_, err := parseMaxStreamsFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever)
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("writing", func() {
|
||||||
|
It("for a bidirectional stream", func() {
|
||||||
|
f := &MaxStreamsFrame{
|
||||||
|
Type: protocol.StreamTypeBidi,
|
||||||
|
MaxStreams: 0xdeadbeef,
|
||||||
|
}
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed())
|
||||||
|
expected := []byte{0x1c}
|
||||||
|
expected = append(expected, encodeVarInt(0xdeadbeef)...)
|
||||||
|
Expect(b.Bytes()).To(Equal(expected))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("for a unidirectional stream", func() {
|
||||||
|
f := &MaxStreamsFrame{
|
||||||
|
Type: protocol.StreamTypeUni,
|
||||||
|
MaxStreams: 0xdecafbad,
|
||||||
|
}
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed())
|
||||||
|
expected := []byte{0x1d}
|
||||||
|
expected = append(expected, encodeVarInt(0xdecafbad)...)
|
||||||
|
Expect(b.Bytes()).To(Equal(expected))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("has the correct min length", func() {
|
||||||
|
frame := MaxStreamsFrame{MaxStreams: 0x1337}
|
||||||
|
Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(0x1337)))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
|
@ -110,16 +110,16 @@ func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) *
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleMaxStreamIDFrame mocks base method
|
// HandleMaxStreamsFrame mocks base method
|
||||||
func (m *MockStreamManager) HandleMaxStreamIDFrame(arg0 *wire.MaxStreamIDFrame) error {
|
func (m *MockStreamManager) HandleMaxStreamsFrame(arg0 *wire.MaxStreamsFrame) error {
|
||||||
ret := m.ctrl.Call(m, "HandleMaxStreamIDFrame", arg0)
|
ret := m.ctrl.Call(m, "HandleMaxStreamsFrame", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleMaxStreamIDFrame indicates an expected call of HandleMaxStreamIDFrame
|
// HandleMaxStreamsFrame indicates an expected call of HandleMaxStreamsFrame
|
||||||
func (mr *MockStreamManagerMockRecorder) HandleMaxStreamIDFrame(arg0 interface{}) *gomock.Call {
|
func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 interface{}) *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamIDFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamIDFrame), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenStream mocks base method
|
// OpenStream mocks base method
|
||||||
|
|
10
session.go
10
session.go
|
@ -40,7 +40,7 @@ type streamManager interface {
|
||||||
AcceptUniStream() (ReceiveStream, error)
|
AcceptUniStream() (ReceiveStream, error)
|
||||||
DeleteStream(protocol.StreamID) error
|
DeleteStream(protocol.StreamID) error
|
||||||
UpdateLimits(*handshake.TransportParameters)
|
UpdateLimits(*handshake.TransportParameters)
|
||||||
HandleMaxStreamIDFrame(*wire.MaxStreamIDFrame) error
|
HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error
|
||||||
CloseWithError(error)
|
CloseWithError(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -563,8 +563,8 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
|
||||||
s.handleMaxDataFrame(frame)
|
s.handleMaxDataFrame(frame)
|
||||||
case *wire.MaxStreamDataFrame:
|
case *wire.MaxStreamDataFrame:
|
||||||
err = s.handleMaxStreamDataFrame(frame)
|
err = s.handleMaxStreamDataFrame(frame)
|
||||||
case *wire.MaxStreamIDFrame:
|
case *wire.MaxStreamsFrame:
|
||||||
err = s.handleMaxStreamIDFrame(frame)
|
err = s.handleMaxStreamsFrame(frame)
|
||||||
case *wire.DataBlockedFrame:
|
case *wire.DataBlockedFrame:
|
||||||
case *wire.StreamDataBlockedFrame:
|
case *wire.StreamDataBlockedFrame:
|
||||||
case *wire.StreamIDBlockedFrame:
|
case *wire.StreamIDBlockedFrame:
|
||||||
|
@ -641,8 +641,8 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) handleMaxStreamIDFrame(frame *wire.MaxStreamIDFrame) error {
|
func (s *session) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) error {
|
||||||
return s.streamsMap.HandleMaxStreamIDFrame(frame)
|
return s.streamsMap.HandleMaxStreamsFrame(frame)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) handleResetStreamFrame(frame *wire.ResetStreamFrame) error {
|
func (s *session) handleResetStreamFrame(frame *wire.ResetStreamFrame) error {
|
||||||
|
|
|
@ -249,17 +249,20 @@ var _ = Describe("Session", func() {
|
||||||
|
|
||||||
Context("handling MAX_STREAM_ID frames", func() {
|
Context("handling MAX_STREAM_ID frames", func() {
|
||||||
It("passes the frame to the streamsMap", func() {
|
It("passes the frame to the streamsMap", func() {
|
||||||
f := &wire.MaxStreamIDFrame{StreamID: 10}
|
f := &wire.MaxStreamsFrame{
|
||||||
streamManager.EXPECT().HandleMaxStreamIDFrame(f)
|
Type: protocol.StreamTypeUni,
|
||||||
err := sess.handleMaxStreamIDFrame(f)
|
MaxStreams: 10,
|
||||||
|
}
|
||||||
|
streamManager.EXPECT().HandleMaxStreamsFrame(f)
|
||||||
|
err := sess.handleMaxStreamsFrame(f)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns errors", func() {
|
It("returns errors", func() {
|
||||||
f := &wire.MaxStreamIDFrame{StreamID: 10}
|
f := &wire.MaxStreamsFrame{MaxStreams: 10}
|
||||||
testErr := errors.New("test error")
|
testErr := errors.New("test error")
|
||||||
streamManager.EXPECT().HandleMaxStreamIDFrame(f).Return(testErr)
|
streamManager.EXPECT().HandleMaxStreamsFrame(f).Return(testErr)
|
||||||
err := sess.handleMaxStreamIDFrame(f)
|
err := sess.handleMaxStreamsFrame(f)
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -146,15 +146,13 @@ func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, err
|
||||||
panic("")
|
panic("")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error {
|
func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error {
|
||||||
id := f.StreamID
|
id := protocol.MaxStreamID(f.Type, f.MaxStreams, m.perspective)
|
||||||
if id.InitiatedBy() != m.perspective {
|
|
||||||
return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id)
|
|
||||||
}
|
|
||||||
switch id.Type() {
|
switch id.Type() {
|
||||||
case protocol.StreamTypeUni:
|
case protocol.StreamTypeUni:
|
||||||
m.outgoingUniStreams.SetMaxStream(id)
|
m.outgoingUniStreams.SetMaxStream(id)
|
||||||
case protocol.StreamTypeBidi:
|
case protocol.StreamTypeBidi:
|
||||||
|
fmt.Printf("")
|
||||||
m.outgoingBidiStreams.SetMaxStream(id)
|
m.outgoingBidiStreams.SetMaxStream(id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import "github.com/cheekybits/genny/generic"
|
import (
|
||||||
|
"github.com/cheekybits/genny/generic"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
// In the auto-generated streams maps, we need to be able to close the streams.
|
// In the auto-generated streams maps, we need to be able to close the streams.
|
||||||
// Therefore, extend the generic.Type with the stream close method.
|
// Therefore, extend the generic.Type with the stream close method.
|
||||||
|
@ -9,3 +13,5 @@ type item interface {
|
||||||
generic.Type
|
generic.Type
|
||||||
closeForShutdown(error)
|
closeForShutdown(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const streamTypeGeneric protocol.StreamType = protocol.StreamTypeUni
|
||||||
|
|
|
@ -24,7 +24,7 @@ type incomingBidiStreamsMap struct {
|
||||||
maxNumStreams uint64 // maximum number of streams
|
maxNumStreams uint64 // maximum number of streams
|
||||||
|
|
||||||
newStream func(protocol.StreamID) streamI
|
newStream func(protocol.StreamID) streamI
|
||||||
queueMaxStreamID func(*wire.MaxStreamIDFrame)
|
queueMaxStreamID func(*wire.MaxStreamsFrame)
|
||||||
|
|
||||||
closeErr error
|
closeErr error
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ func newIncomingBidiStreamsMap(
|
||||||
maxStream: initialMaxStreamID,
|
maxStream: initialMaxStreamID,
|
||||||
maxNumStreams: maxNumStreams,
|
maxNumStreams: maxNumStreams,
|
||||||
newStream: newStream,
|
newStream: newStream,
|
||||||
queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) },
|
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
|
||||||
}
|
}
|
||||||
m.cond.L = &m.mutex
|
m.cond.L = &m.mutex
|
||||||
return m
|
return m
|
||||||
|
@ -111,7 +111,10 @@ func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error {
|
||||||
if m.maxNumStreams > uint64(len(m.streams)) {
|
if m.maxNumStreams > uint64(len(m.streams)) {
|
||||||
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
|
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
|
||||||
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
|
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
|
||||||
m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream})
|
m.queueMaxStreamID(&wire.MaxStreamsFrame{
|
||||||
|
Type: protocol.StreamTypeBidi,
|
||||||
|
MaxStreams: m.maxStream.StreamNum(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,8 +8,8 @@ import (
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream"
|
//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi"
|
||||||
//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream"
|
//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni"
|
||||||
type incomingItemsMap struct {
|
type incomingItemsMap struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
cond sync.Cond
|
cond sync.Cond
|
||||||
|
@ -22,7 +22,7 @@ type incomingItemsMap struct {
|
||||||
maxNumStreams uint64 // maximum number of streams
|
maxNumStreams uint64 // maximum number of streams
|
||||||
|
|
||||||
newStream func(protocol.StreamID) item
|
newStream func(protocol.StreamID) item
|
||||||
queueMaxStreamID func(*wire.MaxStreamIDFrame)
|
queueMaxStreamID func(*wire.MaxStreamsFrame)
|
||||||
|
|
||||||
closeErr error
|
closeErr error
|
||||||
}
|
}
|
||||||
|
@ -41,7 +41,7 @@ func newIncomingItemsMap(
|
||||||
maxStream: initialMaxStreamID,
|
maxStream: initialMaxStreamID,
|
||||||
maxNumStreams: maxNumStreams,
|
maxNumStreams: maxNumStreams,
|
||||||
newStream: newStream,
|
newStream: newStream,
|
||||||
queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) },
|
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
|
||||||
}
|
}
|
||||||
m.cond.L = &m.mutex
|
m.cond.L = &m.mutex
|
||||||
return m
|
return m
|
||||||
|
@ -109,7 +109,10 @@ func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error {
|
||||||
if m.maxNumStreams > uint64(len(m.streams)) {
|
if m.maxNumStreams > uint64(len(m.streams)) {
|
||||||
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
|
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
|
||||||
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
|
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
|
||||||
m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream})
|
m.queueMaxStreamID(&wire.MaxStreamsFrame{
|
||||||
|
Type: streamTypeGeneric,
|
||||||
|
MaxStreams: m.maxStream.StreamNum(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,8 +26,8 @@ func (s *mockGenericStream) closeForShutdown(err error) {
|
||||||
|
|
||||||
var _ = Describe("Streams Map (incoming)", func() {
|
var _ = Describe("Streams Map (incoming)", func() {
|
||||||
const (
|
const (
|
||||||
firstNewStream protocol.StreamID = 20
|
firstNewStream protocol.StreamID = 2
|
||||||
maxNumStreams uint64 = 10
|
maxNumStreams uint64 = 5
|
||||||
initialMaxStream protocol.StreamID = firstNewStream + 4*protocol.StreamID(maxNumStreams-1)
|
initialMaxStream protocol.StreamID = firstNewStream + 4*protocol.StreamID(maxNumStreams-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -49,9 +49,9 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("opens all streams up to the id on GetOrOpenStream", func() {
|
It("opens all streams up to the id on GetOrOpenStream", func() {
|
||||||
_, err := m.GetOrOpenStream(firstNewStream + 4*5)
|
_, err := m.GetOrOpenStream(firstNewStream + 4*4)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(newItemCounter).To(Equal(6))
|
Expect(newItemCounter).To(Equal(5))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("starts opening streams at the right position", func() {
|
It("starts opening streams at the right position", func() {
|
||||||
|
@ -59,9 +59,9 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
_, err := m.GetOrOpenStream(firstNewStream + 4)
|
_, err := m.GetOrOpenStream(firstNewStream + 4)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(newItemCounter).To(Equal(2))
|
Expect(newItemCounter).To(Equal(2))
|
||||||
_, err = m.GetOrOpenStream(firstNewStream + 4*5)
|
_, err = m.GetOrOpenStream(firstNewStream + 4*4)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(newItemCounter).To(Equal(6))
|
Expect(newItemCounter).To(Equal(5))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("accepts streams in the right order", func() {
|
It("accepts streams in the right order", func() {
|
||||||
|
@ -143,9 +143,9 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("closes all streams when CloseWithError is called", func() {
|
It("closes all streams when CloseWithError is called", func() {
|
||||||
str1, err := m.GetOrOpenStream(20)
|
str1, err := m.GetOrOpenStream(firstNewStream)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str2, err := m.GetOrOpenStream(20 + 8)
|
str2, err := m.GetOrOpenStream(firstNewStream + 8)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
testErr := errors.New("test err")
|
testErr := errors.New("test err")
|
||||||
m.CloseWithError(testErr)
|
m.CloseWithError(testErr)
|
||||||
|
@ -157,11 +157,11 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
|
|
||||||
It("deletes streams", func() {
|
It("deletes streams", func() {
|
||||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||||
_, err := m.GetOrOpenStream(20)
|
_, err := m.GetOrOpenStream(initialMaxStream)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
err = m.DeleteStream(20)
|
err = m.DeleteStream(initialMaxStream)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str, err := m.GetOrOpenStream(20)
|
str, err := m.GetOrOpenStream(initialMaxStream)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str).To(BeNil())
|
Expect(str).To(BeNil())
|
||||||
})
|
})
|
||||||
|
@ -171,13 +171,17 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
Expect(err).To(MatchError("Tried to delete unknown stream 1337"))
|
Expect(err).To(MatchError("Tried to delete unknown stream 1337"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends MAX_STREAM_ID frames when streams are deleted", func() {
|
It("sends MAX_STREAMS frames when streams are deleted", func() {
|
||||||
// open a bunch of streams
|
// open a bunch of streams
|
||||||
_, err := m.GetOrOpenStream(firstNewStream + 4*4)
|
_, err := m.GetOrOpenStream(firstNewStream + 4*4)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{StreamID: initialMaxStream + 4})
|
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
|
||||||
Expect(m.DeleteStream(firstNewStream + 4)).To(Succeed())
|
Expect(f.(*wire.MaxStreamsFrame).MaxStreams).To(Equal(maxNumStreams + 1))
|
||||||
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{StreamID: initialMaxStream + 8})
|
})
|
||||||
|
Expect(m.DeleteStream(firstNewStream + 2*4)).To(Succeed())
|
||||||
|
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
|
||||||
|
Expect(f.(*wire.MaxStreamsFrame).MaxStreams).To(Equal(maxNumStreams + 2))
|
||||||
|
})
|
||||||
Expect(m.DeleteStream(firstNewStream + 3*4)).To(Succeed())
|
Expect(m.DeleteStream(firstNewStream + 3*4)).To(Succeed())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -24,7 +24,7 @@ type incomingUniStreamsMap struct {
|
||||||
maxNumStreams uint64 // maximum number of streams
|
maxNumStreams uint64 // maximum number of streams
|
||||||
|
|
||||||
newStream func(protocol.StreamID) receiveStreamI
|
newStream func(protocol.StreamID) receiveStreamI
|
||||||
queueMaxStreamID func(*wire.MaxStreamIDFrame)
|
queueMaxStreamID func(*wire.MaxStreamsFrame)
|
||||||
|
|
||||||
closeErr error
|
closeErr error
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ func newIncomingUniStreamsMap(
|
||||||
maxStream: initialMaxStreamID,
|
maxStream: initialMaxStreamID,
|
||||||
maxNumStreams: maxNumStreams,
|
maxNumStreams: maxNumStreams,
|
||||||
newStream: newStream,
|
newStream: newStream,
|
||||||
queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) },
|
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
|
||||||
}
|
}
|
||||||
m.cond.L = &m.mutex
|
m.cond.L = &m.mutex
|
||||||
return m
|
return m
|
||||||
|
@ -111,7 +111,10 @@ func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error {
|
||||||
if m.maxNumStreams > uint64(len(m.streams)) {
|
if m.maxNumStreams > uint64(len(m.streams)) {
|
||||||
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
|
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
|
||||||
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
|
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
|
||||||
m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream})
|
m.queueMaxStreamID(&wire.MaxStreamsFrame{
|
||||||
|
Type: protocol.StreamTypeUni,
|
||||||
|
MaxStreams: m.maxStream.StreamNum(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -292,7 +292,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("handling MAX_STREAM_ID frames", func() {
|
Context("handling MAX_STREAMS frames", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
|
mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
|
||||||
})
|
})
|
||||||
|
@ -300,49 +300,49 @@ var _ = Describe("Streams Map", func() {
|
||||||
It("processes IDs for outgoing bidirectional streams", func() {
|
It("processes IDs for outgoing bidirectional streams", func() {
|
||||||
_, err := m.OpenStream()
|
_, err := m.OpenStream()
|
||||||
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingBidiStream})
|
Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Type: protocol.StreamTypeBidi,
|
||||||
|
MaxStreams: 1,
|
||||||
|
})).To(Succeed())
|
||||||
str, err := m.OpenStream()
|
str, err := m.OpenStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
|
Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
|
||||||
|
_, err = m.OpenStream()
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("processes IDs for outgoing bidirectional streams", func() {
|
It("processes IDs for outgoing unidirectional streams", func() {
|
||||||
_, err := m.OpenUniStream()
|
_, err := m.OpenUniStream()
|
||||||
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingUniStream})
|
Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Type: protocol.StreamTypeUni,
|
||||||
|
MaxStreams: 1,
|
||||||
|
})).To(Succeed())
|
||||||
str, err := m.OpenUniStream()
|
str, err := m.OpenUniStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
|
Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
|
||||||
})
|
_, err = m.OpenUniStream()
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
It("rejects IDs for incoming bidirectional streams", func() {
|
|
||||||
err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingBidiStream})
|
|
||||||
Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingBidiStream)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects IDs for incoming unidirectional streams", func() {
|
|
||||||
err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingUniStream})
|
|
||||||
Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingUniStream)))
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("sending MAX_STREAM_ID frames", func() {
|
Context("sending MAX_STREAMS frames", func() {
|
||||||
It("sends MAX_STREAM_ID frames for bidirectional streams", func() {
|
It("sends a MAX_STREAMS frame for bidirectional streams", func() {
|
||||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
|
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{
|
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
|
||||||
StreamID: ids.firstIncomingBidiStream + 4*maxBidiStreams,
|
Type: protocol.StreamTypeBidi,
|
||||||
|
MaxStreams: maxBidiStreams + 1,
|
||||||
})
|
})
|
||||||
Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed())
|
Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends MAX_STREAM_ID frames for unidirectional streams", func() {
|
It("sends a MAX_STREAMS frame for unidirectional streams", func() {
|
||||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
|
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{
|
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
|
||||||
StreamID: ids.firstIncomingUniStream + 4*maxUniStreams,
|
Type: protocol.StreamTypeUni,
|
||||||
|
MaxStreams: maxUniStreams + 1,
|
||||||
})
|
})
|
||||||
Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed())
|
Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed())
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue