Merge pull request #1589 from lucas-clemente/fix-1563

use stream counts
This commit is contained in:
Marten Seemann 2018-11-12 15:46:44 +07:00 committed by GitHub
commit a0adcd71f6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
37 changed files with 711 additions and 546 deletions

View file

@ -56,7 +56,7 @@ var _ = Describe("Stream Framer", func() {
It("adds the right number of frames", func() {
maxSize := protocol.ByteCount(1000)
bf := &wire.BlockedFrame{Offset: 0x1337}
bf := &wire.DataBlockedFrame{DataLimit: 0x1337}
bfLen := bf.Length(version)
numFrames := int(maxSize / bfLen) // max number of frames that fit into maxSize
for i := 0; i < numFrames+1; i++ {

View file

@ -11,7 +11,7 @@ import (
var _ = Describe("retransmittable frames", func() {
for fl, el := range map[wire.Frame]bool{
&wire.AckFrame{}: false,
&wire.BlockedFrame{}: true,
&wire.DataBlockedFrame{}: true,
&wire.ConnectionCloseFrame{}: true,
&wire.PingFrame{}: true,
&wire.ResetStreamFrame{}: true,

View file

@ -29,32 +29,39 @@ func (s StreamID) Type() StreamType {
return StreamTypeBidi
}
// MaxBidiStreamID is the highest stream ID that the peer is allowed to open,
// when it is allowed to open numStreams bidirectional streams.
func MaxBidiStreamID(numStreams int, pers Perspective) StreamID {
// StreamNum returns how many streams in total are below this
// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9)
func (s StreamID) StreamNum() uint64 {
return uint64(s/4) + 1
}
// MaxStreamID is the highest stream ID that a peer is allowed to open,
// when it is allowed to open numStreams.
func MaxStreamID(stype StreamType, numStreams uint64, pers Perspective) StreamID {
if numStreams == 0 {
return 0
}
var first StreamID
if pers == PerspectiveClient {
first = 1
} else {
first = 0
switch stype {
case StreamTypeBidi:
switch pers {
case PerspectiveClient:
first = 0
case PerspectiveServer:
first = 1
}
case StreamTypeUni:
switch pers {
case PerspectiveClient:
first = 2
case PerspectiveServer:
first = 3
}
}
return first + 4*StreamID(numStreams-1)
}
// MaxUniStreamID is the highest stream ID that the peer is allowed to open,
// when it is allowed to open numStreams unidirectional streams.
func MaxUniStreamID(numStreams int, pers Perspective) StreamID {
if numStreams == 0 {
return 0
}
var first StreamID
if pers == PerspectiveClient {
first = 3
} else {
first = 2
}
return first + 4*StreamID(numStreams-1)
// FirstStream returns the first valid stream ID
func FirstStream(stype StreamType, pers Perspective) StreamID {
return MaxStreamID(stype, 1, pers)
}

View file

@ -20,39 +20,44 @@ var _ = Describe("Stream ID", func() {
Expect(StreamID(7).Type()).To(Equal(StreamTypeUni))
})
It("tells the first stream ID", func() {
Expect(FirstStream(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(0)))
Expect(FirstStream(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(1)))
Expect(FirstStream(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(2)))
Expect(FirstStream(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3)))
})
It("tells the stream number", func() {
Expect(StreamID(0).StreamNum()).To(BeEquivalentTo(1))
Expect(StreamID(1).StreamNum()).To(BeEquivalentTo(1))
Expect(StreamID(2).StreamNum()).To(BeEquivalentTo(1))
Expect(StreamID(3).StreamNum()).To(BeEquivalentTo(1))
Expect(StreamID(8).StreamNum()).To(BeEquivalentTo(3))
Expect(StreamID(9).StreamNum()).To(BeEquivalentTo(3))
Expect(StreamID(10).StreamNum()).To(BeEquivalentTo(3))
Expect(StreamID(11).StreamNum()).To(BeEquivalentTo(3))
})
Context("maximum stream IDs", func() {
Context("bidirectional streams", func() {
It("doesn't allow any", func() {
Expect(MaxBidiStreamID(0, PerspectiveClient)).To(Equal(StreamID(0)))
Expect(MaxBidiStreamID(0, PerspectiveServer)).To(Equal(StreamID(0)))
})
It("allows one", func() {
Expect(MaxBidiStreamID(1, PerspectiveClient)).To(Equal(StreamID(1)))
Expect(MaxBidiStreamID(1, PerspectiveServer)).To(Equal(StreamID(0)))
})
It("allows many", func() {
Expect(MaxBidiStreamID(100, PerspectiveClient)).To(Equal(StreamID(397)))
Expect(MaxBidiStreamID(100, PerspectiveServer)).To(Equal(StreamID(396)))
})
It("doesn't allow any", func() {
Expect(MaxStreamID(StreamTypeBidi, 0, PerspectiveClient)).To(Equal(StreamID(0)))
Expect(MaxStreamID(StreamTypeBidi, 0, PerspectiveServer)).To(Equal(StreamID(0)))
Expect(MaxStreamID(StreamTypeUni, 0, PerspectiveClient)).To(Equal(StreamID(0)))
Expect(MaxStreamID(StreamTypeUni, 0, PerspectiveServer)).To(Equal(StreamID(0)))
})
Context("unidirectional streams", func() {
It("doesn't allow any", func() {
Expect(MaxUniStreamID(0, PerspectiveClient)).To(Equal(StreamID(0)))
Expect(MaxUniStreamID(0, PerspectiveServer)).To(Equal(StreamID(0)))
})
It("allows one", func() {
Expect(MaxStreamID(StreamTypeBidi, 1, PerspectiveClient)).To(Equal(StreamID(0)))
Expect(MaxStreamID(StreamTypeBidi, 1, PerspectiveServer)).To(Equal(StreamID(1)))
Expect(MaxStreamID(StreamTypeUni, 1, PerspectiveClient)).To(Equal(StreamID(2)))
Expect(MaxStreamID(StreamTypeUni, 1, PerspectiveServer)).To(Equal(StreamID(3)))
})
It("allows one", func() {
Expect(MaxUniStreamID(1, PerspectiveClient)).To(Equal(StreamID(3)))
Expect(MaxUniStreamID(1, PerspectiveServer)).To(Equal(StreamID(2)))
})
It("allows many", func() {
Expect(MaxUniStreamID(100, PerspectiveClient)).To(Equal(StreamID(399)))
Expect(MaxUniStreamID(100, PerspectiveServer)).To(Equal(StreamID(398)))
})
It("allows many", func() {
Expect(MaxStreamID(StreamTypeBidi, 100, PerspectiveClient)).To(Equal(StreamID(396)))
Expect(MaxStreamID(StreamTypeBidi, 100, PerspectiveServer)).To(Equal(StreamID(397)))
Expect(MaxStreamID(StreamTypeUni, 100, PerspectiveClient)).To(Equal(StreamID(398)))
Expect(MaxStreamID(StreamTypeUni, 100, PerspectiveServer)).To(Equal(StreamID(399)))
})
})
})

View file

@ -1,39 +0,0 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A BlockedFrame is a BLOCKED frame
type BlockedFrame struct {
Offset protocol.ByteCount
}
// parseBlockedFrame parses a BLOCKED frame
func parseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
offset, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &BlockedFrame{
Offset: protocol.ByteCount(offset),
}, nil
}
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
typeByte := uint8(0x08)
b.WriteByte(typeByte)
utils.WriteVarInt(b, uint64(f.Offset))
return nil
}
// Length of a written frame
func (f *BlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.Offset))
}

View file

@ -0,0 +1,38 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A DataBlockedFrame is a DATA_BLOCKED frame
type DataBlockedFrame struct {
DataLimit protocol.ByteCount
}
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
offset, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &DataBlockedFrame{
DataLimit: protocol.ByteCount(offset),
}, nil
}
func (f *DataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
typeByte := uint8(0x08)
b.WriteByte(typeByte)
utils.WriteVarInt(b, uint64(f.DataLimit))
return nil
}
// Length of a written frame
func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.DataLimit))
}

View file

@ -11,25 +11,25 @@ import (
. "github.com/onsi/gomega"
)
var _ = Describe("BLOCKED frame", func() {
var _ = Describe("DATA_BLOCKED frame", func() {
Context("when parsing", func() {
It("accepts sample frame", func() {
data := []byte{0x08}
data = append(data, encodeVarInt(0x12345678)...)
b := bytes.NewReader(data)
frame, err := parseBlockedFrame(b, versionIETFFrames)
frame, err := parseDataBlockedFrame(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame.Offset).To(Equal(protocol.ByteCount(0x12345678)))
Expect(frame.DataLimit).To(Equal(protocol.ByteCount(0x12345678)))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0x08}
data = append(data, encodeVarInt(0x12345678)...)
_, err := parseBlockedFrame(bytes.NewReader(data), versionIETFFrames)
_, err := parseDataBlockedFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
for i := range data {
_, err := parseBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames)
_, err := parseDataBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames)
Expect(err).To(MatchError(io.EOF))
}
})
@ -38,7 +38,7 @@ var _ = Describe("BLOCKED frame", func() {
Context("when writing", func() {
It("writes a sample frame", func() {
b := &bytes.Buffer{}
frame := BlockedFrame{Offset: 0xdeadbeef}
frame := DataBlockedFrame{DataLimit: 0xdeadbeef}
err := frame.Write(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x08}
@ -47,7 +47,7 @@ var _ = Describe("BLOCKED frame", func() {
})
It("has the correct min length", func() {
frame := BlockedFrame{Offset: 0x12345}
frame := DataBlockedFrame{DataLimit: 0x12345}
Expect(frame.Length(versionIETFFrames)).To(Equal(1 + utils.VarIntLen(0x12345)))
})
})

View file

@ -55,25 +55,20 @@ func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame
if err != nil {
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
}
case 0x6:
frame, err = parseMaxStreamIDFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0x7:
frame, err = parsePingFrame(r, v)
case 0x8:
frame, err = parseBlockedFrame(r, v)
frame, err = parseDataBlockedFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0x9:
frame, err = parseStreamBlockedFrame(r, v)
frame, err = parseStreamDataBlockedFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0xa:
frame, err = parseStreamIDBlockedFrame(r, v)
case 0xa, 0xb:
frame, err = parseStreamsBlockedFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
@ -97,6 +92,11 @@ func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame
if err != nil {
err = qerr.Error(qerr.InvalidAckData, err.Error())
}
case 0x1c, 0x1d:
frame, err = parseMaxStreamsFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0x18:
frame, err = parseCryptoFrame(r, v)
if err != nil {

View file

@ -97,30 +97,10 @@ var _ = Describe("Frame parsing", func() {
Expect(frame).To(Equal(f))
})
It("unpacks MAX_STREAM_ID frames", func() {
f := &MaxStreamIDFrame{StreamID: 0x1337}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
})
It("unpacks connection-level BLOCKED frames", func() {
f := &BlockedFrame{Offset: 0x1234}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
})
It("unpacks stream-level BLOCKED frames", func() {
f := &StreamBlockedFrame{
StreamID: 0xdeadbeef,
Offset: 0xdead,
It("unpacks MAX_STREAMS frames", func() {
f := &MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreams: 0x1337,
}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
@ -130,8 +110,34 @@ var _ = Describe("Frame parsing", func() {
Expect(frame).To(Equal(f))
})
It("unpacks STREAM_ID_BLOCKED frames", func() {
f := &StreamIDBlockedFrame{StreamID: 0x1234567}
It("unpacks DATA_BLOCKED frames", func() {
f := &DataBlockedFrame{DataLimit: 0x1234}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
})
It("unpacks STREAM_DATA_BLOCKED frames", func() {
f := &StreamDataBlockedFrame{
StreamID: 0xdeadbeef,
DataLimit: 0xdead,
}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
})
It("unpacks STREAMS_BLOCKED frames", func() {
f := &StreamsBlockedFrame{
Type: protocol.StreamTypeBidi,
StreamLimit: 0x1234567,
}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())

View file

@ -1,37 +0,0 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A MaxStreamIDFrame is a MAX_STREAM_ID frame
type MaxStreamIDFrame struct {
StreamID protocol.StreamID
}
// parseMaxStreamIDFrame parses a MAX_STREAM_ID frame
func parseMaxStreamIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamIDFrame, error) {
// read the Type byte
if _, err := r.ReadByte(); err != nil {
return nil, err
}
streamID, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &MaxStreamIDFrame{StreamID: protocol.StreamID(streamID)}, nil
}
func (f *MaxStreamIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x6)
utils.WriteVarInt(b, uint64(f.StreamID))
return nil
}
// Length of a written frame
func (f *MaxStreamIDFrame) Length(protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.StreamID))
}

View file

@ -1,51 +0,0 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("MAX_STREAM_ID frame", func() {
Context("parsing", func() {
It("accepts sample frame", func() {
data := []byte{0x6}
data = append(data, encodeVarInt(0xdecafbad)...)
b := bytes.NewReader(data)
f, err := parseMaxStreamIDFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(f.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0x06}
data = append(data, encodeVarInt(0xdeadbeefcafe13)...)
_, err := parseMaxStreamIDFrame(bytes.NewReader(data), protocol.VersionWhatever)
Expect(err).NotTo(HaveOccurred())
for i := range data {
_, err := parseMaxStreamIDFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever)
Expect(err).To(HaveOccurred())
}
})
})
Context("writing", func() {
It("writes a sample frame", func() {
b := &bytes.Buffer{}
frame := MaxStreamIDFrame{StreamID: 0x12345678}
frame.Write(b, protocol.VersionWhatever)
expected := []byte{0x6}
expected = append(expected, encodeVarInt(0x12345678)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := MaxStreamIDFrame{StreamID: 0x1337}
Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(0x1337)))
})
})
})

View file

@ -0,0 +1,51 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A MaxStreamsFrame is a MAX_STREAMS frame
type MaxStreamsFrame struct {
Type protocol.StreamType
MaxStreams uint64
}
func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamsFrame, error) {
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
f := &MaxStreamsFrame{}
switch typeByte {
case 0x1c:
f.Type = protocol.StreamTypeBidi
case 0x1d:
f.Type = protocol.StreamTypeUni
}
streamID, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
f.MaxStreams = streamID
return f, nil
}
func (f *MaxStreamsFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
switch f.Type {
case protocol.StreamTypeBidi:
b.WriteByte(0x1c)
case protocol.StreamTypeUni:
b.WriteByte(0x1d)
}
utils.WriteVarInt(b, f.MaxStreams)
return nil
}
// Length of a written frame
func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(f.MaxStreams)
}

View file

@ -0,0 +1,78 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("MAX_STREAMS frame", func() {
Context("parsing", func() {
It("accepts a frame for a bidirectional stream", func() {
data := []byte{0x1c}
data = append(data, encodeVarInt(0xdecaf)...)
b := bytes.NewReader(data)
f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(f.Type).To(Equal(protocol.StreamTypeBidi))
Expect(f.MaxStreams).To(BeEquivalentTo(0xdecaf))
Expect(b.Len()).To(BeZero())
})
It("accepts a frame for a bidirectional stream", func() {
data := []byte{0x1d}
data = append(data, encodeVarInt(0xdecaf)...)
b := bytes.NewReader(data)
f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(f.Type).To(Equal(protocol.StreamTypeUni))
Expect(f.MaxStreams).To(BeEquivalentTo(0xdecaf))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0x1d}
data = append(data, encodeVarInt(0xdeadbeefcafe13)...)
_, err := parseMaxStreamsFrame(bytes.NewReader(data), protocol.VersionWhatever)
Expect(err).NotTo(HaveOccurred())
for i := range data {
_, err := parseMaxStreamsFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever)
Expect(err).To(HaveOccurred())
}
})
})
Context("writing", func() {
It("for a bidirectional stream", func() {
f := &MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreams: 0xdeadbeef,
}
b := &bytes.Buffer{}
Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed())
expected := []byte{0x1c}
expected = append(expected, encodeVarInt(0xdeadbeef)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("for a unidirectional stream", func() {
f := &MaxStreamsFrame{
Type: protocol.StreamTypeUni,
MaxStreams: 0xdecafbad,
}
b := &bytes.Buffer{}
Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed())
expected := []byte{0x1d}
expected = append(expected, encodeVarInt(0xdecafbad)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := MaxStreamsFrame{MaxStreams: 0x1337}
Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(0x1337)))
})
})
})

View file

@ -1,46 +0,0 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A StreamBlockedFrame in QUIC
type StreamBlockedFrame struct {
StreamID protocol.StreamID
Offset protocol.ByteCount
}
// parseStreamBlockedFrame parses a STREAM_BLOCKED frame
func parseStreamBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err
}
sid, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
offset, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &StreamBlockedFrame{
StreamID: protocol.StreamID(sid),
Offset: protocol.ByteCount(offset),
}, nil
}
// Write writes a STREAM_BLOCKED frame
func (f *StreamBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x09)
utils.WriteVarInt(b, uint64(f.StreamID))
utils.WriteVarInt(b, uint64(f.Offset))
return nil
}
// Length of a written frame
func (f *StreamBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.Offset))
}

View file

@ -0,0 +1,46 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame
type StreamDataBlockedFrame struct {
StreamID protocol.StreamID
DataLimit protocol.ByteCount
}
func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamDataBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
sid, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
offset, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &StreamDataBlockedFrame{
StreamID: protocol.StreamID(sid),
DataLimit: protocol.ByteCount(offset),
}, nil
}
func (f *StreamDataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x09)
utils.WriteVarInt(b, uint64(f.StreamID))
utils.WriteVarInt(b, uint64(f.DataLimit))
return nil
}
// Length of a written frame
func (f *StreamDataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.DataLimit))
}

