From 25847cfc300f47f76fae844830a4868915dd1b86 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 1 Oct 2018 10:49:53 -0700 Subject: [PATCH] handle the crypto stream separately in the packet packer --- crypto_stream.go | 1 + framer.go | 30 ++-------- framer_test.go | 32 ----------- mock_crypto_stream_test.go | 12 ++++ mock_frame_source_test.go | 24 -------- mock_send_stream_internal_test.go | 12 ++++ mock_stream_internal_test.go | 12 ++++ packet_packer.go | 27 +++++---- packet_packer_test.go | 93 ++++++++++++++++--------------- send_stream.go | 8 +++ session.go | 8 ++- stream.go | 1 + 12 files changed, 123 insertions(+), 137 deletions(-) diff --git a/crypto_stream.go b/crypto_stream.go index a5ec4ecf..d51dc2ab 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -13,6 +13,7 @@ type cryptoStream interface { io.Reader io.Writer handleStreamFrame(*wire.StreamFrame) error + hasData() bool popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool) closeForShutdown(error) setReadOffset(protocol.ByteCount) diff --git a/framer.go b/framer.go index a6cf008f..74ca8c45 100644 --- a/framer.go +++ b/framer.go @@ -12,10 +12,9 @@ type framer struct { cryptoStream cryptoStream version protocol.VersionNumber - streamQueueMutex sync.Mutex - activeStreams map[protocol.StreamID]struct{} - streamQueue []protocol.StreamID - hasCryptoStreamData bool + streamQueueMutex sync.Mutex + activeStreams map[protocol.StreamID]struct{} + streamQueue []protocol.StreamID controlFrameMutex sync.Mutex controlFrames []wire.Frame @@ -57,13 +56,9 @@ func (f *framer) AppendControlFrames(frames []wire.Frame, maxLen protocol.ByteCo return frames, length } +// AddActiveStream adds a stream that has data to write. +// It should not be used for the crypto stream. func (f *framer) AddActiveStream(id protocol.StreamID) { - if id == f.version.CryptoStreamID() { // the crypto stream is handled separately - f.streamQueueMutex.Lock() - f.hasCryptoStreamData = true - f.streamQueueMutex.Unlock() - return - } f.streamQueueMutex.Lock() if _, ok := f.activeStreams[id]; !ok { f.streamQueue = append(f.streamQueue, id) @@ -72,21 +67,6 @@ func (f *framer) AddActiveStream(id protocol.StreamID) { f.streamQueueMutex.Unlock() } -func (f *framer) HasCryptoStreamData() bool { - f.streamQueueMutex.Lock() - hasCryptoStreamData := f.hasCryptoStreamData - f.streamQueueMutex.Unlock() - return hasCryptoStreamData -} - -func (f *framer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { - f.streamQueueMutex.Lock() - frame, hasMoreData := f.cryptoStream.popStreamFrame(maxLen) - f.hasCryptoStreamData = hasMoreData - f.streamQueueMutex.Unlock() - return frame -} - func (f *framer) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { var length protocol.ByteCount f.streamQueueMutex.Lock() diff --git a/framer_test.go b/framer_test.go index 440d1a16..a0cdcb31 100644 --- a/framer_test.go +++ b/framer_test.go @@ -72,38 +72,6 @@ var _ = Describe("Stream Framer", func() { }) }) - Context("handling the crypto stream", func() { - It("says if it has crypto stream data", func() { - Expect(framer.HasCryptoStreamData()).To(BeFalse()) - framer.AddActiveStream(framer.version.CryptoStreamID()) - Expect(framer.HasCryptoStreamData()).To(BeTrue()) - }) - - It("says that it doesn't have crypto stream data after popping all data", func() { - streamID := framer.version.CryptoStreamID() - f := &wire.StreamFrame{ - StreamID: streamID, - Data: []byte("foobar"), - } - cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, false) - framer.AddActiveStream(streamID) - Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f)) - Expect(framer.HasCryptoStreamData()).To(BeFalse()) - }) - - It("says that it has more crypto stream data if not all data was popped", func() { - streamID := framer.version.CryptoStreamID() - f := &wire.StreamFrame{ - StreamID: streamID, - Data: []byte("foobar"), - } - cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, true) - framer.AddActiveStream(streamID) - Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f)) - Expect(framer.HasCryptoStreamData()).To(BeTrue()) - }) - }) - Context("popping STREAM frames", func() { It("returns nil when popping an empty framer", func() { Expect(framer.AppendStreamFrames(nil, 1000)).To(BeEmpty()) diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go index 68e47c1d..35cccf18 100644 --- a/mock_crypto_stream_test.go +++ b/mock_crypto_stream_test.go @@ -117,6 +117,18 @@ func (mr *MockCryptoStreamMockRecorder) handleStreamFrame(arg0 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleStreamFrame), arg0) } +// hasData mocks base method +func (m *MockCryptoStream) hasData() bool { + ret := m.ctrl.Call(m, "hasData") + ret0, _ := ret[0].(bool) + return ret0 +} + +// hasData indicates an expected call of hasData +func (mr *MockCryptoStreamMockRecorder) hasData() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockCryptoStream)(nil).hasData)) +} + // popStreamFrame mocks base method func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) { ret := m.ctrl.Call(m, "popStreamFrame", arg0) diff --git a/mock_frame_source_test.go b/mock_frame_source_test.go index 655323b8..ff757c88 100644 --- a/mock_frame_source_test.go +++ b/mock_frame_source_test.go @@ -59,27 +59,3 @@ func (m *MockFrameSource) AppendStreamFrames(arg0 []wire.Frame, arg1 protocol.By func (mr *MockFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendStreamFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendStreamFrames), arg0, arg1) } - -// HasCryptoStreamData mocks base method -func (m *MockFrameSource) HasCryptoStreamData() bool { - ret := m.ctrl.Call(m, "HasCryptoStreamData") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasCryptoStreamData indicates an expected call of HasCryptoStreamData -func (mr *MockFrameSourceMockRecorder) HasCryptoStreamData() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasCryptoStreamData", reflect.TypeOf((*MockFrameSource)(nil).HasCryptoStreamData)) -} - -// PopCryptoStreamFrame mocks base method -func (m *MockFrameSource) PopCryptoStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame { - ret := m.ctrl.Call(m, "PopCryptoStreamFrame", arg0) - ret0, _ := ret[0].(*wire.StreamFrame) - return ret0 -} - -// PopCryptoStreamFrame indicates an expected call of PopCryptoStreamFrame -func (mr *MockFrameSourceMockRecorder) PopCryptoStreamFrame(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoStreamFrame", reflect.TypeOf((*MockFrameSource)(nil).PopCryptoStreamFrame), arg0) -} diff --git a/mock_send_stream_internal_test.go b/mock_send_stream_internal_test.go index f1e68a0f..693c1423 100644 --- a/mock_send_stream_internal_test.go +++ b/mock_send_stream_internal_test.go @@ -140,6 +140,18 @@ func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0) } +// hasData mocks base method +func (m *MockSendStreamI) hasData() bool { + ret := m.ctrl.Call(m, "hasData") + ret0, _ := ret[0].(bool) + return ret0 +} + +// hasData indicates an expected call of hasData +func (mr *MockSendStreamIMockRecorder) hasData() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockSendStreamI)(nil).hasData)) +} + // popStreamFrame mocks base method func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) { ret := m.ctrl.Call(m, "popStreamFrame", arg0) diff --git a/mock_stream_internal_test.go b/mock_stream_internal_test.go index 6cbc8a97..59235a68 100644 --- a/mock_stream_internal_test.go +++ b/mock_stream_internal_test.go @@ -225,6 +225,18 @@ func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleStreamFrame), arg0) } +// hasData mocks base method +func (m *MockStreamI) hasData() bool { + ret := m.ctrl.Call(m, "hasData") + ret0, _ := ret[0].(bool) + return ret0 +} + +// hasData indicates an expected call of hasData +func (mr *MockStreamIMockRecorder) hasData() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockStreamI)(nil).hasData)) +} + // popStreamFrame mocks base method func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) { ret := m.ctrl.Call(m, "popStreamFrame", arg0) diff --git a/packet_packer.go b/packet_packer.go index 4b99b142..e68ba21b 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -66,8 +66,6 @@ type sealingManager interface { } type frameSource interface { - HasCryptoStreamData() bool - PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) } @@ -98,6 +96,7 @@ type packetPacker struct { packetNumberGenerator *packetNumberGenerator getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen + cryptoStream cryptoStream framer frameSource acks ackFrameSource @@ -117,6 +116,7 @@ func newPacketPacker( remoteAddr net.Addr, // only used for determining the max packet size token []byte, divNonce []byte, + cryptoStream cryptoStream, cryptoSetup sealingManager, framer frameSource, acks ackFrameSource, @@ -124,6 +124,7 @@ func newPacketPacker( version protocol.VersionNumber, ) *packetPacker { return &packetPacker{ + cryptoStream: cryptoStream, cryptoSetup: cryptoSetup, divNonce: divNonce, token: token, @@ -306,13 +307,16 @@ func (p *packetPacker) packHandshakeRetransmission(packet *ackhandler.Packet) (* // PackPacket packs a new packet // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise func (p *packetPacker) PackPacket() (*packedPacket, error) { - hasCryptoStreamFrame := p.framer.HasCryptoStreamData() - // if this is the first packet to be send, make sure it contains stream data - if !p.hasSentPacket && !hasCryptoStreamFrame { - return nil, nil + packet, err := p.maybePackCryptoPacket() + if err != nil { + return nil, err } - if hasCryptoStreamFrame { - return p.packCryptoPacket() + if packet != nil { + return packet, nil + } + // if this is the first packet to be send, make sure it contains stream data + if !p.hasSentPacket && packet == nil { + return nil, nil } encLevel, sealer := p.cryptoSetup.GetSealer() @@ -357,7 +361,10 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { }, nil } -func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { +func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { + if !p.cryptoStream.hasData() { + return nil, nil + } encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream() header := p.getHeader(encLevel) headerLength, err := header.GetLength(p.version) @@ -365,7 +372,7 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { return nil, err } maxLen := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength - sf := p.framer.PopCryptoStreamFrame(maxLen) + sf, _ := p.cryptoStream.popStreamFrame(maxLen) sf.DataLenPresent = false frames := []wire.Frame{sf} raw, err := p.writeAndSealPacket(header, frames, sealer) diff --git a/packet_packer_test.go b/packet_packer_test.go index bf00f595..df5d1257 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -58,11 +58,12 @@ func (m *mockCryptoSetup) ConnectionState() ConnectionState { panic("not impleme var _ = Describe("Packet packer", func() { const maxPacketSize protocol.ByteCount = 1357 var ( - packer *packetPacker - mockFramer *MockFrameSource - mockAckFramer *MockAckFrameSource - divNonce []byte - token []byte + packer *packetPacker + mockFramer *MockFrameSource + mockAckFramer *MockAckFrameSource + mockCryptoStream *MockCryptoStream + divNonce []byte + token []byte ) checkPayloadLen := func(data []byte) { @@ -95,6 +96,7 @@ var _ = Describe("Packet packer", func() { version := versionGQUICFrames mockSender := NewMockStreamSender(mockCtrl) mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() + mockCryptoStream = NewMockCryptoStream(mockCtrl) mockFramer = NewMockFrameSource(mockCtrl) mockAckFramer = NewMockAckFrameSource(mockCtrl) divNonce = bytes.Repeat([]byte{'e'}, 32) @@ -108,6 +110,7 @@ var _ = Describe("Packet packer", func() { &net.TCPAddr{}, token, // token divNonce, + mockCryptoStream, &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, mockFramer, mockAckFramer, @@ -138,7 +141,7 @@ var _ = Describe("Packet packer", func() { It("returns nil when no packet is queued", func() { mockAckFramer.EXPECT().GetAckFrame() - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockFramer.EXPECT().AppendControlFrames(nil, gomock.Any()) mockFramer.EXPECT().AppendStreamFrames(nil, gomock.Any()) p, err := packer.PackPacket() @@ -147,7 +150,7 @@ var _ = Describe("Packet packer", func() { }) It("packs single packets", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame() expectAppendControlFrames() f := &wire.StreamFrame{ @@ -165,7 +168,7 @@ var _ = Describe("Packet packer", func() { }) It("stores the encryption level a packet was sealed with", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame() expectAppendControlFrames() expectAppendStreamFrames(&wire.StreamFrame{ @@ -346,8 +349,8 @@ var _ = Describe("Packet packer", func() { Offset: 0x1337, Data: []byte("foobar"), } - mockFramer.EXPECT().HasCryptoStreamData().Return(true) - mockFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + mockCryptoStream.EXPECT().hasData().Return(true) + mockCryptoStream.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) checkPayloadLen(p.raw) @@ -376,7 +379,7 @@ var _ = Describe("Packet packer", func() { }) It("packs control frames", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame() frames := []wire.Frame{&wire.RstStreamFrame{}, &wire.MaxDataFrame{}} expectAppendControlFrames(frames...) @@ -389,7 +392,7 @@ var _ = Describe("Packet packer", func() { }) It("increases the packet number", func() { - mockFramer.EXPECT().HasCryptoStreamData().Times(2) + mockCryptoStream.EXPECT().hasData().Times(2) mockAckFramer.EXPECT().GetAckFrame().Times(2) expectAppendControlFrames() expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("foobar")}) @@ -405,7 +408,7 @@ var _ = Describe("Packet packer", func() { }) It("packs ACKs and STOP_WAITING frames first, then control frames, then STREAM frames", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100}}} swf := &wire.StopWaitingFrame{LeastUnacked: 10} mdf := &wire.MaxDataFrame{ByteOffset: 0x1234} @@ -422,7 +425,7 @@ var _ = Describe("Packet packer", func() { }) It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() swf := &wire.StopWaitingFrame{LeastUnacked: 0x1337 - 0x100} mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100}}}) mockAckFramer.EXPECT().GetStopWaitingFrame(false).Return(swf) @@ -436,7 +439,7 @@ var _ = Describe("Packet packer", func() { }) It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() packer.hasSentPacket = false p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -444,7 +447,7 @@ var _ = Describe("Packet packer", func() { }) It("accounts for the space consumed by control frames", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame() var maxSize protocol.ByteCount gomock.InOrder( @@ -463,7 +466,7 @@ var _ = Describe("Packet packer", func() { It("only increases the packet number when there is an actual packet to send", func() { mockAckFramer.EXPECT().GetAckFrame().Times(2) - mockFramer.EXPECT().HasCryptoStreamData().Times(2) + mockCryptoStream.EXPECT().hasData().Times(2) expectAppendStreamFrames() expectAppendControlFrames() packer.packetNumberGenerator.nextToSkip = 1000 @@ -482,7 +485,7 @@ var _ = Describe("Packet packer", func() { Context("making ACK packets retransmittable", func() { sendMaxNumNonRetransmittableAcks := func() { - mockFramer.EXPECT().HasCryptoStreamData().Times(protocol.MaxNonRetransmittableAcks) + mockCryptoStream.EXPECT().hasData().Times(protocol.MaxNonRetransmittableAcks) for i := 0; i < protocol.MaxNonRetransmittableAcks; i++ { mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) mockAckFramer.EXPECT().GetStopWaitingFrame(false) @@ -497,7 +500,7 @@ var _ = Describe("Packet packer", func() { It("adds a PING frame when it's supposed to send a retransmittable packet", func() { sendMaxNumNonRetransmittableAcks() - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) mockAckFramer.EXPECT().GetStopWaitingFrame(false) expectAppendControlFrames() @@ -507,7 +510,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(ContainElement(&wire.PingFrame{})) // make sure the next packet doesn't contain another PING - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) mockAckFramer.EXPECT().GetStopWaitingFrame(false) expectAppendControlFrames() @@ -521,7 +524,7 @@ var _ = Describe("Packet packer", func() { It("waits until there's something to send before adding a PING frame", func() { sendMaxNumNonRetransmittableAcks() // nothing to send - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() expectAppendControlFrames() expectAppendStreamFrames() mockAckFramer.EXPECT().GetAckFrame() @@ -531,7 +534,7 @@ var _ = Describe("Packet packer", func() { // now add some frame to send expectAppendControlFrames() expectAppendStreamFrames() - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) mockAckFramer.EXPECT().GetStopWaitingFrame(false) p, err = packer.PackPacket() @@ -542,7 +545,7 @@ var _ = Describe("Packet packer", func() { It("doesn't send a PING if it already sent another retransmittable frame", func() { sendMaxNumNonRetransmittableAcks() - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame() expectAppendStreamFrames() expectAppendControlFrames(&wire.MaxDataFrame{}) @@ -555,7 +558,7 @@ var _ = Describe("Packet packer", func() { Context("STREAM frame handling", func() { It("does not split a STREAM frame with maximum size, for gQUIC frames", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame() expectAppendControlFrames() sf := &wire.StreamFrame{ @@ -578,7 +581,7 @@ var _ = Describe("Packet packer", func() { It("does not split a STREAM frame with maximum size, for IETF draft style frame", func() { packer.version = versionIETFFrames mockAckFramer.EXPECT().GetAckFrame() - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() expectAppendControlFrames() sf := &wire.StreamFrame{ Offset: 1, @@ -613,7 +616,7 @@ var _ = Describe("Packet packer", func() { Data: []byte("frame 3"), DataLenPresent: true, } - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame() expectAppendControlFrames() expectAppendStreamFrames(f1, f2, f3) @@ -630,7 +633,7 @@ var _ = Describe("Packet packer", func() { }) It("refuses to send unencrypted stream data on a data stream", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame() expectAppendControlFrames() // don't expect a call to mockFramer.PopStreamFrames @@ -641,7 +644,7 @@ var _ = Describe("Packet packer", func() { }) It("sends non forward-secure data as the client", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame() expectAppendControlFrames() f := &wire.StreamFrame{ @@ -658,7 +661,7 @@ var _ = Describe("Packet packer", func() { }) It("does not send non forward-secure data as the server", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() mockAckFramer.EXPECT().GetAckFrame() expectAppendControlFrames() // don't expect a call to mockFramer.PopStreamFrames @@ -671,14 +674,14 @@ var _ = Describe("Packet packer", func() { It("packs a maximum size crypto packet", func() { var f *wire.StreamFrame packer.version = versionIETFFrames - mockFramer.EXPECT().HasCryptoStreamData().Return(true) - mockFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.StreamFrame { + mockCryptoStream.EXPECT().hasData().Return(true) + mockCryptoStream.EXPECT().popStreamFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) (*wire.StreamFrame, bool) { f = &wire.StreamFrame{ StreamID: packer.version.CryptoStreamID(), Offset: 0x1337, } f.Data = bytes.Repeat([]byte{'f'}, int(size-f.Length(packer.version))) - return f + return f, false }) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -694,8 +697,8 @@ var _ = Describe("Packet packer", func() { StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), } - mockFramer.EXPECT().HasCryptoStreamData().Return(true) - mockFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + mockCryptoStream.EXPECT().hasData().Return(true) + mockCryptoStream.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -708,8 +711,8 @@ var _ = Describe("Packet packer", func() { StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), } - mockFramer.EXPECT().HasCryptoStreamData().Return(true) - mockFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + mockCryptoStream.EXPECT().hasData().Return(true) + mockCryptoStream.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -718,7 +721,7 @@ var _ = Describe("Packet packer", func() { }) It("does not pack STREAM frames if not allowed", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 10, Smallest: 1}}} mockAckFramer.EXPECT().GetAckFrame().Return(ack) mockAckFramer.EXPECT().GetStopWaitingFrame(false) @@ -732,7 +735,7 @@ var _ = Describe("Packet packer", func() { }) It("packs a single ACK", func() { - mockFramer.EXPECT().HasCryptoStreamData() + mockCryptoStream.EXPECT().hasData() ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}} mockAckFramer.EXPECT().GetAckFrame().Return(ack) mockAckFramer.EXPECT().GetStopWaitingFrame(false) @@ -820,8 +823,8 @@ var _ = Describe("Packet packer", func() { StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), } - mockFramer.EXPECT().HasCryptoStreamData().Return(true) - mockFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + mockCryptoStream.EXPECT().hasData().Return(true) + mockCryptoStream.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) packer.version = protocol.VersionTLS packer.hasSentPacket = false packer.perspective = protocol.PerspectiveClient @@ -837,11 +840,11 @@ var _ = Describe("Packet packer", func() { }) It("set the correct payload length for an Initial packet", func() { - mockFramer.EXPECT().HasCryptoStreamData().Return(true) - mockFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(&wire.StreamFrame{ + mockCryptoStream.EXPECT().hasData().Return(true) + mockCryptoStream.EXPECT().popStreamFrame(gomock.Any()).Return(&wire.StreamFrame{ StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), - }) + }, false) packer.version = protocol.VersionTLS packer.hasSentPacket = false packer.perspective = protocol.PerspectiveClient @@ -1078,7 +1081,7 @@ var _ = Describe("Packet packer", func() { Context("max packet size", func() { It("sets the maximum packet size", func() { mockAckFramer.EXPECT().GetAckFrame().Times(2) - mockFramer.EXPECT().HasCryptoStreamData().AnyTimes() + mockCryptoStream.EXPECT().hasData().AnyTimes() var initialMaxPacketSize protocol.ByteCount mockFramer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { initialMaxPacketSize = maxLen @@ -1102,7 +1105,7 @@ var _ = Describe("Packet packer", func() { It("doesn't increase the max packet size", func() { mockAckFramer.EXPECT().GetAckFrame().Times(2) - mockFramer.EXPECT().HasCryptoStreamData().AnyTimes() + mockCryptoStream.EXPECT().hasData().AnyTimes() var initialMaxPacketSize protocol.ByteCount mockFramer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { initialMaxPacketSize = maxLen diff --git a/send_stream.go b/send_stream.go index bfefc71d..624e1ff4 100644 --- a/send_stream.go +++ b/send_stream.go @@ -15,6 +15,7 @@ import ( type sendStreamI interface { SendStream handleStopSendingFrame(*wire.StopSendingFrame) + hasData() bool popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool) closeForShutdown(error) handleMaxStreamDataFrame(*wire.MaxStreamDataFrame) @@ -181,6 +182,13 @@ func (s *sendStream) popStreamFrameImpl(maxBytes protocol.ByteCount) (bool /* co return frame.FinBit, frame, s.dataForWriting != nil } +func (s *sendStream) hasData() bool { + s.mutex.Lock() + hasData := len(s.dataForWriting) > 0 + s.mutex.Unlock() + return hasData +} + func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { if s.dataForWriting == nil { return nil, s.finishedWriting && !s.finSent diff --git a/session.go b/session.go index 358cebf1..d69821fd 100644 --- a/session.go +++ b/session.go @@ -214,6 +214,7 @@ func newSession( s.RemoteAddr(), nil, // no token divNonce, + s.cryptoStream, cs, s.framer, sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, @@ -288,6 +289,7 @@ var newClientSession = func( s.RemoteAddr(), nil, // no token nil, // no diversification nonce + s.cryptoStream, cs, s.framer, sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, @@ -344,6 +346,7 @@ func newTLSServerSession( s.RemoteAddr(), nil, // no token nil, // no diversification nonce + s.cryptoStream, cs, s.framer, sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, @@ -409,6 +412,7 @@ var newTLSClientSession = func( s.RemoteAddr(), token, nil, // no diversification nonce + s.cryptoStream, cs, s.framer, sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, @@ -1246,7 +1250,9 @@ func (s *session) onHasConnectionWindowUpdate() { } func (s *session) onHasStreamData(id protocol.StreamID) { - s.framer.AddActiveStream(id) + if id != s.version.CryptoStreamID() { + s.framer.AddActiveStream(id) + } s.scheduleSending() } diff --git a/stream.go b/stream.go index 5d6ce671..64e4d46e 100644 --- a/stream.go +++ b/stream.go @@ -52,6 +52,7 @@ type streamI interface { handleRstStreamFrame(*wire.RstStreamFrame) error getWindowUpdate() protocol.ByteCount // for sending + hasData() bool handleStopSendingFrame(*wire.StopSendingFrame) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool) handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)