diff --git a/crypto_stream.go b/crypto_stream.go index 16294b6d..8e96ec10 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -15,7 +15,6 @@ type cryptoStreamI interface { handleStreamFrame(*wire.StreamFrame) error popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool) closeForShutdown(error) - hasDataForWriting() bool setReadOffset(protocol.ByteCount) // methods needed for flow control getWindowUpdate() protocol.ByteCount @@ -40,10 +39,3 @@ func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) { s.receiveStream.readOffset = offset s.receiveStream.frameQueue.readPosition = offset } - -func (s *cryptoStream) hasDataForWriting() bool { - s.sendStream.mutex.Lock() - hasData := s.sendStream.dataForWriting != nil - s.sendStream.mutex.Unlock() - return hasData -} diff --git a/crypto_stream_test.go b/crypto_stream_test.go index fc2a61a6..d5ec3bef 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -23,17 +23,4 @@ var _ = Describe("Crypto Stream", func() { Expect(str.receiveStream.readOffset).To(Equal(protocol.ByteCount(0x42))) Expect(str.receiveStream.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42))) }) - - It("says if it has data for writing", func() { - mockSender.EXPECT().onHasStreamData(str.version.CryptoStreamID()) - Expect(str.hasDataForWriting()).To(BeFalse()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - Eventually(str.hasDataForWriting).Should(BeTrue()) - }) }) diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go index f856e1ba..68e47c1d 100644 --- a/mock_crypto_stream_test.go +++ b/mock_crypto_stream_test.go @@ -117,18 +117,6 @@ func (mr *MockCryptoStreamMockRecorder) handleStreamFrame(arg0 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleStreamFrame), arg0) } -// hasDataForWriting mocks base method -func (m *MockCryptoStream) hasDataForWriting() bool { - ret := m.ctrl.Call(m, "hasDataForWriting") - ret0, _ := ret[0].(bool) - return ret0 -} - -// hasDataForWriting indicates an expected call of hasDataForWriting -func (mr *MockCryptoStreamMockRecorder) hasDataForWriting() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasDataForWriting", reflect.TypeOf((*MockCryptoStream)(nil).hasDataForWriting)) -} - // popStreamFrame mocks base method func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) { ret := m.ctrl.Call(m, "popStreamFrame", arg0) diff --git a/mock_stream_frame_source_test.go b/mock_stream_frame_source_test.go new file mode 100644 index 00000000..9b365805 --- /dev/null +++ b/mock_stream_frame_source_test.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: StreamFrameSource) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockStreamFrameSource is a mock of StreamFrameSource interface +type MockStreamFrameSource struct { + ctrl *gomock.Controller + recorder *MockStreamFrameSourceMockRecorder +} + +// MockStreamFrameSourceMockRecorder is the mock recorder for MockStreamFrameSource +type MockStreamFrameSourceMockRecorder struct { + mock *MockStreamFrameSource +} + +// NewMockStreamFrameSource creates a new mock instance +func NewMockStreamFrameSource(ctrl *gomock.Controller) *MockStreamFrameSource { + mock := &MockStreamFrameSource{ctrl: ctrl} + mock.recorder = &MockStreamFrameSourceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStreamFrameSource) EXPECT() *MockStreamFrameSourceMockRecorder { + return m.recorder +} + +// HasCryptoStreamData mocks base method +func (m *MockStreamFrameSource) HasCryptoStreamData() bool { + ret := m.ctrl.Call(m, "HasCryptoStreamData") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasCryptoStreamData indicates an expected call of HasCryptoStreamData +func (mr *MockStreamFrameSourceMockRecorder) HasCryptoStreamData() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasCryptoStreamData", reflect.TypeOf((*MockStreamFrameSource)(nil).HasCryptoStreamData)) +} + +// PopCryptoStreamFrame mocks base method +func (m *MockStreamFrameSource) PopCryptoStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame { + ret := m.ctrl.Call(m, "PopCryptoStreamFrame", arg0) + ret0, _ := ret[0].(*wire.StreamFrame) + return ret0 +} + +// PopCryptoStreamFrame indicates an expected call of PopCryptoStreamFrame +func (mr *MockStreamFrameSourceMockRecorder) PopCryptoStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoStreamFrame", reflect.TypeOf((*MockStreamFrameSource)(nil).PopCryptoStreamFrame), arg0) +} + +// PopStreamFrames mocks base method +func (m *MockStreamFrameSource) PopStreamFrames(arg0 protocol.ByteCount) []*wire.StreamFrame { + ret := m.ctrl.Call(m, "PopStreamFrames", arg0) + ret0, _ := ret[0].([]*wire.StreamFrame) + return ret0 +} + +// PopStreamFrames indicates an expected call of PopStreamFrames +func (mr *MockStreamFrameSourceMockRecorder) PopStreamFrames(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopStreamFrames", reflect.TypeOf((*MockStreamFrameSource)(nil).PopStreamFrames), arg0) +} diff --git a/mockgen.go b/mockgen.go index 03ef893d..8c2cd6f3 100644 --- a/mockgen.go +++ b/mockgen.go @@ -3,6 +3,7 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI" //go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender" //go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource" //go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream" //go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go" //go:generate sh -c "goimports -w mock*_test.go" diff --git a/packet_packer.go b/packet_packer.go index 1a56e196..74e46e35 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -19,6 +19,12 @@ type packedPacket struct { encryptionLevel protocol.EncryptionLevel } +type streamFrameSource interface { + HasCryptoStreamData() bool + PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame + PopStreamFrames(protocol.ByteCount) []*wire.StreamFrame +} + type packetPacker struct { connectionID protocol.ConnectionID perspective protocol.Perspective @@ -26,7 +32,7 @@ type packetPacker struct { cryptoSetup handshake.CryptoSetup packetNumberGenerator *packetNumberGenerator - streamFramer *streamFramer + streams streamFrameSource controlFrameMutex sync.Mutex controlFrames []wire.Frame @@ -42,7 +48,7 @@ type packetPacker struct { func newPacketPacker(connectionID protocol.ConnectionID, initialPacketNumber protocol.PacketNumber, cryptoSetup handshake.CryptoSetup, - streamFramer *streamFramer, + streamFramer streamFrameSource, perspective protocol.Perspective, version protocol.VersionNumber, ) *packetPacker { @@ -51,7 +57,7 @@ func newPacketPacker(connectionID protocol.ConnectionID, connectionID: connectionID, perspective: perspective, version: version, - streamFramer: streamFramer, + streams: streamFramer, packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), } } @@ -128,7 +134,7 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (* // PackPacket packs a new packet // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise func (p *packetPacker) PackPacket() (*packedPacket, error) { - hasCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame() + hasCryptoStreamFrame := p.streams.HasCryptoStreamData() // if this is the first packet to be send, make sure it contains stream data if !p.hasSentPacket && !hasCryptoStreamFrame { return nil, nil @@ -195,7 +201,7 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { return nil, err } maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength - sf := p.streamFramer.PopCryptoStreamFrame(maxLen) + sf := p.streams.PopCryptoStreamFrame(maxLen) sf.DataLenPresent = false frames := []wire.Frame{sf} raw, err := p.writeAndSealPacket(header, frames, sealer) @@ -260,7 +266,7 @@ func (p *packetPacker) composeNextPacket( maxFrameSize += 2 } - fs := p.streamFramer.PopStreamFrames(maxFrameSize - payloadLength) + fs := p.streams.PopStreamFrames(maxFrameSize - payloadLength) if len(fs) != 0 { fs[len(fs)-1].DataLenPresent = false } diff --git a/packet_packer_test.go b/packet_packer_test.go index 8966733b..3bbeaed8 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -53,11 +53,11 @@ func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce var _ = Describe("Packet packer", func() { var ( - packer *packetPacker - publicHeaderLen protocol.ByteCount - maxFrameSize protocol.ByteCount - streamFramer *streamFramer - cryptoStream cryptoStreamI + packer *packetPacker + publicHeaderLen protocol.ByteCount + maxFrameSize protocol.ByteCount + cryptoStream cryptoStreamI + mockStreamFramer *MockStreamFrameSource ) BeforeEach(func() { @@ -65,15 +65,16 @@ var _ = Describe("Packet packer", func() { mockSender := NewMockStreamSender(mockCtrl) mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() cryptoStream = newCryptoStream(mockSender, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version) - streamFramer = newStreamFramer(cryptoStream, nil, versionGQUICFrames) + mockStreamFramer = NewMockStreamFrameSource(mockCtrl) - packer = &packetPacker{ - cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, - connectionID: 0x1337, - packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength), - streamFramer: streamFramer, - perspective: protocol.PerspectiveServer, - } + packer = newPacketPacker( + 0x1337, + 1, + &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, + mockStreamFramer, + protocol.PerspectiveServer, + version, + ) publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen packer.hasSentPacket = true @@ -81,33 +82,36 @@ var _ = Describe("Packet packer", func() { }) It("returns nil when no packet is queued", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) p, err := packer.PackPacket() Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) }) It("packs single packets", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() f := &wire.StreamFrame{ StreamID: 5, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, } - streamFramer.AddFrameForRetransmission(f) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f}) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b := &bytes.Buffer{} f.Write(b, packer.version) - Expect(p.frames).To(HaveLen(1)) + Expect(p.frames).To(Equal([]wire.Frame{f})) Expect(p.raw).To(ContainSubstring(string(b.Bytes()))) }) It("stores the encryption level a packet was sealed with", func() { - packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure - f := &wire.StreamFrame{ + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{ StreamID: 5, Data: []byte("foobar"), - } - streamFramer.AddFrameForRetransmission(f) + }}) + packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) @@ -215,7 +219,7 @@ var _ = Describe("Packet packer", func() { }) }) - It("packs a ConnectionClose", func() { + It("packs a CONNECTION_CLOSE", func() { ccf := wire.ConnectionCloseFrame{ ErrorCode: 0x1337, ReasonPhrase: "foobar", @@ -226,23 +230,21 @@ var _ = Describe("Packet packer", func() { Expect(p.frames[0]).To(Equal(&ccf)) }) - It("doesn't send any other frames when sending a ConnectionClose", func() { - ccf := wire.ConnectionCloseFrame{ + It("doesn't send any other frames when sending a CONNECTION_CLOSE", func() { + // expect no mockStreamFramer.PopStreamFrames + ccf := &wire.ConnectionCloseFrame{ ErrorCode: 0x1337, ReasonPhrase: "foobar", } packer.controlFrames = []wire.Frame{&wire.MaxStreamDataFrame{StreamID: 37}} - streamFramer.AddFrameForRetransmission(&wire.StreamFrame{ - StreamID: 5, - Data: []byte("foobar"), - }) - p, err := packer.PackConnectionClose(&ccf) + p, err := packer.PackConnectionClose(ccf) Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(Equal(&ccf)) + Expect(p.frames).To(Equal([]wire.Frame{ccf})) }) It("packs only control frames", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) packer.QueueControlFrame(&wire.RstStreamFrame{}) packer.QueueControlFrame(&wire.MaxDataFrame{}) p, err := packer.PackPacket() @@ -253,6 +255,8 @@ var _ = Describe("Packet packer", func() { }) It("increases the packet number", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) packer.QueueControlFrame(&wire.RstStreamFrame{}) p1, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -265,6 +269,8 @@ var _ = Describe("Packet packer", func() { }) It("packs a STOP_WAITING frame first", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) packer.packetNumberGenerator.next = 15 swf := &wire.StopWaitingFrame{LeastUnacked: 10} packer.QueueControlFrame(&wire.RstStreamFrame{}) @@ -277,6 +283,8 @@ var _ = Describe("Packet packer", func() { }) It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number packer.packetNumberGenerator.next = packetNumber swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100} @@ -288,6 +296,8 @@ var _ = Describe("Packet packer", func() { }) It("does not pack a packet containing only a STOP_WAITING frame", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) swf := &wire.StopWaitingFrame{LeastUnacked: 10} packer.QueueControlFrame(swf) p, err := packer.PackPacket() @@ -296,6 +306,8 @@ var _ = Describe("Packet packer", func() { }) It("packs a packet if it has queued control frames, but no new control frames", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -303,6 +315,7 @@ var _ = Describe("Packet packer", func() { }) It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() packer.hasSentPacket = false packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} p, err := packer.PackPacket() @@ -346,16 +359,17 @@ var _ = Describe("Packet packer", func() { }) It("only increases the packet number when there is an actual packet to send", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) packer.packetNumberGenerator.nextToSkip = 1000 p, err := packer.PackPacket() Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1))) - f := &wire.StreamFrame{ + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{ StreamID: 5, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, - } - streamFramer.AddFrameForRetransmission(f) + }}) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -364,6 +378,8 @@ var _ = Describe("Packet packer", func() { }) It("adds a PING frame when it's supposed to send a retransmittable packet", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) packer.QueueControlFrame(&wire.AckFrame{}) packer.QueueControlFrame(&wire.StopWaitingFrame{}) packer.MakeNextPacketRetransmittable() @@ -381,6 +397,8 @@ var _ = Describe("Packet packer", func() { }) It("waits until there's something to send before adding a PING frame", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) packer.MakeNextPacketRetransmittable() p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -393,6 +411,8 @@ var _ = Describe("Packet packer", func() { }) It("doesn't send a PING if it already sent another retransmittable frame", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) packer.MakeNextPacketRetransmittable() packer.QueueControlFrame(&wire.MaxDataFrame{}) p, err := packer.PackPacket() @@ -406,300 +426,158 @@ var _ = Describe("Packet packer", func() { Expect(p.frames).To(HaveLen(1)) }) - Context("STREAM Frame handling", func() { + Context("STREAM frame handling", func() { It("does not splits a STREAM frame with maximum size, for gQUIC frames", func() { - f := &wire.StreamFrame{ - Offset: 1, - StreamID: 5, - DataLenPresent: false, - } - maxStreamFrameDataLen := maxFrameSize - f.MinLength(packer.version) - f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)) - streamFramer.AddFrameForRetransmission(f) - payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(1)) - Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(BeEmpty()) - }) - - It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() { - packer.version = versionIETFFrames - streamFramer.version = versionIETFFrames - f := &wire.StreamFrame{ - Offset: 1, - StreamID: 5, - DataLenPresent: true, - } - // for IETF draft style STREAM frames, we don't know the size of the DataLen, because it is a variable length integer - // in the general case, we therefore use a STREAM frame that is 1 byte smaller than the maximum size - maxStreamFrameDataLen := maxFrameSize - f.MinLength(packer.version) - 1 - f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)) - streamFramer.AddFrameForRetransmission(f) - payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(1)) - Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(BeEmpty()) - }) - - It("correctly handles a STREAM frame with one byte less than maximum size", func() { - maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2) - 1 - f1 := &wire.StreamFrame{ - StreamID: 5, - Offset: 1, - Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)), - } - f2 := &wire.StreamFrame{ - StreamID: 5, - Offset: 1, - Data: []byte("foobar"), - } - streamFramer.AddFrameForRetransmission(f1) - streamFramer.AddFrameForRetransmission(f2) - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1))) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - p, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - }) - - It("packs multiple small STREAM frames into single packet", func() { - f1 := &wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, - } - f2 := &wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xBE, 0xEF, 0x13, 0x37}, - } - f3 := &wire.StreamFrame{ - StreamID: 3, - Data: []byte{0xCA, 0xFE}, - } - streamFramer.AddFrameForRetransmission(f1) - streamFramer.AddFrameForRetransmission(f2) - streamFramer.AddFrameForRetransmission(f3) - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - b := &bytes.Buffer{} - f1.Write(b, 0) - f2.Write(b, 0) - f3.Write(b, 0) - Expect(p.frames).To(HaveLen(3)) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) - Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) - Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - Expect(p.raw).To(ContainSubstring(string(f1.Data))) - Expect(p.raw).To(ContainSubstring(string(f2.Data))) - Expect(p.raw).To(ContainSubstring(string(f3.Data))) - }) - - It("splits one STREAM frame larger than maximum size", func() { - f := &wire.StreamFrame{ - StreamID: 7, - Offset: 1, - } - maxStreamFrameDataLen := maxFrameSize - f.MinLength(packer.version) - f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200) - streamFramer.AddFrameForRetransmission(f) - payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(1)) - Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - Expect(payloadFrames[0].(*wire.StreamFrame).Data).To(HaveLen(int(maxStreamFrameDataLen))) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(1)) - Expect(payloadFrames[0].(*wire.StreamFrame).Data).To(HaveLen(200)) - Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(BeEmpty()) - }) - - It("packs 2 STREAM frames that are too big for one packet correctly", func() { - maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2) - f1 := &wire.StreamFrame{ - StreamID: 5, - Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100), - Offset: 1, - } - f2 := &wire.StreamFrame{ - StreamID: 5, - Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100), - Offset: 1, - } - streamFramer.AddFrameForRetransmission(f1) - streamFramer.AddFrameForRetransmission(f2) + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame { + f := &wire.StreamFrame{ + Offset: 1, + StreamID: 5, + DataLenPresent: true, + } + f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.MinLength(packer.version))) + return []*wire.StreamFrame{f} + }) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) - p, err = packer.PackPacket() - Expect(p.frames).To(HaveLen(2)) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) - Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) - p, err = packer.PackPacket() - Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) - It("packs a packet that has the maximum packet size when given a large enough STREAM frame", func() { - f := &wire.StreamFrame{ - StreamID: 5, - Offset: 1, - } - f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-f.MinLength(packer.version)+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header - streamFramer.AddFrameForRetransmission(f) + It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() { + packer.version = versionIETFFrames + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame { + f := &wire.StreamFrame{ + Offset: 1, + StreamID: 5, + DataLenPresent: true, + } + f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.MinLength(packer.version))) + return []*wire.StreamFrame{f} + }) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) + Expect(p.frames).To(HaveLen(1)) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) + Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + p, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) }) - It("splits a STREAM frame larger than the maximum size", func() { - f := &wire.StreamFrame{ - StreamID: 5, - Offset: 1, + It("packs multiple small STREAM frames into single packet", func() { + f1 := &wire.StreamFrame{ + StreamID: 5, + Data: []byte("frame 1"), + DataLenPresent: true, } - f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-f.MinLength(packer.version)+2)) // + 2 since MinceLength is 1 bigger than the actual StreamFrame header - - streamFramer.AddFrameForRetransmission(f) - payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) + f2 := &wire.StreamFrame{ + StreamID: 5, + Data: []byte("frame 2"), + DataLenPresent: true, + } + f3 := &wire.StreamFrame{ + StreamID: 3, + Data: []byte("frame 3"), + DataLenPresent: true, + } + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f1, f2, f3}) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(1)) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(1)) + Expect(p.frames).To(HaveLen(3)) + Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) + Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) + Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) + Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) + Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) + Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) }) It("refuses to send unencrypted stream data on a data stream", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + // don't expect a call to mockStreamFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted - f := &wire.StreamFrame{ - StreamID: 3, - Data: []byte("foobar"), - } - streamFramer.AddFrameForRetransmission(f) p, err := packer.PackPacket() Expect(err).NotTo(HaveOccurred()) Expect(p).To(BeNil()) }) It("sends non forward-secure data as the client", func() { - packer.perspective = protocol.PerspectiveClient - packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure f := &wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), } - streamFramer.AddFrameForRetransmission(f) + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f}) + packer.perspective = protocol.PerspectiveClient + packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) - Expect(p.frames[0]).To(Equal(f)) + Expect(p.frames).To(Equal([]wire.Frame{f})) }) It("does not send non forward-secure data as the server", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + // don't expect a call to mockStreamFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure - f := &wire.StreamFrame{ - StreamID: 5, - Data: []byte("foobar"), - } - streamFramer.AddFrameForRetransmission(f) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) It("sends unencrypted stream data on the crypto stream", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := cryptoStream.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted - var p *packedPacket - Eventually(func() *packedPacket { - defer GinkgoRecover() - var err error - p, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - return p - }).ShouldNot(BeNil()) - Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(Equal(&wire.StreamFrame{ + f := &wire.StreamFrame{ StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), - })) - Eventually(done).Should(BeClosed()) + } + mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) + mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(Equal([]wire.Frame{f})) + Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) }) It("sends encrypted stream data on the crypto stream", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := cryptoStream.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure - var p *packedPacket - Eventually(func() *packedPacket { - defer GinkgoRecover() - var err error - p, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - return p - }).ShouldNot(BeNil()) - Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(Equal(&wire.StreamFrame{ + f := &wire.StreamFrame{ StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), - })) - Eventually(done).Should(BeClosed()) - }) - - It("does not pack stream frames if not allowed", func() { - packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted - packer.QueueControlFrame(&wire.AckFrame{}) - streamFramer.AddFrameForRetransmission(&wire.StreamFrame{StreamID: 3, Data: []byte("foobar")}) + } + mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) + mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(func() { _ = p.frames[0].(*wire.AckFrame) }).NotTo(Panic()) + Expect(p.frames).To(Equal([]wire.Frame{f})) + Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) }) - }) - It("returns nil if we only have a single STOP_WAITING", func() { - packer.QueueControlFrame(&wire.StopWaitingFrame{}) - p, err := packer.PackPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p).To(BeNil()) + It("does not pack STREAM frames if not allowed", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + // don't expect a call to mockStreamFramer.PopStreamFrames + packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted + ack := &wire.AckFrame{LargestAcked: 10} + packer.QueueControlFrame(ack) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(Equal([]wire.Frame{ack})) + }) }) It("packs a single ACK", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) ack := &wire.AckFrame{LargestAcked: 42} packer.QueueControlFrame(ack) p, err := packer.PackPacket() @@ -709,6 +587,8 @@ var _ = Describe("Packet packer", func() { }) It("does not return nil if we only have a single ACK but request it to be sent", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) ack := &wire.AckFrame{} packer.QueueControlFrame(ack) p, err := packer.PackPacket() @@ -716,15 +596,6 @@ var _ = Describe("Packet packer", func() { Expect(p).ToNot(BeNil()) }) - It("queues a control frame to be sent in the next packet", func() { - msd := &wire.MaxStreamDataFrame{StreamID: 5} - packer.QueueControlFrame(msd) - p, err := packer.PackPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(Equal(msd)) - }) - Context("retransmitting of handshake packets", func() { swf := &wire.StopWaitingFrame{LeastUnacked: 1} sf := &wire.StreamFrame{ @@ -802,31 +673,23 @@ var _ = Describe("Packet packer", func() { }) It("pads Initial packets to the required minimum packet size", func() { + f := &wire.StreamFrame{ + StreamID: packer.version.CryptoStreamID(), + Data: []byte("foobar"), + } + mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) + mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) packer.version = protocol.VersionTLS packer.hasSentPacket = false packer.perspective = protocol.PerspectiveClient packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := cryptoStream.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - var packet *packedPacket - Eventually(func() *packedPacket { - defer GinkgoRecover() - var err error - packet, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - return packet - }).ShouldNot(BeNil()) + packet, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize)) Expect(packet.frames).To(HaveLen(1)) sf := packet.frames[0].(*wire.StreamFrame) Expect(sf.Data).To(Equal([]byte("foobar"))) Expect(sf.DataLenPresent).To(BeTrue()) - Eventually(done).Should(BeClosed()) }) It("refuses to retransmit packets that were sent with forward-secure encryption", func() { diff --git a/stream_framer.go b/stream_framer.go index eeb09327..c8de4d63 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -14,9 +14,10 @@ type streamFramer struct { retransmissionQueue []*wire.StreamFrame - streamQueueMutex sync.Mutex - activeStreams map[protocol.StreamID]struct{} - streamQueue []protocol.StreamID + streamQueueMutex sync.Mutex + activeStreams map[protocol.StreamID]struct{} + streamQueue []protocol.StreamID + hasCryptoStreamData bool } func newStreamFramer( @@ -38,6 +39,9 @@ func (f *streamFramer) AddFrameForRetransmission(frame *wire.StreamFrame) { func (f *streamFramer) AddActiveStream(id protocol.StreamID) { if id == f.version.CryptoStreamID() { // the crypto stream is handled separately + f.streamQueueMutex.Lock() + f.hasCryptoStreamData = true + f.streamQueueMutex.Unlock() return } f.streamQueueMutex.Lock() @@ -57,14 +61,18 @@ func (f *streamFramer) HasFramesForRetransmission() bool { return len(f.retransmissionQueue) > 0 } -// TODO: don't need to ask the crypto stream here, just record this information in AddActiveStream -func (f *streamFramer) HasCryptoStreamFrame() bool { - return f.cryptoStream.hasDataForWriting() +func (f *streamFramer) HasCryptoStreamData() bool { + f.streamQueueMutex.Lock() + hasCryptoStreamData := f.hasCryptoStreamData + f.streamQueueMutex.Unlock() + return hasCryptoStreamData } -// TODO(lclemente): This is somewhat duplicate with the normal path for generating frames. func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { - frame, _ := f.cryptoStream.popStreamFrame(maxLen) + f.streamQueueMutex.Lock() + frame, hasMoreData := f.cryptoStream.popStreamFrame(maxLen) + f.hasCryptoStreamData = hasMoreData + f.streamQueueMutex.Unlock() return frame } diff --git a/stream_framer_test.go b/stream_framer_test.go index 18e5b88a..c35c3786 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -22,6 +22,7 @@ var _ = Describe("Stream Framer", func() { retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame framer *streamFramer stream1, stream2 *MockStreamI + cryptoStream *MockCryptoStream streamGetter *MockStreamGetter ) @@ -40,8 +41,8 @@ var _ = Describe("Stream Framer", func() { stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes() stream2 = NewMockStreamI(mockCtrl) stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() - - framer = newStreamFramer(nil, streamGetter, versionGQUICFrames) + cryptoStream = NewMockCryptoStream(mockCtrl) + framer = newStreamFramer(cryptoStream, streamGetter, versionGQUICFrames) }) It("says if it has retransmissions", func() { @@ -57,6 +58,38 @@ var _ = Describe("Stream Framer", func() { Expect(fs[0].DataLenPresent).To(BeTrue()) }) + Context("handling the crypto stream", func() { + It("says if it has crypto stream data", func() { + Expect(framer.HasCryptoStreamData()).To(BeFalse()) + framer.AddActiveStream(framer.version.CryptoStreamID()) + Expect(framer.HasCryptoStreamData()).To(BeTrue()) + }) + + It("says that it doesn't have crypto stream data after popping all data", func() { + streamID := framer.version.CryptoStreamID() + f := &wire.StreamFrame{ + StreamID: streamID, + Data: []byte("foobar"), + } + cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, false) + framer.AddActiveStream(streamID) + Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f)) + Expect(framer.HasCryptoStreamData()).To(BeFalse()) + }) + + It("says that it has more crypto stream data if not all data was popped", func() { + streamID := framer.version.CryptoStreamID() + f := &wire.StreamFrame{ + StreamID: streamID, + Data: []byte("foobar"), + } + cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, true) + framer.AddActiveStream(streamID) + Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f)) + Expect(framer.HasCryptoStreamData()).To(BeTrue()) + }) + }) + Context("Popping", func() { It("returns nil when popping an empty framer", func() { Expect(framer.PopStreamFrames(1000)).To(BeEmpty())