View file

@ -10,17 +10,17 @@ import (
. "github.com/onsi/gomega"
)
var _ = Describe("STREAM_BLOCKED frame", func() {
var _ = Describe("STREAM_DATA_BLOCKED frame", func() {
Context("parsing", func() {
It("accepts sample frame", func() {
data := []byte{0x9}
data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID
data = append(data, encodeVarInt(0xdecafbad)...) // offset
b := bytes.NewReader(data)
frame, err := parseStreamBlockedFrame(b, versionIETFFrames)
frame, err := parseStreamDataBlockedFrame(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef)))
Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad)))
Expect(frame.DataLimit).To(Equal(protocol.ByteCount(0xdecafbad)))
Expect(b.Len()).To(BeZero())
})
@ -28,10 +28,10 @@ var _ = Describe("STREAM_BLOCKED frame", func() {
data := []byte{0x9}
data = append(data, encodeVarInt(0xdeadbeef)...)
data = append(data, encodeVarInt(0xc0010ff)...)
_, err := parseStreamBlockedFrame(bytes.NewReader(data), versionIETFFrames)
_, err := parseStreamDataBlockedFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).NotTo(HaveOccurred())
for i := range data {
_, err := parseStreamBlockedFrame(bytes.NewReader(data[0:i]), versionIETFFrames)
_, err := parseStreamDataBlockedFrame(bytes.NewReader(data[0:i]), versionIETFFrames)
Expect(err).To(HaveOccurred())
}
})
@ -39,24 +39,24 @@ var _ = Describe("STREAM_BLOCKED frame", func() {
Context("writing", func() {
It("has proper min length", func() {
f := &StreamBlockedFrame{
StreamID: 0x1337,
Offset: 0xdeadbeef,
f := &StreamDataBlockedFrame{
StreamID: 0x1337,
DataLimit: 0xdeadbeef,
}
Expect(f.Length(0)).To(Equal(1 + utils.VarIntLen(0x1337) + utils.VarIntLen(0xdeadbeef)))
})
It("writes a sample frame", func() {
b := &bytes.Buffer{}
f := &StreamBlockedFrame{
StreamID: 0xdecafbad,
Offset: 0x1337,
f := &StreamDataBlockedFrame{
StreamID: 0xdecafbad,
DataLimit: 0x1337,
}
err := f.Write(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x9}
expected = append(expected, encodeVarInt(uint64(f.StreamID))...)
expected = append(expected, encodeVarInt(uint64(f.Offset))...)
expected = append(expected, encodeVarInt(uint64(f.DataLimit))...)
Expect(b.Bytes()).To(Equal(expected))
})
})

View file

@ -1,37 +0,0 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A StreamIDBlockedFrame is a STREAM_ID_BLOCKED frame
type StreamIDBlockedFrame struct {
StreamID protocol.StreamID
}
// parseStreamIDBlockedFrame parses a STREAM_ID_BLOCKED frame
func parseStreamIDBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamIDBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
streamID, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &StreamIDBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil
}
func (f *StreamIDBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
typeByte := uint8(0x0a)
b.WriteByte(typeByte)
utils.WriteVarInt(b, uint64(f.StreamID))
return nil
}
// Length of a written frame
func (f *StreamIDBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.StreamID))
}

View file

@ -1,53 +0,0 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("STREAM_ID_BLOCKED frame", func() {
Context("parsing", func() {
It("accepts sample frame", func() {
expected := []byte{0xa}
expected = append(expected, encodeVarInt(0xdecafbad)...)
b := bytes.NewReader(expected)
frame, err := parseStreamIDBlockedFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0xa}
data = append(data, encodeVarInt(0x12345678)...)
_, err := parseStreamIDBlockedFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
for i := range data {
_, err := parseStreamIDBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames)
Expect(err).To(MatchError(io.EOF))
}
})
})
Context("writing", func() {
It("writes a sample frame", func() {
b := &bytes.Buffer{}
frame := StreamIDBlockedFrame{StreamID: 0xdeadbeefcafe}
err := frame.Write(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0xa}
expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := StreamIDBlockedFrame{StreamID: 0x123456}
Expect(frame.Length(0)).To(Equal(protocol.ByteCount(1) + utils.VarIntLen(0x123456)))
})
})
})

