also use the onHasStreamData callback for the crypto stream

This commit is contained in:
Marten Seemann 2017-12-22 10:57:01 +07:00
parent 5cd5d5dae9
commit d330fda94c
9 changed files with 291 additions and 341 deletions

View file

@ -15,7 +15,6 @@ type cryptoStreamI interface {
handleStreamFrame(*wire.StreamFrame) error handleStreamFrame(*wire.StreamFrame) error
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool) popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error) closeForShutdown(error)
hasDataForWriting() bool
setReadOffset(protocol.ByteCount) setReadOffset(protocol.ByteCount)
// methods needed for flow control // methods needed for flow control
getWindowUpdate() protocol.ByteCount getWindowUpdate() protocol.ByteCount
@ -40,10 +39,3 @@ func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
s.receiveStream.readOffset = offset s.receiveStream.readOffset = offset
s.receiveStream.frameQueue.readPosition = offset s.receiveStream.frameQueue.readPosition = offset
} }
func (s *cryptoStream) hasDataForWriting() bool {
s.sendStream.mutex.Lock()
hasData := s.sendStream.dataForWriting != nil
s.sendStream.mutex.Unlock()
return hasData
}

View file

@ -23,17 +23,4 @@ var _ = Describe("Crypto Stream", func() {
Expect(str.receiveStream.readOffset).To(Equal(protocol.ByteCount(0x42))) Expect(str.receiveStream.readOffset).To(Equal(protocol.ByteCount(0x42)))
Expect(str.receiveStream.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42))) Expect(str.receiveStream.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42)))
}) })
It("says if it has data for writing", func() {
mockSender.EXPECT().onHasStreamData(str.version.CryptoStreamID())
Expect(str.hasDataForWriting()).To(BeFalse())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(str.hasDataForWriting).Should(BeTrue())
})
}) })

View file

@ -117,18 +117,6 @@ func (mr *MockCryptoStreamMockRecorder) handleStreamFrame(arg0 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleStreamFrame), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleStreamFrame), arg0)
} }
// hasDataForWriting mocks base method
func (m *MockCryptoStream) hasDataForWriting() bool {
ret := m.ctrl.Call(m, "hasDataForWriting")
ret0, _ := ret[0].(bool)
return ret0
}
// hasDataForWriting indicates an expected call of hasDataForWriting
func (mr *MockCryptoStreamMockRecorder) hasDataForWriting() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasDataForWriting", reflect.TypeOf((*MockCryptoStream)(nil).hasDataForWriting))
}
// popStreamFrame mocks base method // popStreamFrame mocks base method
func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) { func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
ret := m.ctrl.Call(m, "popStreamFrame", arg0) ret := m.ctrl.Call(m, "popStreamFrame", arg0)

View file

@ -0,0 +1,72 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamFrameSource)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockStreamFrameSource is a mock of StreamFrameSource interface
type MockStreamFrameSource struct {
ctrl *gomock.Controller
recorder *MockStreamFrameSourceMockRecorder
}
// MockStreamFrameSourceMockRecorder is the mock recorder for MockStreamFrameSource
type MockStreamFrameSourceMockRecorder struct {
mock *MockStreamFrameSource
}
// NewMockStreamFrameSource creates a new mock instance
func NewMockStreamFrameSource(ctrl *gomock.Controller) *MockStreamFrameSource {
mock := &MockStreamFrameSource{ctrl: ctrl}
mock.recorder = &MockStreamFrameSourceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStreamFrameSource) EXPECT() *MockStreamFrameSourceMockRecorder {
return m.recorder
}
// HasCryptoStreamData mocks base method
func (m *MockStreamFrameSource) HasCryptoStreamData() bool {
ret := m.ctrl.Call(m, "HasCryptoStreamData")
ret0, _ := ret[0].(bool)
return ret0
}
// HasCryptoStreamData indicates an expected call of HasCryptoStreamData
func (mr *MockStreamFrameSourceMockRecorder) HasCryptoStreamData() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasCryptoStreamData", reflect.TypeOf((*MockStreamFrameSource)(nil).HasCryptoStreamData))
}
// PopCryptoStreamFrame mocks base method
func (m *MockStreamFrameSource) PopCryptoStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame {
ret := m.ctrl.Call(m, "PopCryptoStreamFrame", arg0)
ret0, _ := ret[0].(*wire.StreamFrame)
return ret0
}
// PopCryptoStreamFrame indicates an expected call of PopCryptoStreamFrame
func (mr *MockStreamFrameSourceMockRecorder) PopCryptoStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoStreamFrame", reflect.TypeOf((*MockStreamFrameSource)(nil).PopCryptoStreamFrame), arg0)
}
// PopStreamFrames mocks base method
func (m *MockStreamFrameSource) PopStreamFrames(arg0 protocol.ByteCount) []*wire.StreamFrame {
ret := m.ctrl.Call(m, "PopStreamFrames", arg0)
ret0, _ := ret[0].([]*wire.StreamFrame)
return ret0
}
// PopStreamFrames indicates an expected call of PopStreamFrames
func (mr *MockStreamFrameSourceMockRecorder) PopStreamFrames(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopStreamFrames", reflect.TypeOf((*MockStreamFrameSource)(nil).PopStreamFrames), arg0)
}

