From 06db39836ba0cd4d135dc415d65dac6e50aa7908 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 28 Sep 2018 17:37:46 -0600 Subject: [PATCH] move queueing of control frames to the framer --- stream_framer.go => framer.go | 42 +++- stream_framer_test.go => framer_test.go | 44 +++- mock_frame_source_test.go | 85 +++++++ mock_packer_test.go | 10 - mock_stream_frame_source_test.go | 72 ------ mockgen.go | 2 +- packet_packer.go | 42 +--- packet_packer_test.go | 314 ++++++++++++------------ session.go | 28 +-- session_test.go | 8 +- 10 files changed, 345 insertions(+), 302 deletions(-) rename stream_framer.go => framer.go (69%) rename stream_framer_test.go => framer_test.go (84%) create mode 100644 mock_frame_source_test.go delete mode 100644 mock_stream_frame_source_test.go diff --git a/stream_framer.go b/framer.go similarity index 69% rename from stream_framer.go rename to framer.go index 24247a84..a6cf008f 100644 --- a/stream_framer.go +++ b/framer.go @@ -7,7 +7,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -type streamFramer struct { +type framer struct { streamGetter streamGetter cryptoStream cryptoStream version protocol.VersionNumber @@ -16,14 +16,17 @@ type streamFramer struct { activeStreams map[protocol.StreamID]struct{} streamQueue []protocol.StreamID hasCryptoStreamData bool + + controlFrameMutex sync.Mutex + controlFrames []wire.Frame } -func newStreamFramer( +func newFramer( cryptoStream cryptoStream, streamGetter streamGetter, v protocol.VersionNumber, -) *streamFramer { - return &streamFramer{ +) *framer { + return &framer{ streamGetter: streamGetter, cryptoStream: cryptoStream, activeStreams: make(map[protocol.StreamID]struct{}), @@ -31,7 +34,30 @@ func newStreamFramer( } } -func (f *streamFramer) AddActiveStream(id protocol.StreamID) { +func (f *framer) QueueControlFrame(frame wire.Frame) { + f.controlFrameMutex.Lock() + f.controlFrames = append(f.controlFrames, frame) + f.controlFrameMutex.Unlock() +} + +func (f *framer) AppendControlFrames(frames []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + var length protocol.ByteCount + f.controlFrameMutex.Lock() + for len(f.controlFrames) > 0 { + frame := f.controlFrames[len(f.controlFrames)-1] + frameLen := frame.Length(f.version) + if length+frameLen > maxLen { + break + } + frames = append(frames, frame) + length += frameLen + f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] + } + f.controlFrameMutex.Unlock() + return frames, length +} + +func (f *framer) AddActiveStream(id protocol.StreamID) { if id == f.version.CryptoStreamID() { // the crypto stream is handled separately f.streamQueueMutex.Lock() f.hasCryptoStreamData = true @@ -46,14 +72,14 @@ func (f *streamFramer) AddActiveStream(id protocol.StreamID) { f.streamQueueMutex.Unlock() } -func (f *streamFramer) HasCryptoStreamData() bool { +func (f *framer) HasCryptoStreamData() bool { f.streamQueueMutex.Lock() hasCryptoStreamData := f.hasCryptoStreamData f.streamQueueMutex.Unlock() return hasCryptoStreamData } -func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { +func (f *framer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { f.streamQueueMutex.Lock() frame, hasMoreData := f.cryptoStream.popStreamFrame(maxLen) f.hasCryptoStreamData = hasMoreData @@ -61,7 +87,7 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str return frame } -func (f *streamFramer) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { +func (f *framer) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { var length protocol.ByteCount f.streamQueueMutex.Lock() // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet diff --git a/stream_framer_test.go b/framer_test.go similarity index 84% rename from stream_framer_test.go rename to framer_test.go index edca4d41..440d1a16 100644 --- a/stream_framer_test.go +++ b/framer_test.go @@ -18,7 +18,7 @@ var _ = Describe("Stream Framer", func() { ) var ( - framer *streamFramer + framer *framer cryptoStream *MockCryptoStream stream1, stream2 *MockSendStreamI streamGetter *MockStreamGetter @@ -31,7 +31,45 @@ var _ = Describe("Stream Framer", func() { stream2 = NewMockSendStreamI(mockCtrl) stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() cryptoStream = NewMockCryptoStream(mockCtrl) - framer = newStreamFramer(cryptoStream, streamGetter, versionGQUICFrames) + framer = newFramer(cryptoStream, streamGetter, versionGQUICFrames) + }) + + Context("handling control frames", func() { + It("adds control frames", func() { + mdf := &wire.MaxDataFrame{ByteOffset: 0x42} + msdf := &wire.MaxStreamDataFrame{ByteOffset: 0x1337} + framer.QueueControlFrame(mdf) + framer.QueueControlFrame(msdf) + frames, length := framer.AppendControlFrames(nil, 1000) + Expect(frames).To(ContainElement(mdf)) + Expect(frames).To(ContainElement(msdf)) + Expect(length).To(Equal(mdf.Length(framer.version) + msdf.Length(framer.version))) + }) + + It("appends to the slice given", func() { + ack := &wire.AckFrame{} + mdf := &wire.MaxDataFrame{ByteOffset: 0x42} + framer.QueueControlFrame(mdf) + frames, length := framer.AppendControlFrames([]wire.Frame{ack}, 1000) + Expect(frames).To(Equal([]wire.Frame{ack, mdf})) + Expect(length).To(Equal(mdf.Length(framer.version))) + }) + + It("adds the right number of frames", func() { + maxSize := protocol.ByteCount(1000) + bf := &wire.BlockedFrame{Offset: 0x1337} + bfLen := bf.Length(framer.version) + numFrames := int(maxSize / bfLen) // max number of frames that fit into maxSize + for i := 0; i < numFrames+1; i++ { + framer.QueueControlFrame(bf) + } + frames, length := framer.AppendControlFrames(nil, protocol.ByteCount(maxSize)) + Expect(frames).To(HaveLen(numFrames)) + Expect(length).To(BeNumerically(">", maxSize-bfLen)) + frames, length = framer.AppendControlFrames(nil, protocol.ByteCount(maxSize)) + Expect(frames).To(HaveLen(1)) + Expect(length).To(Equal(bfLen)) + }) }) Context("handling the crypto stream", func() { @@ -66,7 +104,7 @@ var _ = Describe("Stream Framer", func() { }) }) - Context("Popping", func() { + Context("popping STREAM frames", func() { It("returns nil when popping an empty framer", func() { Expect(framer.AppendStreamFrames(nil, 1000)).To(BeEmpty()) }) diff --git a/mock_frame_source_test.go b/mock_frame_source_test.go new file mode 100644 index 00000000..655323b8 --- /dev/null +++ b/mock_frame_source_test.go @@ -0,0 +1,85 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: FrameSource) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockFrameSource is a mock of FrameSource interface +type MockFrameSource struct { + ctrl *gomock.Controller + recorder *MockFrameSourceMockRecorder +} + +// MockFrameSourceMockRecorder is the mock recorder for MockFrameSource +type MockFrameSourceMockRecorder struct { + mock *MockFrameSource +} + +// NewMockFrameSource creates a new mock instance +func NewMockFrameSource(ctrl *gomock.Controller) *MockFrameSource { + mock := &MockFrameSource{ctrl: ctrl} + mock.recorder = &MockFrameSourceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockFrameSource) EXPECT() *MockFrameSourceMockRecorder { + return m.recorder +} + +// AppendControlFrames mocks base method +func (m *MockFrameSource) AppendControlFrames(arg0 []wire.Frame, arg1 protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + ret := m.ctrl.Call(m, "AppendControlFrames", arg0, arg1) + ret0, _ := ret[0].([]wire.Frame) + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 +} + +// AppendControlFrames indicates an expected call of AppendControlFrames +func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendControlFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendControlFrames), arg0, arg1) +} + +// AppendStreamFrames mocks base method +func (m *MockFrameSource) AppendStreamFrames(arg0 []wire.Frame, arg1 protocol.ByteCount) []wire.Frame { + ret := m.ctrl.Call(m, "AppendStreamFrames", arg0, arg1) + ret0, _ := ret[0].([]wire.Frame) + return ret0 +} + +// AppendStreamFrames indicates an expected call of AppendStreamFrames +func (mr *MockFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendStreamFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendStreamFrames), arg0, arg1) +} + +// HasCryptoStreamData mocks base method +func (m *MockFrameSource) HasCryptoStreamData() bool { + ret := m.ctrl.Call(m, "HasCryptoStreamData") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasCryptoStreamData indicates an expected call of HasCryptoStreamData +func (mr *MockFrameSourceMockRecorder) HasCryptoStreamData() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasCryptoStreamData", reflect.TypeOf((*MockFrameSource)(nil).HasCryptoStreamData)) +} + +// PopCryptoStreamFrame mocks base method +func (m *MockFrameSource) PopCryptoStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame { + ret := m.ctrl.Call(m, "PopCryptoStreamFrame", arg0) + ret0, _ := ret[0].(*wire.StreamFrame) + return ret0 +} + +// PopCryptoStreamFrame indicates an expected call of PopCryptoStreamFrame +func (mr *MockFrameSourceMockRecorder) PopCryptoStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoStreamFrame", reflect.TypeOf((*MockFrameSource)(nil).PopCryptoStreamFrame), arg0) +} diff --git a/mock_packer_test.go b/mock_packer_test.go index 10b4dad3..1123390f 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -108,13 +108,3 @@ func (m *MockPacker) PackRetransmission(arg0 *ackhandler.Packet) ([]*packedPacke func (mr *MockPackerMockRecorder) PackRetransmission(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackRetransmission", reflect.TypeOf((*MockPacker)(nil).PackRetransmission), arg0) } - -// QueueControlFrame mocks base method -func (m *MockPacker) QueueControlFrame(arg0 wire.Frame) { - m.ctrl.Call(m, "QueueControlFrame", arg0) -} - -// QueueControlFrame indicates an expected call of QueueControlFrame -func (mr *MockPackerMockRecorder) QueueControlFrame(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueControlFrame", reflect.TypeOf((*MockPacker)(nil).QueueControlFrame), arg0) -} diff --git a/mock_stream_frame_source_test.go b/mock_stream_frame_source_test.go deleted file mode 100644 index fa29c02c..00000000 --- a/mock_stream_frame_source_test.go +++ /dev/null @@ -1,72 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go (interfaces: StreamFrameSource) - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/lucas-clemente/quic-go/internal/protocol" - wire "github.com/lucas-clemente/quic-go/internal/wire" -) - -// MockStreamFrameSource is a mock of StreamFrameSource interface -type MockStreamFrameSource struct { - ctrl *gomock.Controller - recorder *MockStreamFrameSourceMockRecorder -} - -// MockStreamFrameSourceMockRecorder is the mock recorder for MockStreamFrameSource -type MockStreamFrameSourceMockRecorder struct { - mock *MockStreamFrameSource -} - -// NewMockStreamFrameSource creates a new mock instance -func NewMockStreamFrameSource(ctrl *gomock.Controller) *MockStreamFrameSource { - mock := &MockStreamFrameSource{ctrl: ctrl} - mock.recorder = &MockStreamFrameSourceMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockStreamFrameSource) EXPECT() *MockStreamFrameSourceMockRecorder { - return m.recorder -} - -// AppendStreamFrames mocks base method -func (m *MockStreamFrameSource) AppendStreamFrames(arg0 []wire.Frame, arg1 protocol.ByteCount) []wire.Frame { - ret := m.ctrl.Call(m, "AppendStreamFrames", arg0, arg1) - ret0, _ := ret[0].([]wire.Frame) - return ret0 -} - -// AppendStreamFrames indicates an expected call of AppendStreamFrames -func (mr *MockStreamFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendStreamFrames", reflect.TypeOf((*MockStreamFrameSource)(nil).AppendStreamFrames), arg0, arg1) -} - -// HasCryptoStreamData mocks base method -func (m *MockStreamFrameSource) HasCryptoStreamData() bool { - ret := m.ctrl.Call(m, "HasCryptoStreamData") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasCryptoStreamData indicates an expected call of HasCryptoStreamData -func (mr *MockStreamFrameSourceMockRecorder) HasCryptoStreamData() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasCryptoStreamData", reflect.TypeOf((*MockStreamFrameSource)(nil).HasCryptoStreamData)) -} - -// PopCryptoStreamFrame mocks base method -func (m *MockStreamFrameSource) PopCryptoStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame { - ret := m.ctrl.Call(m, "PopCryptoStreamFrame", arg0) - ret0, _ := ret[0].(*wire.StreamFrame) - return ret0 -} - -// PopCryptoStreamFrame indicates an expected call of PopCryptoStreamFrame -func (mr *MockStreamFrameSourceMockRecorder) PopCryptoStreamFrame(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoStreamFrame", reflect.TypeOf((*MockStreamFrameSource)(nil).PopCryptoStreamFrame), arg0) -} diff --git a/mockgen.go b/mockgen.go index 3667134d..11cee2ee 100644 --- a/mockgen.go +++ b/mockgen.go @@ -5,7 +5,7 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI" //go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender" //go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource" +//go:generate sh -c "./mockgen_private.sh quic mock_frame_source_test.go github.com/lucas-clemente/quic-go frameSource" //go:generate sh -c "./mockgen_private.sh quic mock_ack_frame_source_test.go github.com/lucas-clemente/quic-go ackFrameSource" //go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStream" //go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager" diff --git a/packet_packer.go b/packet_packer.go index 938dfe6c..a683dd6f 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net" - "sync" "time" "github.com/lucas-clemente/quic-go/internal/ackhandler" @@ -16,8 +15,6 @@ import ( ) type packer interface { - QueueControlFrame(frame wire.Frame) - PackPacket() (*packedPacket, error) MaybePackAckPacket() (*packedPacket, error) PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) @@ -68,10 +65,11 @@ type sealingManager interface { GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error) } -type streamFrameSource interface { +type frameSource interface { HasCryptoStreamData() bool PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame + AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) } // sentAndReceivedPacketManager is only needed until STOP_WAITING is removed @@ -100,12 +98,9 @@ type packetPacker struct { packetNumberGenerator *packetNumberGenerator getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen - streams streamFrameSource + framer frameSource acks ackFrameSource - controlFrameMutex sync.Mutex - controlFrames []wire.Frame - omitConnectionID bool maxPacketSize protocol.ByteCount hasSentPacket bool // has the packetPacker already sent a packet @@ -123,7 +118,7 @@ func newPacketPacker( token []byte, divNonce []byte, cryptoSetup sealingManager, - streamFramer streamFrameSource, + framer frameSource, acks ackFrameSource, perspective protocol.Perspective, version protocol.VersionNumber, @@ -136,7 +131,7 @@ func newPacketPacker( srcConnID: srcConnID, perspective: perspective, version: version, - streams: streamFramer, + framer: framer, acks: acks, getPacketNumberLen: getPacketNumberLen, packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), @@ -320,7 +315,7 @@ func (p *packetPacker) packHandshakeRetransmission(packet *ackhandler.Packet) (* // PackPacket packs a new packet // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise func (p *packetPacker) PackPacket() (*packedPacket, error) { - hasCryptoStreamFrame := p.streams.HasCryptoStreamData() + hasCryptoStreamFrame := p.framer.HasCryptoStreamData() // if this is the first packet to be send, make sure it contains stream data if !p.hasSentPacket && !hasCryptoStreamFrame { return nil, nil @@ -379,7 +374,7 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { return nil, err } maxLen := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength - sf := p.streams.PopCryptoStreamFrame(maxLen) + sf := p.framer.PopCryptoStreamFrame(maxLen) sf.DataLenPresent = false frames := []wire.Frame{sf} raw, err := p.writeAndSealPacket(header, frames, sealer) @@ -418,18 +413,9 @@ func (p *packetPacker) composeNextPacket( } } - p.controlFrameMutex.Lock() - for len(p.controlFrames) > 0 { - frame := p.controlFrames[len(p.controlFrames)-1] - frameLen := frame.Length(p.version) - if length+frameLen > maxFrameSize { - break - } - frames = append(frames, frame) - length += frameLen - p.controlFrames = p.controlFrames[:len(p.controlFrames)-1] - } - p.controlFrameMutex.Unlock() + var lengthAdded protocol.ByteCount + frames, lengthAdded = p.framer.AppendControlFrames(frames, maxFrameSize-length) + length += lengthAdded if !canSendStreamFrames { return frames, nil @@ -446,7 +432,7 @@ func (p *packetPacker) composeNextPacket( maxFrameSize += 2 } - frames = p.streams.AppendStreamFrames(frames, maxFrameSize-length) + frames = p.framer.AppendStreamFrames(frames, maxFrameSize-length) if len(frames) > 0 { lastFrame := frames[len(frames)-1] if sf, ok := lastFrame.(*wire.StreamFrame); ok { @@ -456,12 +442,6 @@ func (p *packetPacker) composeNextPacket( return frames, nil } -func (p *packetPacker) QueueControlFrame(frame wire.Frame) { - p.controlFrameMutex.Lock() - p.controlFrames = append(p.controlFrames, frame) - p.controlFrameMutex.Unlock() -} - func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header { pnum := p.packetNumberGenerator.Peek() packetNumberLen := p.getPacketNumberLen(pnum) diff --git a/packet_packer_test.go b/packet_packer_test.go index 63bc3b42..cc9886b0 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -56,13 +56,11 @@ func (m *mockCryptoSetup) ConnectionState() ConnectionState { panic("not impleme var _ = Describe("Packet packer", func() { const maxPacketSize protocol.ByteCount = 1357 var ( - packer *packetPacker - publicHeaderLen protocol.ByteCount - maxFrameSize protocol.ByteCount - mockStreamFramer *MockStreamFrameSource - mockAckFramer *MockAckFrameSource - divNonce []byte - token []byte + packer *packetPacker + mockFramer *MockFrameSource + mockAckFramer *MockAckFrameSource + divNonce []byte + token []byte ) checkPayloadLen := func(data []byte) { @@ -75,8 +73,18 @@ var _ = Describe("Packet packer", func() { } expectAppendStreamFrames := func(frames ...wire.Frame) { - mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) []wire.Frame { - return append(frames, fs...) + mockFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) []wire.Frame { + return append(fs, frames...) + }) + } + + expectAppendControlFrames := func(frames ...wire.Frame) { + mockFramer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + var length protocol.ByteCount + for _, f := range frames { + length += f.Length(packer.version) + } + return append(fs, frames...), length }) } @@ -84,7 +92,7 @@ var _ = Describe("Packet packer", func() { version := versionGQUICFrames mockSender := NewMockStreamSender(mockCtrl) mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() - mockStreamFramer = NewMockStreamFrameSource(mockCtrl) + mockFramer = NewMockFrameSource(mockCtrl) mockAckFramer = NewMockAckFrameSource(mockCtrl) divNonce = bytes.Repeat([]byte{'e'}, 32) token = []byte("initial token") @@ -98,13 +106,11 @@ var _ = Describe("Packet packer", func() { token, // token divNonce, &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, - mockStreamFramer, + mockFramer, mockAckFramer, protocol.PerspectiveServer, version, ) - publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number - maxFrameSize = maxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen packer.hasSentPacket = true packer.version = version packer.maxPacketSize = maxPacketSize @@ -129,16 +135,18 @@ var _ = Describe("Packet packer", func() { It("returns nil when no packet is queued", func() { mockAckFramer.EXPECT().GetAckFrame() - mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().AppendStreamFrames(nil, gomock.Any()) + mockFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().AppendControlFrames(nil, gomock.Any()) + mockFramer.EXPECT().AppendStreamFrames(nil, gomock.Any()) p, err := packer.PackPacket() Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) }) It("packs single packets", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() mockAckFramer.EXPECT().GetAckFrame() + expectAppendControlFrames() f := &wire.StreamFrame{ StreamID: 5, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, @@ -154,8 +162,9 @@ var _ = Describe("Packet packer", func() { }) It("stores the encryption level a packet was sealed with", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() mockAckFramer.EXPECT().GetAckFrame() + expectAppendControlFrames() expectAppendStreamFrames(&wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), @@ -334,8 +343,8 @@ var _ = Describe("Packet packer", func() { Offset: 0x1337, Data: []byte("foobar"), } - mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) - mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + mockFramer.EXPECT().HasCryptoStreamData().Return(true) + mockFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) checkPayloadLen(p.raw) @@ -353,67 +362,68 @@ var _ = Describe("Packet packer", func() { }) It("doesn't send any other frames when sending a CONNECTION_CLOSE", func() { - // expect no mockStreamFramer.PopStreamFrames + // expect no mockFramer.PopStreamFrames ccf := &wire.ConnectionCloseFrame{ ErrorCode: 0x1337, ReasonPhrase: "foobar", } - packer.controlFrames = []wire.Frame{&wire.MaxStreamDataFrame{StreamID: 37}} p, err := packer.PackConnectionClose(ccf) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(Equal([]wire.Frame{ccf})) }) - It("packs only control frames", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() + It("packs control frames", func() { + mockFramer.EXPECT().HasCryptoStreamData() mockAckFramer.EXPECT().GetAckFrame() + frames := []wire.Frame{&wire.RstStreamFrame{}, &wire.MaxDataFrame{}} + expectAppendControlFrames(frames...) expectAppendStreamFrames() - packer.QueueControlFrame(&wire.RstStreamFrame{}) - packer.QueueControlFrame(&wire.MaxDataFrame{}) p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(2)) + Expect(p.frames).To(Equal(frames)) Expect(p.raw).NotTo(BeEmpty()) }) It("increases the packet number", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockFramer.EXPECT().HasCryptoStreamData().Times(2) mockAckFramer.EXPECT().GetAckFrame().Times(2) - expectAppendStreamFrames() - expectAppendStreamFrames() - packer.QueueControlFrame(&wire.RstStreamFrame{}) + expectAppendControlFrames() + expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("foobar")}) + expectAppendControlFrames() + expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("raboof")}) p1, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p1).ToNot(BeNil()) - packer.QueueControlFrame(&wire.RstStreamFrame{}) p2, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p2).ToNot(BeNil()) Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber)) }) - It("packs ACKs and STOP_WAITING frames first", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() + It("packs ACKs and STOP_WAITING frames first, then control frames, then STREAM frames", func() { + mockFramer.EXPECT().HasCryptoStreamData() ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100}}} swf := &wire.StopWaitingFrame{LeastUnacked: 10} + mdf := &wire.MaxDataFrame{ByteOffset: 0x1234} + sf := &wire.StreamFrame{Data: []byte("foobar")} mockAckFramer.EXPECT().GetAckFrame().Return(ack) mockAckFramer.EXPECT().GetStopWaitingFrame(false).Return(swf) - expectAppendStreamFrames() + expectAppendControlFrames(mdf) + expectAppendStreamFrames(sf) packer.packetNumberGenerator.next = 15 - cf := &wire.RstStreamFrame{} - packer.QueueControlFrame(cf) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.frames).To(Equal([]wire.Frame{ack, swf, cf})) + Expect(p.frames).To(Equal([]wire.Frame{ack, swf, mdf, sf})) }) It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() swf := &wire.StopWaitingFrame{LeastUnacked: 0x1337 - 0x100} mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100}}}) mockAckFramer.EXPECT().GetStopWaitingFrame(false).Return(swf) + expectAppendControlFrames() expectAppendStreamFrames() packer.packetNumberGenerator.next = 0x1337 p, err := packer.PackPacket() @@ -423,64 +433,43 @@ var _ = Describe("Packet packer", func() { }) It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() packer.hasSentPacket = false - packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) - It("packs many control frames into one packets", func() { - mockAckFramer.EXPECT().GetAckFrame().Times(2) - f := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 1, Smallest: 1}}} - b := &bytes.Buffer{} - err := f.Write(b, packer.version) + It("accounts for the space consumed by control frames", func() { + mockFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() + var maxSize protocol.ByteCount + gomock.InOrder( + mockFramer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + maxSize = maxLen + return fs, 444 + }), + mockFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { + Expect(maxLen).To(Equal(maxSize - 444 + 2 /* data length of the STREAM frame */)) + return nil + }), + ) + _, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - maxFramesPerPacket := int(maxFrameSize) / b.Len() - var controlFrames []wire.Frame - for i := 0; i < maxFramesPerPacket; i++ { - controlFrames = append(controlFrames, f) - } - packer.controlFrames = controlFrames - payloadFrames, err := packer.composeNextPacket(nil, maxFrameSize, false) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(maxFramesPerPacket)) - payloadFrames, err = packer.composeNextPacket(nil, maxFrameSize, false) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(BeEmpty()) - }) - - It("packs a lot of control frames into 2 packets if they don't fit into one", func() { - mockAckFramer.EXPECT().GetAckFrame().Times(2) - blockedFrame := &wire.BlockedFrame{} - maxFramesPerPacket := int(maxFrameSize) / int(blockedFrame.Length(packer.version)) - var controlFrames []wire.Frame - for i := 0; i < maxFramesPerPacket+10; i++ { - controlFrames = append(controlFrames, blockedFrame) - } - packer.controlFrames = controlFrames - payloadFrames, err := packer.composeNextPacket(nil, maxFrameSize, false) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(maxFramesPerPacket)) - payloadFrames, err = packer.composeNextPacket(nil, maxFrameSize, false) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(10)) }) It("only increases the packet number when there is an actual packet to send", func() { mockAckFramer.EXPECT().GetAckFrame().Times(2) - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockFramer.EXPECT().HasCryptoStreamData().Times(2) expectAppendStreamFrames() + expectAppendControlFrames() packer.packetNumberGenerator.nextToSkip = 1000 p, err := packer.PackPacket() Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1))) - expectAppendStreamFrames(&wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, - }) + expectAppendControlFrames() + expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("foobar")}) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -490,10 +479,11 @@ var _ = Describe("Packet packer", func() { Context("making ACK packets retransmittable", func() { sendMaxNumNonRetransmittableAcks := func() { - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(protocol.MaxNonRetransmittableAcks) + mockFramer.EXPECT().HasCryptoStreamData().Times(protocol.MaxNonRetransmittableAcks) for i := 0; i < protocol.MaxNonRetransmittableAcks; i++ { mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) mockAckFramer.EXPECT().GetStopWaitingFrame(false) + expectAppendControlFrames() expectAppendStreamFrames() p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) @@ -504,18 +494,20 @@ var _ = Describe("Packet packer", func() { It("adds a PING frame when it's supposed to send a retransmittable packet", func() { sendMaxNumNonRetransmittableAcks() - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) mockAckFramer.EXPECT().GetStopWaitingFrame(false) + expectAppendControlFrames() expectAppendStreamFrames() p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(ContainElement(&wire.PingFrame{})) // make sure the next packet doesn't contain another PING - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) mockAckFramer.EXPECT().GetStopWaitingFrame(false) + expectAppendControlFrames() expectAppendStreamFrames() p, err = packer.PackPacket() Expect(p).ToNot(BeNil()) @@ -526,15 +518,17 @@ var _ = Describe("Packet packer", func() { It("waits until there's something to send before adding a PING frame", func() { sendMaxNumNonRetransmittableAcks() // nothing to send - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() + expectAppendControlFrames() expectAppendStreamFrames() mockAckFramer.EXPECT().GetAckFrame() p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) // now add some frame to send + expectAppendControlFrames() expectAppendStreamFrames() - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) mockAckFramer.EXPECT().GetStopWaitingFrame(false) p, err = packer.PackPacket() @@ -545,10 +539,10 @@ var _ = Describe("Packet packer", func() { It("doesn't send a PING if it already sent another retransmittable frame", func() { sendMaxNumNonRetransmittableAcks() - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() mockAckFramer.EXPECT().GetAckFrame() expectAppendStreamFrames() - packer.QueueControlFrame(&wire.MaxDataFrame{}) + expectAppendControlFrames(&wire.MaxDataFrame{}) p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) @@ -558,56 +552,46 @@ var _ = Describe("Packet packer", func() { Context("STREAM frame handling", func() { It("does not split a STREAM frame with maximum size, for gQUIC frames", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() mockAckFramer.EXPECT().GetAckFrame() - mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(_ []wire.Frame, maxSize protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { - f := &wire.StreamFrame{ - Offset: 1, - StreamID: 5, - DataLenPresent: true, - } - f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.Length(packer.version))) - return []wire.Frame{f}, f.Length(packer.version) + expectAppendControlFrames() + sf := &wire.StreamFrame{ + Offset: 1, + StreamID: 5, + DataLenPresent: true, + } + mockFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(_ []wire.Frame, maxSize protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + sf.Data = bytes.Repeat([]byte{'f'}, int(maxSize-sf.Length(packer.version))) + return []wire.Frame{sf}, sf.Length(packer.version) }) - mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) Expect(p.raw).To(HaveLen(int(maxPacketSize))) + Expect(p.frames[0].(*wire.StreamFrame).Data).To(HaveLen(len(sf.Data))) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - // make sure there's nothing else to send - mockStreamFramer.EXPECT().HasCryptoStreamData() - mockAckFramer.EXPECT().GetAckFrame() - p, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(BeNil()) }) It("does not split a STREAM frame with maximum size, for IETF draft style frame", func() { packer.version = versionIETFFrames mockAckFramer.EXPECT().GetAckFrame() - mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(_ []wire.Frame, maxSize protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { - f := &wire.StreamFrame{ - Offset: 1, - StreamID: 5, - DataLenPresent: true, - } - f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.Length(packer.version))) - return []wire.Frame{f}, f.Length(packer.version) + mockFramer.EXPECT().HasCryptoStreamData() + expectAppendControlFrames() + sf := &wire.StreamFrame{ + Offset: 1, + StreamID: 5, + DataLenPresent: true, + } + mockFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(_ []wire.Frame, maxSize protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + sf.Data = bytes.Repeat([]byte{'f'}, int(maxSize-sf.Length(packer.version))) + return []wire.Frame{sf}, sf.Length(packer.version) }) - mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) Expect(p.raw).To(HaveLen(int(maxPacketSize))) + Expect(p.frames[0].(*wire.StreamFrame).Data).To(HaveLen(len(sf.Data))) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - // make sure there's nothing else to send - mockStreamFramer.EXPECT().HasCryptoStreamData() - mockAckFramer.EXPECT().GetAckFrame() - p, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(BeNil()) }) It("packs multiple small STREAM frames into single packet", func() { @@ -626,8 +610,9 @@ var _ = Describe("Packet packer", func() { Data: []byte("frame 3"), DataLenPresent: true, } - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() mockAckFramer.EXPECT().GetAckFrame() + expectAppendControlFrames() expectAppendStreamFrames(f1, f2, f3) p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) @@ -642,9 +627,10 @@ var _ = Describe("Packet packer", func() { }) It("refuses to send unencrypted stream data on a data stream", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() mockAckFramer.EXPECT().GetAckFrame() - // don't expect a call to mockStreamFramer.PopStreamFrames + expectAppendControlFrames() + // don't expect a call to mockFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted p, err := packer.PackPacket() Expect(err).NotTo(HaveOccurred()) @@ -652,8 +638,9 @@ var _ = Describe("Packet packer", func() { }) It("sends non forward-secure data as the client", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() mockAckFramer.EXPECT().GetAckFrame() + expectAppendControlFrames() f := &wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), @@ -668,9 +655,10 @@ var _ = Describe("Packet packer", func() { }) It("does not send non forward-secure data as the server", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() mockAckFramer.EXPECT().GetAckFrame() - // don't expect a call to mockStreamFramer.PopStreamFrames + expectAppendControlFrames() + // don't expect a call to mockFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -680,8 +668,8 @@ var _ = Describe("Packet packer", func() { It("packs a maximum size crypto packet", func() { var f *wire.StreamFrame packer.version = versionIETFFrames - mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) - mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.StreamFrame { + mockFramer.EXPECT().HasCryptoStreamData().Return(true) + mockFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.StreamFrame { f = &wire.StreamFrame{ StreamID: packer.version.CryptoStreamID(), Offset: 0x1337, @@ -703,8 +691,8 @@ var _ = Describe("Packet packer", func() { StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), } - mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) - mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + mockFramer.EXPECT().HasCryptoStreamData().Return(true) + mockFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -717,8 +705,8 @@ var _ = Describe("Packet packer", func() { StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), } - mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) - mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + mockFramer.EXPECT().HasCryptoStreamData().Return(true) + mockFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -727,11 +715,12 @@ var _ = Describe("Packet packer", func() { }) It("does not pack STREAM frames if not allowed", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 10, Smallest: 1}}} mockAckFramer.EXPECT().GetAckFrame().Return(ack) mockAckFramer.EXPECT().GetStopWaitingFrame(false) - // don't expect a call to mockStreamFramer.PopStreamFrames + expectAppendControlFrames() + // don't expect a call to mockFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -740,10 +729,11 @@ var _ = Describe("Packet packer", func() { }) It("packs a single ACK", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() + mockFramer.EXPECT().HasCryptoStreamData() ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}} mockAckFramer.EXPECT().GetAckFrame().Return(ack) mockAckFramer.EXPECT().GetStopWaitingFrame(false) + expectAppendControlFrames() expectAppendStreamFrames() p, err := packer.PackPacket() Expect(err).NotTo(HaveOccurred()) @@ -827,8 +817,8 @@ var _ = Describe("Packet packer", func() { StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), } - mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) - mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + mockFramer.EXPECT().HasCryptoStreamData().Return(true) + mockFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) packer.version = protocol.VersionTLS packer.hasSentPacket = false packer.perspective = protocol.PerspectiveClient @@ -844,8 +834,8 @@ var _ = Describe("Packet packer", func() { }) It("set the correct payload length for an Initial packet", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) - mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(&wire.StreamFrame{ + mockFramer.EXPECT().HasCryptoStreamData().Return(true) + mockFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(&wire.StreamFrame{ StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), }) @@ -1046,44 +1036,50 @@ var _ = Describe("Packet packer", func() { Context("max packet size", func() { It("sets the maximum packet size", func() { mockAckFramer.EXPECT().GetAckFrame().Times(2) - for i := 0; i < 10*int(maxPacketSize); i++ { - packer.QueueControlFrame(&wire.PingFrame{}) - } - mockStreamFramer.EXPECT().HasCryptoStreamData().AnyTimes() - mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) []wire.Frame { - return fs - }).AnyTimes() - p, err := packer.PackPacket() + mockFramer.EXPECT().HasCryptoStreamData().AnyTimes() + var initialMaxPacketSize protocol.ByteCount + mockFramer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + initialMaxPacketSize = maxLen + return nil, 0 + }) + expectAppendStreamFrames() + _, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.raw).To(HaveLen(int(maxPacketSize))) // now reduce the maxPacketSize packer.HandleTransportParameters(&handshake.TransportParameters{ MaxPacketSize: maxPacketSize - 10, }) - p, err = packer.PackPacket() + mockFramer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + Expect(maxLen).To(Equal(initialMaxPacketSize - 10)) + return nil, 0 + }) + expectAppendStreamFrames() + _, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.raw).To(HaveLen(int(maxPacketSize) - 10)) }) It("doesn't increase the max packet size", func() { mockAckFramer.EXPECT().GetAckFrame().Times(2) - for i := 0; i < 10*int(maxPacketSize); i++ { - packer.QueueControlFrame(&wire.PingFrame{}) - } - mockStreamFramer.EXPECT().HasCryptoStreamData().AnyTimes() - mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) []wire.Frame { - return fs - }).AnyTimes() - p, err := packer.PackPacket() + mockFramer.EXPECT().HasCryptoStreamData().AnyTimes() + var initialMaxPacketSize protocol.ByteCount + mockFramer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + initialMaxPacketSize = maxLen + return nil, 0 + }) + expectAppendStreamFrames() + _, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.raw).To(HaveLen(int(maxPacketSize))) // now try to increase the maxPacketSize packer.HandleTransportParameters(&handshake.TransportParameters{ MaxPacketSize: maxPacketSize + 10, }) - p, err = packer.PackPacket() + mockFramer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + Expect(maxLen).To(Equal(initialMaxPacketSize)) + return nil, 0 + }) + expectAppendStreamFrames() + _, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.raw).To(HaveLen(int(maxPacketSize))) }) }) }) diff --git a/session.go b/session.go index f222b5bd..358cebf1 100644 --- a/session.go +++ b/session.go @@ -92,7 +92,7 @@ type session struct { sentPacketHandler ackhandler.SentPacketHandler receivedPacketHandler ackhandler.ReceivedPacketHandler - streamFramer *streamFramer + framer *framer windowUpdateQueue *windowUpdateQueue connFlowController flowcontrol.ConnectionFlowController @@ -205,7 +205,7 @@ func newSession( s.cryptoStreamHandler = cs s.unpacker = newPacketUnpackerGQUIC(cs, s.version) s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) - s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) + s.framer = newFramer(s.cryptoStream, s.streamsMap, s.version) s.packer = newPacketPacker( destConnID, srcConnID, @@ -215,7 +215,7 @@ func newSession( nil, // no token divNonce, cs, - s.streamFramer, + s.framer, sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, @@ -279,7 +279,7 @@ var newClientSession = func( s.cryptoStreamHandler = cs s.unpacker = newPacketUnpackerGQUIC(cs, s.version) s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) - s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) + s.framer = newFramer(s.cryptoStream, s.streamsMap, s.version) s.packer = newPacketPacker( destConnID, srcConnID, @@ -289,7 +289,7 @@ var newClientSession = func( nil, // no token nil, // no diversification nonce cs, - s.streamFramer, + s.framer, sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, @@ -335,7 +335,7 @@ func newTLSServerSession( } s.cryptoStreamHandler = cs s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version) - s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) + s.framer = newFramer(s.cryptoStream, s.streamsMap, s.version) s.packer = newPacketPacker( s.destConnID, s.srcConnID, @@ -345,7 +345,7 @@ func newTLSServerSession( nil, // no token nil, // no diversification nonce cs, - s.streamFramer, + s.framer, sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, @@ -400,7 +400,7 @@ var newTLSClientSession = func( s.cryptoStreamHandler = cs s.unpacker = newPacketUnpacker(cs, s.version) s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version) - s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) + s.framer = newFramer(s.cryptoStream, s.streamsMap, s.version) s.packer = newPacketPacker( s.destConnID, s.srcConnID, @@ -410,7 +410,7 @@ var newTLSClientSession = func( token, nil, // no diversification nonce cs, - s.streamFramer, + s.framer, sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, @@ -444,7 +444,7 @@ func (s *session) postSetup() error { s.lastNetworkActivityTime = now s.sessionCreationTime = now - s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.packer.QueueControlFrame) + s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) return nil } @@ -522,7 +522,7 @@ runLoop: if s.config.KeepAlive && !s.keepAlivePingSent && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.peerParams.IdleTimeout/2 { // send a PING frame since there is no activity in the session s.logger.Debugf("Sending a keep-alive ping to keep the connection alive.") - s.packer.QueueControlFrame(&wire.PingFrame{}) + s.framer.QueueControlFrame(&wire.PingFrame{}) s.keepAlivePingSent = true } else if !pacingDeadline.IsZero() && now.Before(pacingDeadline) { // If we get to this point before the pacing deadline, we should wait until that deadline. @@ -1067,7 +1067,7 @@ func (s *session) sendProbePacket() error { func (s *session) sendPacket() (bool, error) { if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { - s.packer.QueueControlFrame(&wire.BlockedFrame{Offset: offset}) + s.framer.QueueControlFrame(&wire.BlockedFrame{Offset: offset}) } s.windowUpdateQueue.QueueAll() @@ -1231,7 +1231,7 @@ func (s *session) tryDecryptingQueuedPackets() { } func (s *session) queueControlFrame(f wire.Frame) { - s.packer.QueueControlFrame(f) + s.framer.QueueControlFrame(f) s.scheduleSending() } @@ -1246,7 +1246,7 @@ func (s *session) onHasConnectionWindowUpdate() { } func (s *session) onHasStreamData(id protocol.StreamID) { - s.streamFramer.AddActiveStream(id) + s.framer.AddActiveStream(id) s.scheduleSending() } diff --git a/session_test.go b/session_test.go index a5986529..d0ecf53e 100644 --- a/session_test.go +++ b/session_test.go @@ -446,9 +446,10 @@ var _ = Describe("Session", func() { It("handles PATH_CHALLENGE frames", func() { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - packer.EXPECT().QueueControlFrame(&wire.PathResponseFrame{Data: data}) err := sess.handleFrames([]wire.Frame{&wire.PathChallengeFrame{Data: data}}, protocol.EncryptionUnspecified) Expect(err).ToNot(HaveOccurred()) + frames, _ := sess.framer.AppendControlFrames(nil, 1000) + Expect(frames).To(Equal([]wire.Frame{&wire.PathResponseFrame{Data: data}})) }) It("handles BLOCKED frames", func() { @@ -748,10 +749,11 @@ var _ = Describe("Session", func() { fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) packer.EXPECT().PackPacket().Return(getPacket(1), nil) sess.connFlowController = fc - packer.EXPECT().QueueControlFrame(&wire.BlockedFrame{Offset: 1337}) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(sent).To(BeTrue()) + frames, _ := sess.framer.AppendControlFrames(nil, 1000) + Expect(frames).To(Equal([]wire.Frame{&wire.BlockedFrame{Offset: 1337}})) }) It("sends PUBLIC_RESET", func() { @@ -1314,7 +1316,6 @@ var _ = Describe("Session", func() { sess.config.KeepAlive = true sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2) sent := make(chan struct{}) - packer.EXPECT().QueueControlFrame(&wire.PingFrame{}) packer.EXPECT().PackPacket().Do(func() (*packedPacket, error) { close(sent) return nil, nil @@ -1610,7 +1611,6 @@ var _ = Describe("Client Session", func() { done := make(chan struct{}) gomock.InOrder( sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()), - packer.EXPECT().QueueControlFrame(&wire.PingFrame{}), packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { close(done) return &packedPacket{header: &wire.Header{}, raw: *getPacketBuffer()}, nil