View file

@ -0,0 +1,52 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A StreamsBlockedFrame is a STREAMS_BLOCKED frame
type StreamsBlockedFrame struct {
Type protocol.StreamType
StreamLimit uint64
}
func parseStreamsBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) {
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
f := &StreamsBlockedFrame{}
switch typeByte {
case 0xa:
f.Type = protocol.StreamTypeBidi
case 0xb:
f.Type = protocol.StreamTypeUni
}
streamLimit, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
f.StreamLimit = streamLimit
return f, nil
}
func (f *StreamsBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
switch f.Type {
case protocol.StreamTypeBidi:
b.WriteByte(0xa)
case protocol.StreamTypeUni:
b.WriteByte(0xb)
}
utils.WriteVarInt(b, f.StreamLimit)
return nil
}
// Length of a written frame
func (f *StreamsBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(f.StreamLimit)
}

View file

@ -0,0 +1,79 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("STREAMS_BLOCKED frame", func() {
Context("parsing", func() {
It("accepts a frame for bidirectional streams", func() {
expected := []byte{0xa}
expected = append(expected, encodeVarInt(0x1337)...)
b := bytes.NewReader(expected)
f, err := parseStreamsBlockedFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(f.Type).To(Equal(protocol.StreamTypeBidi))
Expect(f.StreamLimit).To(BeEquivalentTo(0x1337))
Expect(b.Len()).To(BeZero())
})
It("accepts a frame for unidirectional streams", func() {
expected := []byte{0xb}
expected = append(expected, encodeVarInt(0x7331)...)
b := bytes.NewReader(expected)
f, err := parseStreamsBlockedFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(f.Type).To(Equal(protocol.StreamTypeUni))
Expect(f.StreamLimit).To(BeEquivalentTo(0x7331))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0xa}
data = append(data, encodeVarInt(0x12345678)...)
_, err := parseStreamsBlockedFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
for i := range data {
_, err := parseStreamsBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames)
Expect(err).To(MatchError(io.EOF))
}
})
})
Context("writing", func() {
It("writes a frame for bidirectional streams", func() {
b := &bytes.Buffer{}
f := StreamsBlockedFrame{
Type: protocol.StreamTypeBidi,
StreamLimit: 0xdeadbeefcafe,
}
Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed())
expected := []byte{0xa}
expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("writes a frame for unidirectional streams", func() {
b := &bytes.Buffer{}
f := StreamsBlockedFrame{
Type: protocol.StreamTypeUni,
StreamLimit: 0xdeadbeefcafe,
}
Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed())
expected := []byte{0xb}
expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := StreamsBlockedFrame{StreamLimit: 0x123456}
Expect(frame.Length(0)).To(Equal(protocol.ByteCount(1) + utils.VarIntLen(0x123456)))
})
})
})