View file

@ -3,6 +3,7 @@ package quic
//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI" //go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender" //go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter" //go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource"
//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream" //go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream"
//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go" //go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go"
//go:generate sh -c "goimports -w mock*_test.go" //go:generate sh -c "goimports -w mock*_test.go"

View file

@ -19,6 +19,12 @@ type packedPacket struct {
encryptionLevel protocol.EncryptionLevel encryptionLevel protocol.EncryptionLevel
} }
type streamFrameSource interface {
HasCryptoStreamData() bool
PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame
PopStreamFrames(protocol.ByteCount) []*wire.StreamFrame
}
type packetPacker struct { type packetPacker struct {
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
perspective protocol.Perspective perspective protocol.Perspective
@ -26,7 +32,7 @@ type packetPacker struct {
cryptoSetup handshake.CryptoSetup cryptoSetup handshake.CryptoSetup
packetNumberGenerator *packetNumberGenerator packetNumberGenerator *packetNumberGenerator
streamFramer *streamFramer streams streamFrameSource
controlFrameMutex sync.Mutex controlFrameMutex sync.Mutex
controlFrames []wire.Frame controlFrames []wire.Frame
@ -42,7 +48,7 @@ type packetPacker struct {
func newPacketPacker(connectionID protocol.ConnectionID, func newPacketPacker(connectionID protocol.ConnectionID,
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
cryptoSetup handshake.CryptoSetup, cryptoSetup handshake.CryptoSetup,
streamFramer *streamFramer, streamFramer streamFrameSource,
perspective protocol.Perspective, perspective protocol.Perspective,
version protocol.VersionNumber, version protocol.VersionNumber,
) *packetPacker { ) *packetPacker {
@ -51,7 +57,7 @@ func newPacketPacker(connectionID protocol.ConnectionID,
connectionID: connectionID, connectionID: connectionID,
perspective: perspective, perspective: perspective,
version: version, version: version,
streamFramer: streamFramer, streams: streamFramer,
packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
} }
} }
@ -128,7 +134,7 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*
// PackPacket packs a new packet // PackPacket packs a new packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
func (p *packetPacker) PackPacket() (*packedPacket, error) { func (p *packetPacker) PackPacket() (*packedPacket, error) {
hasCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame() hasCryptoStreamFrame := p.streams.HasCryptoStreamData()
// if this is the first packet to be send, make sure it contains stream data // if this is the first packet to be send, make sure it contains stream data
if !p.hasSentPacket && !hasCryptoStreamFrame { if !p.hasSentPacket && !hasCryptoStreamFrame {
return nil, nil return nil, nil
@ -195,7 +201,7 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) {
return nil, err return nil, err
} }
maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength
sf := p.streamFramer.PopCryptoStreamFrame(maxLen) sf := p.streams.PopCryptoStreamFrame(maxLen)
sf.DataLenPresent = false sf.DataLenPresent = false
frames := []wire.Frame{sf} frames := []wire.Frame{sf}
raw, err := p.writeAndSealPacket(header, frames, sealer) raw, err := p.writeAndSealPacket(header, frames, sealer)
@ -260,7 +266,7 @@ func (p *packetPacker) composeNextPacket(
maxFrameSize += 2 maxFrameSize += 2
} }
fs := p.streamFramer.PopStreamFrames(maxFrameSize - payloadLength) fs := p.streams.PopStreamFrames(maxFrameSize - payloadLength)
if len(fs) != 0 { if len(fs) != 0 {
fs[len(fs)-1].DataLenPresent = false fs[len(fs)-1].DataLenPresent = false
} }

View file

@ -53,11 +53,11 @@ func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce
var _ = Describe("Packet packer", func() { var _ = Describe("Packet packer", func() {
var ( var (
packer *packetPacker packer *packetPacker
publicHeaderLen protocol.ByteCount publicHeaderLen protocol.ByteCount
maxFrameSize protocol.ByteCount maxFrameSize protocol.ByteCount
streamFramer *streamFramer cryptoStream cryptoStreamI
cryptoStream cryptoStreamI mockStreamFramer *MockStreamFrameSource
) )
BeforeEach(func() { BeforeEach(func() {
@ -65,15 +65,16 @@ var _ = Describe("Packet packer", func() {
mockSender := NewMockStreamSender(mockCtrl) mockSender := NewMockStreamSender(mockCtrl)
mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes()
cryptoStream = newCryptoStream(mockSender, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version) cryptoStream = newCryptoStream(mockSender, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version)
streamFramer = newStreamFramer(cryptoStream, nil, versionGQUICFrames) mockStreamFramer = NewMockStreamFrameSource(mockCtrl)
packer = &packetPacker{ packer = newPacketPacker(
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, 0x1337,
connectionID: 0x1337, 1,
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength), &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
streamFramer: streamFramer, mockStreamFramer,
perspective: protocol.PerspectiveServer, protocol.PerspectiveServer,
} version,
)
publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number
maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen
packer.hasSentPacket = true packer.hasSentPacket = true
@ -81,33 +82,36 @@ var _ = Describe("Packet packer", func() {
}) })
It("returns nil when no packet is queued", func() { It("returns nil when no packet is queued", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(p).To(BeNil()) Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
It("packs single packets", func() { It("packs single packets", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
f := &wire.StreamFrame{ f := &wire.StreamFrame{
StreamID: 5, StreamID: 5,
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
} }
streamFramer.AddFrameForRetransmission(f) mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f})
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
b := &bytes.Buffer{} b := &bytes.Buffer{}
f.Write(b, packer.version) f.Write(b, packer.version)
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(Equal([]wire.Frame{f}))
Expect(p.raw).To(ContainSubstring(string(b.Bytes()))) Expect(p.raw).To(ContainSubstring(string(b.Bytes())))
}) })
It("stores the encryption level a packet was sealed with", func() { It("stores the encryption level a packet was sealed with", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure mockStreamFramer.EXPECT().HasCryptoStreamData()
f := &wire.StreamFrame{ mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{
StreamID: 5, StreamID: 5,
Data: []byte("foobar"), Data: []byte("foobar"),
} }})
streamFramer.AddFrameForRetransmission(f) packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
@ -215,7 +219,7 @@ var _ = Describe("Packet packer", func() {
}) })
}) })
It("packs a ConnectionClose", func() { It("packs a CONNECTION_CLOSE", func() {
ccf := wire.ConnectionCloseFrame{ ccf := wire.ConnectionCloseFrame{
ErrorCode: 0x1337, ErrorCode: 0x1337,
ReasonPhrase: "foobar", ReasonPhrase: "foobar",
@ -226,23 +230,21 @@ var _ = Describe("Packet packer", func() {
Expect(p.frames[0]).To(Equal(&ccf)) Expect(p.frames[0]).To(Equal(&ccf))
}) })
It("doesn't send any other frames when sending a ConnectionClose", func() { It("doesn't send any other frames when sending a CONNECTION_CLOSE", func() {
ccf := wire.ConnectionCloseFrame{ // expect no mockStreamFramer.PopStreamFrames
ccf := &wire.ConnectionCloseFrame{
ErrorCode: 0x1337, ErrorCode: 0x1337,
ReasonPhrase: "foobar", ReasonPhrase: "foobar",
} }
packer.controlFrames = []wire.Frame{&wire.MaxStreamDataFrame{StreamID: 37}} packer.controlFrames = []wire.Frame{&wire.MaxStreamDataFrame{StreamID: 37}}
streamFramer.AddFrameForRetransmission(&wire.StreamFrame{ p, err := packer.PackConnectionClose(ccf)
StreamID: 5,
Data: []byte("foobar"),
})
p, err := packer.PackConnectionClose(&ccf)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(Equal([]wire.Frame{ccf}))
Expect(p.frames[0]).To(Equal(&ccf))
}) })
It("packs only control frames", func() { It("packs only control frames", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packer.QueueControlFrame(&wire.RstStreamFrame{}) packer.QueueControlFrame(&wire.RstStreamFrame{})
packer.QueueControlFrame(&wire.MaxDataFrame{}) packer.QueueControlFrame(&wire.MaxDataFrame{})
p, err := packer.PackPacket() p, err := packer.PackPacket()
@ -253,6 +255,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("increases the packet number", func() { It("increases the packet number", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
packer.QueueControlFrame(&wire.RstStreamFrame{}) packer.QueueControlFrame(&wire.RstStreamFrame{})
p1, err := packer.PackPacket() p1, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -265,6 +269,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("packs a STOP_WAITING frame first", func() { It("packs a STOP_WAITING frame first", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packer.packetNumberGenerator.next = 15 packer.packetNumberGenerator.next = 15
swf := &wire.StopWaitingFrame{LeastUnacked: 10} swf := &wire.StopWaitingFrame{LeastUnacked: 10}
packer.QueueControlFrame(&wire.RstStreamFrame{}) packer.QueueControlFrame(&wire.RstStreamFrame{})
@ -277,6 +283,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() { It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number
packer.packetNumberGenerator.next = packetNumber packer.packetNumberGenerator.next = packetNumber
swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100} swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100}
@ -288,6 +296,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("does not pack a packet containing only a STOP_WAITING frame", func() { It("does not pack a packet containing only a STOP_WAITING frame", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
swf := &wire.StopWaitingFrame{LeastUnacked: 10} swf := &wire.StopWaitingFrame{LeastUnacked: 10}
packer.QueueControlFrame(swf) packer.QueueControlFrame(swf)
p, err := packer.PackPacket() p, err := packer.PackPacket()
@ -296,6 +306,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("packs a packet if it has queued control frames, but no new control frames", func() { It("packs a packet if it has queued control frames, but no new control frames", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}}
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -303,6 +315,7 @@ var _ = Describe("Packet packer", func() {
}) })
It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() { It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
packer.hasSentPacket = false packer.hasSentPacket = false
packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}}
p, err := packer.PackPacket() p, err := packer.PackPacket()
@ -346,16 +359,17 @@ var _ = Describe("Packet packer", func() {
}) })
It("only increases the packet number when there is an actual packet to send", func() { It("only increases the packet number when there is an actual packet to send", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packer.packetNumberGenerator.nextToSkip = 1000 packer.packetNumberGenerator.nextToSkip = 1000
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(p).To(BeNil()) Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1))) Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1)))
f := &wire.StreamFrame{ mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{
StreamID: 5, StreamID: 5,
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
} }})
streamFramer.AddFrameForRetransmission(f)
p, err = packer.PackPacket() p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
@ -364,6 +378,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("adds a PING frame when it's supposed to send a retransmittable packet", func() { It("adds a PING frame when it's supposed to send a retransmittable packet", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
packer.QueueControlFrame(&wire.AckFrame{}) packer.QueueControlFrame(&wire.AckFrame{})
packer.QueueControlFrame(&wire.StopWaitingFrame{}) packer.QueueControlFrame(&wire.StopWaitingFrame{})
packer.MakeNextPacketRetransmittable() packer.MakeNextPacketRetransmittable()
@ -381,6 +397,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("waits until there's something to send before adding a PING frame", func() { It("waits until there's something to send before adding a PING frame", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
packer.MakeNextPacketRetransmittable() packer.MakeNextPacketRetransmittable()
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -393,6 +411,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("doesn't send a PING if it already sent another retransmittable frame", func() { It("doesn't send a PING if it already sent another retransmittable frame", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
packer.MakeNextPacketRetransmittable() packer.MakeNextPacketRetransmittable()
packer.QueueControlFrame(&wire.MaxDataFrame{}) packer.QueueControlFrame(&wire.MaxDataFrame{})
p, err := packer.PackPacket() p, err := packer.PackPacket()
@ -406,300 +426,158 @@ var _ = Describe("Packet packer", func() {
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
}) })
Context("STREAM Frame handling", func() { Context("STREAM frame handling", func() {
It("does not splits a STREAM frame with maximum size, for gQUIC frames", func() { It("does not splits a STREAM frame with maximum size, for gQUIC frames", func() {
f := &wire.StreamFrame{ mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
Offset: 1, mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame {
StreamID: 5, f := &wire.StreamFrame{
DataLenPresent: false, Offset: 1,
} StreamID: 5,
maxStreamFrameDataLen := maxFrameSize - f.MinLength(packer.version) DataLenPresent: true,
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)) }
streamFramer.AddFrameForRetransmission(f) f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.MinLength(packer.version)))
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) return []*wire.StreamFrame{f}
Expect(err).ToNot(HaveOccurred()) })
Expect(payloadFrames).To(HaveLen(1)) mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(BeEmpty())
})
It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() {
packer.version = versionIETFFrames
streamFramer.version = versionIETFFrames
f := &wire.StreamFrame{
Offset: 1,
StreamID: 5,
DataLenPresent: true,
}
// for IETF draft style STREAM frames, we don't know the size of the DataLen, because it is a variable length integer
// in the general case, we therefore use a STREAM frame that is 1 byte smaller than the maximum size
maxStreamFrameDataLen := maxFrameSize - f.MinLength(packer.version) - 1
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen))
streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(BeEmpty())
})
It("correctly handles a STREAM frame with one byte less than maximum size", func() {
maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2) - 1
f1 := &wire.StreamFrame{
StreamID: 5,
Offset: 1,
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)),
}
f2 := &wire.StreamFrame{
StreamID: 5,
Offset: 1,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2)
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1)))
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
})
It("packs multiple small STREAM frames into single packet", func() {
f1 := &wire.StreamFrame{
StreamID: 5,
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
}
f2 := &wire.StreamFrame{
StreamID: 5,
Data: []byte{0xBE, 0xEF, 0x13, 0x37},
}
f3 := &wire.StreamFrame{
StreamID: 3,
Data: []byte{0xCA, 0xFE},
}
streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2)
streamFramer.AddFrameForRetransmission(f3)
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
b := &bytes.Buffer{}
f1.Write(b, 0)
f2.Write(b, 0)
f3.Write(b, 0)
Expect(p.frames).To(HaveLen(3))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(p.raw).To(ContainSubstring(string(f1.Data)))
Expect(p.raw).To(ContainSubstring(string(f2.Data)))
Expect(p.raw).To(ContainSubstring(string(f3.Data)))
})
It("splits one STREAM frame larger than maximum size", func() {
f := &wire.StreamFrame{
StreamID: 7,
Offset: 1,
}
maxStreamFrameDataLen := maxFrameSize - f.MinLength(packer.version)
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200)
streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(payloadFrames[0].(*wire.StreamFrame).Data).To(HaveLen(int(maxStreamFrameDataLen)))
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*wire.StreamFrame).Data).To(HaveLen(200))
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(BeEmpty())
})
It("packs 2 STREAM frames that are too big for one packet correctly", func() {
maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2)
f1 := &wire.StreamFrame{
StreamID: 5,
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100),
Offset: 1,
}
f2 := &wire.StreamFrame{
StreamID: 5,
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100),
Offset: 1,
}
streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2)
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
p, err = packer.PackPacket()
Expect(p.frames).To(HaveLen(2))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(err).ToNot(HaveOccurred())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
p, err = packer.PackPacket()
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
p, err = packer.PackPacket() p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
It("packs a packet that has the maximum packet size when given a large enough STREAM frame", func() { It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() {
f := &wire.StreamFrame{ packer.version = versionIETFFrames
StreamID: 5, mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
Offset: 1, mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame {
} f := &wire.StreamFrame{
f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-f.MinLength(packer.version)+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header Offset: 1,
streamFramer.AddFrameForRetransmission(f) StreamID: 5,
DataLenPresent: true,
}
f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.MinLength(packer.version)))
return []*wire.StreamFrame{f}
})
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p.frames).To(HaveLen(1))
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
}) })
It("splits a STREAM frame larger than the maximum size", func() { It("packs multiple small STREAM frames into single packet", func() {
f := &wire.StreamFrame{ f1 := &wire.StreamFrame{
StreamID: 5, StreamID: 5,
Offset: 1, Data: []byte("frame 1"),
DataLenPresent: true,
} }
f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-f.MinLength(packer.version)+2)) // + 2 since MinceLength is 1 bigger than the actual StreamFrame header f2 := &wire.StreamFrame{
StreamID: 5,
streamFramer.AddFrameForRetransmission(f) Data: []byte("frame 2"),
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) DataLenPresent: true,
}
f3 := &wire.StreamFrame{
StreamID: 3,
Data: []byte("frame 3"),
DataLenPresent: true,
}
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f1, f2, f3})
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(3))
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1")))
Expect(err).ToNot(HaveOccurred()) Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2")))
Expect(payloadFrames).To(HaveLen(1)) Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3")))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
}) })
It("refuses to send unencrypted stream data on a data stream", func() { It("refuses to send unencrypted stream data on a data stream", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
// don't expect a call to mockStreamFramer.PopStreamFrames
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted
f := &wire.StreamFrame{
StreamID: 3,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
It("sends non forward-secure data as the client", func() { It("sends non forward-secure data as the client", func() {
packer.perspective = protocol.PerspectiveClient
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
f := &wire.StreamFrame{ f := &wire.StreamFrame{
StreamID: 5, StreamID: 5,
Data: []byte("foobar"), Data: []byte("foobar"),
} }
streamFramer.AddFrameForRetransmission(f) mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f})
packer.perspective = protocol.PerspectiveClient
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
Expect(p.frames[0]).To(Equal(f)) Expect(p.frames).To(Equal([]wire.Frame{f}))
}) })
It("does not send non forward-secure data as the server", func() { It("does not send non forward-secure data as the server", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
// don't expect a call to mockStreamFramer.PopStreamFrames
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
f := &wire.StreamFrame{
StreamID: 5,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
It("sends unencrypted stream data on the crypto stream", func() { It("sends unencrypted stream data on the crypto stream", func() {
done := make(chan struct{}) f := &wire.StreamFrame{
go func() {
defer GinkgoRecover()
_, err := cryptoStream.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted
var p *packedPacket
Eventually(func() *packedPacket {
defer GinkgoRecover()
var err error
p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
return p
}).ShouldNot(BeNil())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{
StreamID: packer.version.CryptoStreamID(), StreamID: packer.version.CryptoStreamID(),
Data: []byte("foobar"), Data: []byte("foobar"),
})) }
Eventually(done).Should(BeClosed()) mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true)
mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f)
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(Equal([]wire.Frame{f}))
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
}) })
It("sends encrypted stream data on the crypto stream", func() { It("sends encrypted stream data on the crypto stream", func() {
done := make(chan struct{}) f := &wire.StreamFrame{
go func() {
defer GinkgoRecover()
_, err := cryptoStream.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure
var p *packedPacket
Eventually(func() *packedPacket {
defer GinkgoRecover()
var err error
p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
return p
}).ShouldNot(BeNil())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{
StreamID: packer.version.CryptoStreamID(), StreamID: packer.version.CryptoStreamID(),
Data: []byte("foobar"), Data: []byte("foobar"),
})) }
Eventually(done).Should(BeClosed()) mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true)
}) mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f)
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure
It("does not pack stream frames if not allowed", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted
packer.QueueControlFrame(&wire.AckFrame{})
streamFramer.AddFrameForRetransmission(&wire.StreamFrame{StreamID: 3, Data: []byte("foobar")})
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(Equal([]wire.Frame{f}))
Expect(func() { _ = p.frames[0].(*wire.AckFrame) }).NotTo(Panic()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
}) })
})
It("returns nil if we only have a single STOP_WAITING", func() { It("does not pack STREAM frames if not allowed", func() {
packer.QueueControlFrame(&wire.StopWaitingFrame{}) mockStreamFramer.EXPECT().HasCryptoStreamData()
p, err := packer.PackPacket() // don't expect a call to mockStreamFramer.PopStreamFrames
Expect(err).NotTo(HaveOccurred()) packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted
Expect(p).To(BeNil()) ack := &wire.AckFrame{LargestAcked: 10}
packer.QueueControlFrame(ack)
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(Equal([]wire.Frame{ack}))
})
}) })
It("packs a single ACK", func() { It("packs a single ACK", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
ack := &wire.AckFrame{LargestAcked: 42} ack := &wire.AckFrame{LargestAcked: 42}
packer.QueueControlFrame(ack) packer.QueueControlFrame(ack)
p, err := packer.PackPacket() p, err := packer.PackPacket()
@ -709,6 +587,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("does not return nil if we only have a single ACK but request it to be sent", func() { It("does not return nil if we only have a single ACK but request it to be sent", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
ack := &wire.AckFrame{} ack := &wire.AckFrame{}
packer.QueueControlFrame(ack) packer.QueueControlFrame(ack)
p, err := packer.PackPacket() p, err := packer.PackPacket()
@ -716,15 +596,6 @@ var _ = Describe("Packet packer", func() {
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
}) })
It("queues a control frame to be sent in the next packet", func() {
msd := &wire.MaxStreamDataFrame{StreamID: 5}
packer.QueueControlFrame(msd)
p, err := packer.PackPacket()
Expect(err).NotTo(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(msd))
})
Context("retransmitting of handshake packets", func() { Context("retransmitting of handshake packets", func() {
swf := &wire.StopWaitingFrame{LeastUnacked: 1} swf := &wire.StopWaitingFrame{LeastUnacked: 1}
sf := &wire.StreamFrame{ sf := &wire.StreamFrame{
@ -802,31 +673,23 @@ var _ = Describe("Packet packer", func() {
}) })
It("pads Initial packets to the required minimum packet size", func() { It("pads Initial packets to the required minimum packet size", func() {
f := &wire.StreamFrame{
StreamID: packer.version.CryptoStreamID(),
Data: []byte("foobar"),
}
mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true)
mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f)
packer.version = protocol.VersionTLS packer.version = protocol.VersionTLS
packer.hasSentPacket = false packer.hasSentPacket = false
packer.perspective = protocol.PerspectiveClient packer.perspective = protocol.PerspectiveClient
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted
done := make(chan struct{}) packet, err := packer.PackPacket()
go func() { Expect(err).ToNot(HaveOccurred())
defer GinkgoRecover()
_, err := cryptoStream.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
var packet *packedPacket
Eventually(func() *packedPacket {
defer GinkgoRecover()
var err error
packet, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
return packet
}).ShouldNot(BeNil())
Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize)) Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize))
Expect(packet.frames).To(HaveLen(1)) Expect(packet.frames).To(HaveLen(1))
sf := packet.frames[0].(*wire.StreamFrame) sf := packet.frames[0].(*wire.StreamFrame)
Expect(sf.Data).To(Equal([]byte("foobar"))) Expect(sf.Data).To(Equal([]byte("foobar")))
Expect(sf.DataLenPresent).To(BeTrue()) Expect(sf.DataLenPresent).To(BeTrue())
Eventually(done).Should(BeClosed())
}) })
It("refuses to retransmit packets that were sent with forward-secure encryption", func() { It("refuses to retransmit packets that were sent with forward-secure encryption", func() {

View file

@ -14,9 +14,10 @@ type streamFramer struct {
retransmissionQueue []*wire.StreamFrame retransmissionQueue []*wire.StreamFrame
streamQueueMutex sync.Mutex streamQueueMutex sync.Mutex
activeStreams map[protocol.StreamID]struct{} activeStreams map[protocol.StreamID]struct{}
streamQueue []protocol.StreamID streamQueue []protocol.StreamID
hasCryptoStreamData bool
} }
func newStreamFramer( func newStreamFramer(
@ -38,6 +39,9 @@ func (f *streamFramer) AddFrameForRetransmission(frame *wire.StreamFrame) {
func (f *streamFramer) AddActiveStream(id protocol.StreamID) { func (f *streamFramer) AddActiveStream(id protocol.StreamID) {
if id == f.version.CryptoStreamID() { // the crypto stream is handled separately if id == f.version.CryptoStreamID() { // the crypto stream is handled separately
f.streamQueueMutex.Lock()
f.hasCryptoStreamData = true
f.streamQueueMutex.Unlock()
return return
} }
f.streamQueueMutex.Lock() f.streamQueueMutex.Lock()
@ -57,14 +61,18 @@ func (f *streamFramer) HasFramesForRetransmission() bool {
return len(f.retransmissionQueue) > 0 return len(f.retransmissionQueue) > 0
} }
// TODO: don't need to ask the crypto stream here, just record this information in AddActiveStream func (f *streamFramer) HasCryptoStreamData() bool {
func (f *streamFramer) HasCryptoStreamFrame() bool { f.streamQueueMutex.Lock()
return f.cryptoStream.hasDataForWriting() hasCryptoStreamData := f.hasCryptoStreamData
f.streamQueueMutex.Unlock()
return hasCryptoStreamData
} }
// TODO(lclemente): This is somewhat duplicate with the normal path for generating frames.
func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame {
frame, _ := f.cryptoStream.popStreamFrame(maxLen) f.streamQueueMutex.Lock()
frame, hasMoreData := f.cryptoStream.popStreamFrame(maxLen)
f.hasCryptoStreamData = hasMoreData
f.streamQueueMutex.Unlock()
return frame return frame
} }

View file

@ -22,6 +22,7 @@ var _ = Describe("Stream Framer", func() {
retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame
framer *streamFramer framer *streamFramer
stream1, stream2 *MockStreamI stream1, stream2 *MockStreamI
cryptoStream *MockCryptoStream
streamGetter *MockStreamGetter streamGetter *MockStreamGetter
) )
@ -40,8 +41,8 @@ var _ = Describe("Stream Framer", func() {
stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes() stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes()
stream2 = NewMockStreamI(mockCtrl) stream2 = NewMockStreamI(mockCtrl)
stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes()
cryptoStream = NewMockCryptoStream(mockCtrl)
framer = newStreamFramer(nil, streamGetter, versionGQUICFrames) framer = newStreamFramer(cryptoStream, streamGetter, versionGQUICFrames)
}) })
It("says if it has retransmissions", func() { It("says if it has retransmissions", func() {
@ -57,6 +58,38 @@ var _ = Describe("Stream Framer", func() {
Expect(fs[0].DataLenPresent).To(BeTrue()) Expect(fs[0].DataLenPresent).To(BeTrue())
}) })
Context("handling the crypto stream", func() {
It("says if it has crypto stream data", func() {
Expect(framer.HasCryptoStreamData()).To(BeFalse())
framer.AddActiveStream(framer.version.CryptoStreamID())
Expect(framer.HasCryptoStreamData()).To(BeTrue())
})
It("says that it doesn't have crypto stream data after popping all data", func() {
streamID := framer.version.CryptoStreamID()
f := &wire.StreamFrame{
StreamID: streamID,
Data: []byte("foobar"),
}
cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, false)
framer.AddActiveStream(streamID)
Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f))
Expect(framer.HasCryptoStreamData()).To(BeFalse())
})
It("says that it has more crypto stream data if not all data was popped", func() {
streamID := framer.version.CryptoStreamID()
f := &wire.StreamFrame{
StreamID: streamID,
Data: []byte("foobar"),
}
cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, true)
framer.AddActiveStream(streamID)
Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f))
Expect(framer.HasCryptoStreamData()).To(BeTrue())
})
})
Context("Popping", func() { Context("Popping", func() {
It("returns nil when popping an empty framer", func() { It("returns nil when popping an empty framer", func() {
Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) Expect(framer.PopStreamFrames(1000)).To(BeEmpty())