View file

@ -110,16 +110,16 @@ func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0)
}
// HandleMaxStreamIDFrame mocks base method
func (m *MockStreamManager) HandleMaxStreamIDFrame(arg0 *wire.MaxStreamIDFrame) error {
ret := m.ctrl.Call(m, "HandleMaxStreamIDFrame", arg0)
// HandleMaxStreamsFrame mocks base method
func (m *MockStreamManager) HandleMaxStreamsFrame(arg0 *wire.MaxStreamsFrame) error {
ret := m.ctrl.Call(m, "HandleMaxStreamsFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// HandleMaxStreamIDFrame indicates an expected call of HandleMaxStreamIDFrame
func (mr *MockStreamManagerMockRecorder) HandleMaxStreamIDFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamIDFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamIDFrame), arg0)
// HandleMaxStreamsFrame indicates an expected call of HandleMaxStreamsFrame
func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0)
}
// OpenStream mocks base method

View file

@ -57,10 +57,10 @@ var _ = Describe("Packet Unpacker", func() {
It("unpacks the frames", func() {
buf := &bytes.Buffer{}
(&wire.PingFrame{}).Write(buf, protocol.VersionWhatever)
(&wire.BlockedFrame{}).Write(buf, protocol.VersionWhatever)
(&wire.DataBlockedFrame{}).Write(buf, protocol.VersionWhatever)
aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return(buf.Bytes(), nil)
packet, err := unpacker.Unpack(hdr.Raw, hdr, nil)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{&wire.PingFrame{}, &wire.BlockedFrame{}}))
Expect(packet.frames).To(Equal([]wire.Frame{&wire.PingFrame{}, &wire.DataBlockedFrame{}}))
})
})

View file

@ -169,9 +169,9 @@ func (s *sendStream) popStreamFrameImpl(maxBytes protocol.ByteCount) (bool /* co
return false, nil, false
}
if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
s.sender.queueControlFrame(&wire.StreamBlockedFrame{
StreamID: s.streamID,
Offset: offset,
s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
StreamID: s.streamID,
DataLimit: offset,
})
return false, nil, false
}

View file

@ -179,9 +179,9 @@ var _ = Describe("Send Stream", func() {
It("queues a BLOCKED frame if the stream is flow control blocked", func() {
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(0))
mockFC.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(12))
mockSender.EXPECT().queueControlFrame(&wire.StreamBlockedFrame{
StreamID: streamID,
Offset: 12,
mockSender.EXPECT().queueControlFrame(&wire.StreamDataBlockedFrame{
StreamID: streamID,
DataLimit: 12,
})
mockSender.EXPECT().onHasStreamData(streamID)
done := make(chan struct{})
@ -224,9 +224,9 @@ var _ = Describe("Send Stream", func() {
mockFC.EXPECT().SendWindowSize()
// don't use offset 3 here, to make sure the BLOCKED frame contains the number returned by the flow controller
mockFC.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(10))
mockSender.EXPECT().queueControlFrame(&wire.StreamBlockedFrame{
StreamID: streamID,
Offset: 10,
mockSender.EXPECT().queueControlFrame(&wire.StreamDataBlockedFrame{
StreamID: streamID,
DataLimit: 10,
})
f, hasMoreData = str.popStreamFrame(1000)
Expect(f).To(BeNil())

View file

@ -40,7 +40,7 @@ type streamManager interface {
AcceptUniStream() (ReceiveStream, error)
DeleteStream(protocol.StreamID) error
UpdateLimits(*handshake.TransportParameters)
HandleMaxStreamIDFrame(*wire.MaxStreamIDFrame) error
HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error
CloseWithError(error)
}
@ -158,7 +158,14 @@ var newSession = func(
s.preSetup()
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version)
s.streamsMap = newStreamsMap(
s,
s.newFlowController,
uint64(s.config.MaxIncomingStreams),
uint64(s.config.MaxIncomingUniStreams),
s.perspective,
s.version,
)
s.framer = newFramer(s.streamsMap, s.version)
cs, err := handshake.NewCryptoSetupServer(
initialStream,
@ -248,7 +255,14 @@ var newClientSession = func(
s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream)
s.unpacker = newPacketUnpacker(cs, s.version)
s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version)
s.streamsMap = newStreamsMap(
s,
s.newFlowController,
uint64(s.config.MaxIncomingStreams),
uint64(s.config.MaxIncomingUniStreams),
s.perspective,
s.version,
)
s.framer = newFramer(s.streamsMap, s.version)
s.packer = newPacketPacker(
s.destConnID,
@ -549,11 +563,11 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
s.handleMaxDataFrame(frame)
case *wire.MaxStreamDataFrame:
err = s.handleMaxStreamDataFrame(frame)
case *wire.MaxStreamIDFrame:
err = s.handleMaxStreamIDFrame(frame)
case *wire.BlockedFrame:
case *wire.StreamBlockedFrame:
case *wire.StreamIDBlockedFrame:
case *wire.MaxStreamsFrame:
err = s.handleMaxStreamsFrame(frame)
case *wire.DataBlockedFrame:
case *wire.StreamDataBlockedFrame:
case *wire.StreamsBlockedFrame:
case *wire.StopSendingFrame:
err = s.handleStopSendingFrame(frame)
case *wire.PingFrame:
@ -627,8 +641,8 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error
return nil
}
func (s *session) handleMaxStreamIDFrame(frame *wire.MaxStreamIDFrame) error {
return s.streamsMap.HandleMaxStreamIDFrame(frame)
func (s *session) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) error {
return s.streamsMap.HandleMaxStreamsFrame(frame)
}
func (s *session) handleResetStreamFrame(frame *wire.ResetStreamFrame) error {
@ -893,7 +907,7 @@ func (s *session) sendProbePacket() error {
func (s *session) sendPacket() (bool, error) {
if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked {
s.framer.QueueControlFrame(&wire.BlockedFrame{Offset: offset})
s.framer.QueueControlFrame(&wire.DataBlockedFrame{DataLimit: offset})
}
s.windowUpdateQueue.QueueAll()

View file

@ -249,17 +249,20 @@ var _ = Describe("Session", func() {
Context("handling MAX_STREAM_ID frames", func() {
It("passes the frame to the streamsMap", func() {
f := &wire.MaxStreamIDFrame{StreamID: 10}
streamManager.EXPECT().HandleMaxStreamIDFrame(f)
err := sess.handleMaxStreamIDFrame(f)
f := &wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni,
MaxStreams: 10,
}
streamManager.EXPECT().HandleMaxStreamsFrame(f)
err := sess.handleMaxStreamsFrame(f)
Expect(err).ToNot(HaveOccurred())
})
It("returns errors", func() {
f := &wire.MaxStreamIDFrame{StreamID: 10}
f := &wire.MaxStreamsFrame{MaxStreams: 10}
testErr := errors.New("test error")
streamManager.EXPECT().HandleMaxStreamIDFrame(f).Return(testErr)
err := sess.handleMaxStreamIDFrame(f)
streamManager.EXPECT().HandleMaxStreamsFrame(f).Return(testErr)
err := sess.handleMaxStreamsFrame(f)
Expect(err).To(MatchError(testErr))
})
})
@ -306,17 +309,17 @@ var _ = Describe("Session", func() {
})
It("handles BLOCKED frames", func() {
err := sess.handleFrames([]wire.Frame{&wire.BlockedFrame{}}, protocol.EncryptionUnspecified)
err := sess.handleFrames([]wire.Frame{&wire.DataBlockedFrame{}}, protocol.EncryptionUnspecified)
Expect(err).NotTo(HaveOccurred())
})
It("handles STREAM_BLOCKED frames", func() {
err := sess.handleFrames([]wire.Frame{&wire.StreamBlockedFrame{}}, protocol.EncryptionUnspecified)
err := sess.handleFrames([]wire.Frame{&wire.StreamDataBlockedFrame{}}, protocol.EncryptionUnspecified)
Expect(err).NotTo(HaveOccurred())
})
It("handles STREAM_ID_BLOCKED frames", func() {
err := sess.handleFrames([]wire.Frame{&wire.StreamIDBlockedFrame{}}, protocol.EncryptionUnspecified)
err := sess.handleFrames([]wire.Frame{&wire.StreamsBlockedFrame{}}, protocol.EncryptionUnspecified)
Expect(err).NotTo(HaveOccurred())
})
@ -598,7 +601,7 @@ var _ = Describe("Session", func() {
Expect(err).NotTo(HaveOccurred())
Expect(sent).To(BeTrue())
frames, _ := sess.framer.AppendControlFrames(nil, 1000)
Expect(frames).To(Equal([]wire.Frame{&wire.BlockedFrame{Offset: 1337}}))
Expect(frames).To(Equal([]wire.Frame{&wire.DataBlockedFrame{DataLimit: 1337}}))
})
It("sends a retransmission and a regular packet in the same run", func() {

View file

@ -26,8 +26,8 @@ var _ streamManager = &streamsMap{}
func newStreamsMap(
sender streamSender,
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
maxIncomingStreams int,
maxIncomingUniStreams int,
maxIncomingStreams uint64,
maxIncomingUniStreams uint64,
perspective protocol.Perspective,
version protocol.VersionNumber,
) streamManager {
@ -36,18 +36,6 @@ func newStreamsMap(
newFlowController: newFlowController,
sender: sender,
}
var firstOutgoingBidiStream, firstOutgoingUniStream, firstIncomingBidiStream, firstIncomingUniStream protocol.StreamID
if perspective == protocol.PerspectiveServer {
firstOutgoingBidiStream = 1
firstIncomingBidiStream = 0
firstOutgoingUniStream = 3
firstIncomingUniStream = 2
} else {
firstOutgoingBidiStream = 0
firstIncomingBidiStream = 1
firstOutgoingUniStream = 2
firstIncomingUniStream = 3
}
newBidiStream := func(id protocol.StreamID) streamI {
return newStream(id, m.sender, m.newFlowController(id), version)
}
@ -58,25 +46,25 @@ func newStreamsMap(
return newReceiveStream(id, m.sender, m.newFlowController(id), version)
}
m.outgoingBidiStreams = newOutgoingBidiStreamsMap(
firstOutgoingBidiStream,
protocol.FirstStream(protocol.StreamTypeBidi, perspective),
newBidiStream,
sender.queueControlFrame,
)
m.incomingBidiStreams = newIncomingBidiStreamsMap(
firstIncomingBidiStream,
protocol.MaxBidiStreamID(maxIncomingStreams, perspective),
protocol.FirstStream(protocol.StreamTypeBidi, perspective.Opposite()),
protocol.MaxStreamID(protocol.StreamTypeBidi, maxIncomingStreams, perspective.Opposite()),
maxIncomingStreams,
sender.queueControlFrame,
newBidiStream,
)
m.outgoingUniStreams = newOutgoingUniStreamsMap(
firstOutgoingUniStream,
protocol.FirstStream(protocol.StreamTypeUni, perspective),
newUniSendStream,
sender.queueControlFrame,
)
m.incomingUniStreams = newIncomingUniStreamsMap(
firstIncomingUniStream,
protocol.MaxUniStreamID(maxIncomingUniStreams, perspective),
protocol.FirstStream(protocol.StreamTypeUni, perspective.Opposite()),
protocol.MaxStreamID(protocol.StreamTypeUni, maxIncomingUniStreams, perspective.Opposite()),
maxIncomingUniStreams,
sender.queueControlFrame,
newUniReceiveStream,
@ -158,15 +146,13 @@ func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, err
panic("")
}
func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error {
id := f.StreamID
if id.InitiatedBy() != m.perspective {
return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id)
}
func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error {
id := protocol.MaxStreamID(f.Type, f.MaxStreams, m.perspective)
switch id.Type() {
case protocol.StreamTypeUni:
m.outgoingUniStreams.SetMaxStream(id)
case protocol.StreamTypeBidi:
fmt.Printf("")
m.outgoingBidiStreams.SetMaxStream(id)
}
return nil
@ -174,13 +160,8 @@ func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error {
func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) {
// Max{Uni,Bidi}StreamID returns the highest stream ID that the peer is allowed to open.
// Invert the perspective to determine the value that we are allowed to open.
peerPers := protocol.PerspectiveServer
if m.perspective == protocol.PerspectiveServer {
peerPers = protocol.PerspectiveClient
}
m.outgoingBidiStreams.SetMaxStream(protocol.MaxBidiStreamID(int(p.MaxBidiStreams), peerPers))
m.outgoingUniStreams.SetMaxStream(protocol.MaxUniStreamID(int(p.MaxUniStreams), peerPers))
m.outgoingBidiStreams.SetMaxStream(protocol.MaxStreamID(protocol.StreamTypeBidi, p.MaxBidiStreams, m.perspective))
m.outgoingUniStreams.SetMaxStream(protocol.MaxStreamID(protocol.StreamTypeUni, p.MaxUniStreams, m.perspective))
}
func (m *streamsMap) CloseWithError(err error) {

View file

@ -1,6 +1,10 @@
package quic
import "github.com/cheekybits/genny/generic"
import (
"github.com/cheekybits/genny/generic"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// In the auto-generated streams maps, we need to be able to close the streams.
// Therefore, extend the generic.Type with the stream close method.
@ -9,3 +13,5 @@ type item interface {
generic.Type
closeForShutdown(error)
}
const streamTypeGeneric protocol.StreamType = protocol.StreamTypeUni

View file

@ -21,10 +21,10 @@ type incomingBidiStreamsMap struct {
nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream()
nextStreamToOpen protocol.StreamID // the highest stream that the peer openend
maxStream protocol.StreamID // the highest stream that the peer is allowed to open
maxNumStreams int // maximum number of streams
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamID) streamI
queueMaxStreamID func(*wire.MaxStreamIDFrame)
queueMaxStreamID func(*wire.MaxStreamsFrame)
closeErr error
}
@ -32,7 +32,7 @@ type incomingBidiStreamsMap struct {
func newIncomingBidiStreamsMap(
nextStreamToAccept protocol.StreamID,
initialMaxStreamID protocol.StreamID,
maxNumStreams int,
maxNumStreams uint64,
queueControlFrame func(wire.Frame),
newStream func(protocol.StreamID) streamI,
) *incomingBidiStreamsMap {
@ -43,7 +43,7 @@ func newIncomingBidiStreamsMap(
maxStream: initialMaxStreamID,
maxNumStreams: maxNumStreams,
newStream: newStream,
queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) },
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
}
m.cond.L = &m.mutex
return m
@ -108,9 +108,13 @@ func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error {
}
delete(m.streams, id)
// queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream
if numNewStreams := m.maxNumStreams - len(m.streams); numNewStreams > 0 {
if m.maxNumStreams > uint64(len(m.streams)) {
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream})
m.queueMaxStreamID(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreams: m.maxStream.StreamNum(),
})
}
return nil
}

View file

@ -8,8 +8,8 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)
//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream"
//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream"
//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi"
//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni"
type incomingItemsMap struct {
mutex sync.RWMutex
cond sync.Cond
@ -19,10 +19,10 @@ type incomingItemsMap struct {
nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream()
nextStreamToOpen protocol.StreamID // the highest stream that the peer openend
maxStream protocol.StreamID // the highest stream that the peer is allowed to open
maxNumStreams int // maximum number of streams
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamID) item
queueMaxStreamID func(*wire.MaxStreamIDFrame)
queueMaxStreamID func(*wire.MaxStreamsFrame)
closeErr error
}
@ -30,7 +30,7 @@ type incomingItemsMap struct {
func newIncomingItemsMap(
nextStreamToAccept protocol.StreamID,
initialMaxStreamID protocol.StreamID,
maxNumStreams int,
maxNumStreams uint64,
queueControlFrame func(wire.Frame),
newStream func(protocol.StreamID) item,
) *incomingItemsMap {
@ -41,7 +41,7 @@ func newIncomingItemsMap(
maxStream: initialMaxStreamID,
maxNumStreams: maxNumStreams,
newStream: newStream,
queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) },
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
}
m.cond.L = &m.mutex
return m
@ -106,9 +106,13 @@ func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error {
}
delete(m.streams, id)
// queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream
if numNewStreams := m.maxNumStreams - len(m.streams); numNewStreams > 0 {
if m.maxNumStreams > uint64(len(m.streams)) {
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream})
m.queueMaxStreamID(&wire.MaxStreamsFrame{
Type: streamTypeGeneric,
MaxStreams: m.maxStream.StreamNum(),
})
}
return nil
}

View file

@ -26,8 +26,8 @@ func (s *mockGenericStream) closeForShutdown(err error) {
var _ = Describe("Streams Map (incoming)", func() {
const (
firstNewStream protocol.StreamID = 20
maxNumStreams int = 10
firstNewStream protocol.StreamID = 2
maxNumStreams uint64 = 5
initialMaxStream protocol.StreamID = firstNewStream + 4*protocol.StreamID(maxNumStreams-1)
)
@ -49,9 +49,9 @@ var _ = Describe("Streams Map (incoming)", func() {
})
It("opens all streams up to the id on GetOrOpenStream", func() {
_, err := m.GetOrOpenStream(firstNewStream + 4*5)
_, err := m.GetOrOpenStream(firstNewStream + 4*4)
Expect(err).ToNot(HaveOccurred())
Expect(newItemCounter).To(Equal(6))
Expect(newItemCounter).To(Equal(5))
})
It("starts opening streams at the right position", func() {
@ -59,9 +59,9 @@ var _ = Describe("Streams Map (incoming)", func() {
_, err := m.GetOrOpenStream(firstNewStream + 4)
Expect(err).ToNot(HaveOccurred())
Expect(newItemCounter).To(Equal(2))
_, err = m.GetOrOpenStream(firstNewStream + 4*5)
_, err = m.GetOrOpenStream(firstNewStream + 4*4)
Expect(err).ToNot(HaveOccurred())
Expect(newItemCounter).To(Equal(6))
Expect(newItemCounter).To(Equal(5))
})
It("accepts streams in the right order", func() {
@ -143,9 +143,9 @@ var _ = Describe("Streams Map (incoming)", func() {
})
It("closes all streams when CloseWithError is called", func() {
str1, err := m.GetOrOpenStream(20)
str1, err := m.GetOrOpenStream(firstNewStream)
Expect(err).ToNot(HaveOccurred())
str2, err := m.GetOrOpenStream(20 + 8)
str2, err := m.GetOrOpenStream(firstNewStream + 8)
Expect(err).ToNot(HaveOccurred())
testErr := errors.New("test err")
m.CloseWithError(testErr)
@ -157,11 +157,11 @@ var _ = Describe("Streams Map (incoming)", func() {
It("deletes streams", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
_, err := m.GetOrOpenStream(20)
_, err := m.GetOrOpenStream(initialMaxStream)
Expect(err).ToNot(HaveOccurred())
err = m.DeleteStream(20)
err = m.DeleteStream(initialMaxStream)
Expect(err).ToNot(HaveOccurred())
str, err := m.GetOrOpenStream(20)
str, err := m.GetOrOpenStream(initialMaxStream)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
@ -171,13 +171,17 @@ var _ = Describe("Streams Map (incoming)", func() {
Expect(err).To(MatchError("Tried to delete unknown stream 1337"))
})
It("sends MAX_STREAM_ID frames when streams are deleted", func() {
It("sends MAX_STREAMS frames when streams are deleted", func() {
// open a bunch of streams
_, err := m.GetOrOpenStream(firstNewStream + 4*4)
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{StreamID: initialMaxStream + 4})
Expect(m.DeleteStream(firstNewStream + 4)).To(Succeed())
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{StreamID: initialMaxStream + 8})
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
Expect(f.(*wire.MaxStreamsFrame).MaxStreams).To(Equal(maxNumStreams + 1))
})
Expect(m.DeleteStream(firstNewStream + 2*4)).To(Succeed())
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
Expect(f.(*wire.MaxStreamsFrame).MaxStreams).To(Equal(maxNumStreams + 2))
})
Expect(m.DeleteStream(firstNewStream + 3*4)).To(Succeed())
})
})

View file

@ -21,10 +21,10 @@ type incomingUniStreamsMap struct {
nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream()
nextStreamToOpen protocol.StreamID // the highest stream that the peer openend
maxStream protocol.StreamID // the highest stream that the peer is allowed to open
maxNumStreams int // maximum number of streams
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamID) receiveStreamI
queueMaxStreamID func(*wire.MaxStreamIDFrame)
queueMaxStreamID func(*wire.MaxStreamsFrame)
closeErr error
}
@ -32,7 +32,7 @@ type incomingUniStreamsMap struct {
func newIncomingUniStreamsMap(
nextStreamToAccept protocol.StreamID,
initialMaxStreamID protocol.StreamID,
maxNumStreams int,
maxNumStreams uint64,
queueControlFrame func(wire.Frame),
newStream func(protocol.StreamID) receiveStreamI,
) *incomingUniStreamsMap {
@ -43,7 +43,7 @@ func newIncomingUniStreamsMap(
maxStream: initialMaxStreamID,
maxNumStreams: maxNumStreams,
newStream: newStream,
queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) },
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
}
m.cond.L = &m.mutex
return m
@ -108,9 +108,13 @@ func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error {
}
delete(m.streams, id)
// queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream
if numNewStreams := m.maxNumStreams - len(m.streams); numNewStreams > 0 {
if m.maxNumStreams > uint64(len(m.streams)) {
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream})
m.queueMaxStreamID(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni,
MaxStreams: m.maxStream.StreamNum(),
})
}
return nil
}

View file

@ -19,13 +19,13 @@ type outgoingBidiStreamsMap struct {
streams map[protocol.StreamID]streamI
nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamID // the maximum stream ID we're allowed to open
maxStreamSet bool // was maxStream set. If not, it's not possible to any stream (also works for stream 0)
highestBlocked protocol.StreamID // the highest stream ID that we queued a STREAM_ID_BLOCKED frame for
nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamID // the maximum stream ID we're allowed to open
maxStreamSet bool // was maxStream set. If not, it's not possible to any stream (also works for stream 0)
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
newStream func(protocol.StreamID) streamI
queueStreamIDBlocked func(*wire.StreamIDBlockedFrame)
queueStreamIDBlocked func(*wire.StreamsBlockedFrame)
closeErr error
}
@ -39,7 +39,7 @@ func newOutgoingBidiStreamsMap(
streams: make(map[protocol.StreamID]streamI),
nextStream: nextStream,
newStream: newStream,
queueStreamIDBlocked: func(f *wire.StreamIDBlockedFrame) { queueControlFrame(f) },
queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) },
}
m.cond.L = &m.mutex
return m
@ -73,9 +73,19 @@ func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) {
return nil, m.closeErr
}
if !m.maxStreamSet || m.nextStream > m.maxStream {
if m.maxStream == 0 || m.highestBlocked < m.maxStream {
m.queueStreamIDBlocked(&wire.StreamIDBlockedFrame{StreamID: m.maxStream})
m.highestBlocked = m.maxStream
if !m.blockedSent {
if m.maxStreamSet {
m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
Type: protocol.StreamTypeBidi,
StreamLimit: m.maxStream.StreamNum(),
})
} else {
m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
Type: protocol.StreamTypeBidi,
StreamLimit: 0,
})
}
m.blockedSent = true
}
return nil, qerr.TooManyOpenStreams
}
@ -112,6 +122,7 @@ func (m *outgoingBidiStreamsMap) SetMaxStream(id protocol.StreamID) {
if !m.maxStreamSet || id > m.maxStream {
m.maxStream = id
m.maxStreamSet = true
m.blockedSent = false
m.cond.Broadcast()
}
m.mutex.Unlock()

View file

@ -9,21 +9,21 @@ import (
"github.com/lucas-clemente/quic-go/qerr"
)
//go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream"
//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream"
//go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi"
//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni"
type outgoingItemsMap struct {
mutex sync.RWMutex
cond sync.Cond
streams map[protocol.StreamID]item
nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamID // the maximum stream ID we're allowed to open
maxStreamSet bool // was maxStream set. If not, it's not possible to any stream (also works for stream 0)
highestBlocked protocol.StreamID // the highest stream ID that we queued a STREAM_ID_BLOCKED frame for
nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamID // the maximum stream ID we're allowed to open
maxStreamSet bool // was maxStream set. If not, it's not possible to any stream (also works for stream 0)
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
newStream func(protocol.StreamID) item
queueStreamIDBlocked func(*wire.StreamIDBlockedFrame)
queueStreamIDBlocked func(*wire.StreamsBlockedFrame)
closeErr error
}
@ -37,7 +37,7 @@ func newOutgoingItemsMap(
streams: make(map[protocol.StreamID]item),
nextStream: nextStream,
newStream: newStream,
queueStreamIDBlocked: func(f *wire.StreamIDBlockedFrame) { queueControlFrame(f) },
queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) },
}
m.cond.L = &m.mutex
return m
@ -71,9 +71,19 @@ func (m *outgoingItemsMap) openStreamImpl() (item, error) {
return nil, m.closeErr
}
if !m.maxStreamSet || m.nextStream > m.maxStream {
if m.maxStream == 0 || m.highestBlocked < m.maxStream {
m.queueStreamIDBlocked(&wire.StreamIDBlockedFrame{StreamID: m.maxStream})
m.highestBlocked = m.maxStream
if !m.blockedSent {
if m.maxStreamSet {
m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
Type: streamTypeGeneric,
StreamLimit: m.maxStream.StreamNum(),
})
} else {
m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
Type: streamTypeGeneric,
StreamLimit: 0,
})
}
m.blockedSent = true
}
return nil, qerr.TooManyOpenStreams
}
@ -110,6 +120,7 @@ func (m *outgoingItemsMap) SetMaxStream(id protocol.StreamID) {
if !m.maxStreamSet || id > m.maxStream {
m.maxStream = id
m.maxStreamSet = true
m.blockedSent = false
m.cond.Broadcast()
}
m.mutex.Unlock()

View file

@ -12,7 +12,8 @@ import (
)
var _ = Describe("Streams Map (outgoing)", func() {
const firstNewStream protocol.StreamID = 10
const firstNewStream protocol.StreamID = 3
var (
m *outgoingItemsMap
newItem func(id protocol.StreamID) item
@ -57,16 +58,16 @@ var _ = Describe("Streams Map (outgoing)", func() {
})
It("errors when trying to get a stream that has not yet been opened", func() {
_, err := m.GetStream(10)
Expect(err).To(MatchError(qerr.Error(qerr.InvalidStreamID, "peer attempted to open stream 10")))
_, err := m.GetStream(firstNewStream)
Expect(err).To(MatchError(qerr.Error(qerr.InvalidStreamID, "peer attempted to open stream 3")))
})
It("deletes streams", func() {
_, err := m.OpenStream() // opens stream 10
_, err := m.OpenStream() // opens firstNewStream
Expect(err).ToNot(HaveOccurred())
err = m.DeleteStream(10)
err = m.DeleteStream(firstNewStream)
Expect(err).ToNot(HaveOccurred())
str, err := m.GetStream(10)
str, err := m.GetStream(firstNewStream)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
@ -77,12 +78,12 @@ var _ = Describe("Streams Map (outgoing)", func() {
})
It("errors when deleting a stream twice", func() {
_, err := m.OpenStream() // opens stream 10
_, err := m.OpenStream() // opens firstNewStream
Expect(err).ToNot(HaveOccurred())
err = m.DeleteStream(10)
err = m.DeleteStream(firstNewStream)
Expect(err).ToNot(HaveOccurred())
err = m.DeleteStream(10)
Expect(err).To(MatchError("Tried to delete unknown stream 10"))
err = m.DeleteStream(firstNewStream)
Expect(err).To(MatchError("Tried to delete unknown stream 3"))
})
It("closes all streams when CloseWithError is called", func() {
@ -124,7 +125,9 @@ var _ = Describe("Streams Map (outgoing)", func() {
It("works with stream 0", func() {
m = newOutgoingItemsMap(0, newItem, mockSender.queueControlFrame)
mockSender.EXPECT().queueControlFrame(&wire.StreamIDBlockedFrame{StreamID: 0})
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeZero())
})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -156,25 +159,35 @@ var _ = Describe("Streams Map (outgoing)", func() {
})
It("doesn't reduce the stream limit", func() {
m.SetMaxStream(firstNewStream + 4)
m.SetMaxStream(firstNewStream)
m.SetMaxStream(firstNewStream - 4)
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream))
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream + 4))
})
It("queues a STREAM_ID_BLOCKED frame if no stream can be opened", func() {
m.SetMaxStream(firstNewStream)
mockSender.EXPECT().queueControlFrame(&wire.StreamIDBlockedFrame{StreamID: firstNewStream})
m.SetMaxStream(firstNewStream + 5*4)
// open the 6 allowed streams
for i := 0; i < 6; i++ {
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
}
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(6))
})
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("only sends one STREAM_ID_BLOCKED frame for one stream ID", func() {
m.SetMaxStream(firstNewStream)
mockSender.EXPECT().queueControlFrame(&wire.StreamIDBlockedFrame{StreamID: firstNewStream})
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1))
})
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
// try to open a stream twice, but expect only one STREAM_ID_BLOCKED to be sent

View file

@ -19,13 +19,13 @@ type outgoingUniStreamsMap struct {
streams map[protocol.StreamID]sendStreamI
nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamID // the maximum stream ID we're allowed to open
maxStreamSet bool // was maxStream set. If not, it's not possible to any stream (also works for stream 0)
highestBlocked protocol.StreamID // the highest stream ID that we queued a STREAM_ID_BLOCKED frame for
nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamID // the maximum stream ID we're allowed to open
maxStreamSet bool // was maxStream set. If not, it's not possible to any stream (also works for stream 0)
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
newStream func(protocol.StreamID) sendStreamI
queueStreamIDBlocked func(*wire.StreamIDBlockedFrame)
queueStreamIDBlocked func(*wire.StreamsBlockedFrame)
closeErr error
}
@ -39,7 +39,7 @@ func newOutgoingUniStreamsMap(
streams: make(map[protocol.StreamID]sendStreamI),
nextStream: nextStream,
newStream: newStream,
queueStreamIDBlocked: func(f *wire.StreamIDBlockedFrame) { queueControlFrame(f) },
queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) },
}
m.cond.L = &m.mutex
return m
@ -73,9 +73,19 @@ func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) {
return nil, m.closeErr
}
if !m.maxStreamSet || m.nextStream > m.maxStream {
if m.maxStream == 0 || m.highestBlocked < m.maxStream {
m.queueStreamIDBlocked(&wire.StreamIDBlockedFrame{StreamID: m.maxStream})
m.highestBlocked = m.maxStream
if !m.blockedSent {
if m.maxStreamSet {
m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
Type: protocol.StreamTypeUni,
StreamLimit: m.maxStream.StreamNum(),
})
} else {
m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
Type: protocol.StreamTypeUni,
StreamLimit: 0,
})
}
m.blockedSent = true
}
return nil, qerr.TooManyOpenStreams
}
@ -112,6 +122,7 @@ func (m *outgoingUniStreamsMap) SetMaxStream(id protocol.StreamID) {
if !m.maxStreamSet || id > m.maxStream {
m.maxStream = id
m.maxStreamSet = true
m.blockedSent = false
m.cond.Broadcast()
}
m.mutex.Unlock()

View file

@ -292,7 +292,7 @@ var _ = Describe("Streams Map", func() {
})
})
Context("handling MAX_STREAM_ID frames", func() {
Context("handling MAX_STREAMS frames", func() {
BeforeEach(func() {
mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
})
@ -300,49 +300,49 @@ var _ = Describe("Streams Map", func() {
It("processes IDs for outgoing bidirectional streams", func() {
_, err := m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingBidiStream})
Expect(err).ToNot(HaveOccurred())
Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreams: 1,
})).To(Succeed())
str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
_, err = m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("processes IDs for outgoing bidirectional streams", func() {
It("processes IDs for outgoing unidirectional streams", func() {
_, err := m.OpenUniStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingUniStream})
Expect(err).ToNot(HaveOccurred())
Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni,
MaxStreams: 1,
})).To(Succeed())
str, err := m.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
})
It("rejects IDs for incoming bidirectional streams", func() {
err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingBidiStream})
Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingBidiStream)))
})
It("rejects IDs for incoming unidirectional streams", func() {
err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingUniStream})
Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingUniStream)))
_, err = m.OpenUniStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
})
Context("sending MAX_STREAM_ID frames", func() {
It("sends MAX_STREAM_ID frames for bidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream + 4*10)
Context("sending MAX_STREAMS frames", func() {
It("sends a MAX_STREAMS frame for bidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{
StreamID: protocol.MaxBidiStreamID(maxBidiStreams, perspective) + 4,
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreams: maxBidiStreams + 1,
})
Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed())
})
It("sends MAX_STREAM_ID frames for unidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream + 4*10)
It("sends a MAX_STREAMS frame for unidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{
StreamID: protocol.MaxUniStreamID(maxUniStreams, perspective) + 4,
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni,
MaxStreams: maxUniStreams + 1,
})
Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed())
})