diff --git a/crypto_stream_test.go b/crypto_stream_test.go index 81bb3296..834bbd5f 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -201,8 +201,9 @@ var _ = Describe("Post Handshake Crypto Stream", func() { Expect(n).To(Equal(3)) frames, _ := framer.AppendControlFrames(nil, 1000) Expect(frames).To(HaveLen(2)) - Expect(frames).To(ContainElement(&wire.CryptoFrame{Data: []byte("foo")})) - Expect(frames).To(ContainElement(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3})) + fs := []wire.Frame{frames[0].Frame, frames[1].Frame} + Expect(fs).To(ContainElement(&wire.CryptoFrame{Data: []byte("foo")})) + Expect(fs).To(ContainElement(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3})) }) It("splits large writes into multiple frames", func() { @@ -214,9 +215,9 @@ var _ = Describe("Post Handshake Crypto Stream", func() { Expect(frames).To(HaveLen(11)) // one more for framing overhead var dataLen int for _, f := range frames { - Expect(f.Length(protocol.VersionTLS)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize)) - Expect(f).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) - dataLen += len(f.(*wire.CryptoFrame).Data) + Expect(f.Frame.Length(protocol.VersionTLS)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize)) + Expect(f.Frame).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) + dataLen += len(f.Frame.(*wire.CryptoFrame).Data) } Expect(dataLen).To(BeEquivalentTo(size)) }) diff --git a/framer.go b/framer.go index 117f01c1..632efb75 100644 --- a/framer.go +++ b/framer.go @@ -3,6 +3,7 @@ package quic import ( "sync" + "github.com/lucas-clemente/quic-go/internal/ackhandler" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" @@ -10,10 +11,10 @@ import ( type framer interface { QueueControlFrame(wire.Frame) - AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) + AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) AddActiveStream(protocol.StreamID) - AppendStreamFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) + AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) } type framerI struct { @@ -48,7 +49,7 @@ func (f *framerI) QueueControlFrame(frame wire.Frame) { f.controlFrameMutex.Unlock() } -func (f *framerI) AppendControlFrames(frames []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { +func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { var length protocol.ByteCount f.controlFrameMutex.Lock() for len(f.controlFrames) > 0 { @@ -57,7 +58,7 @@ func (f *framerI) AppendControlFrames(frames []wire.Frame, maxLen protocol.ByteC if length+frameLen > maxLen { break } - frames = append(frames, frame) + frames = append(frames, ackhandler.Frame{Frame: frame}) length += frameLen f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] } @@ -74,9 +75,9 @@ func (f *framerI) AddActiveStream(id protocol.StreamID) { f.mutex.Unlock() } -func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { +func (f *framerI) AppendStreamFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { var length protocol.ByteCount - var frameAdded bool + var lastFrame *wire.StreamFrame f.mutex.Lock() // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet numActiveStreams := len(f.streamQueue) @@ -96,7 +97,7 @@ func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCo } remainingLen := maxLen - length // For the last STREAM frame, we'll remove the DataLen field later. - // Therefore, we can pretend to have more bytes avaibalbe when popping + // Therefore, we can pretend to have more bytes available when popping // the STREAM frame (which will always have the DataLen set). remainingLen += utils.VarIntLen(uint64(remainingLen)) frame, hasMoreData := str.popStreamFrame(remainingLen) @@ -111,15 +112,14 @@ func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCo if frame == nil { continue } - frames = append(frames, frame) + frames = append(frames, ackhandler.Frame{Frame: frame}) length += frame.Length(f.version) - frameAdded = true + lastFrame = frame } f.mutex.Unlock() - if frameAdded { - lastFrame := frames[len(frames)-1].(*wire.StreamFrame) + if lastFrame != nil { lastFrameLen := lastFrame.Length(f.version) - // acount for the smaller size of the last STREAM frame + // account for the smaller size of the last STREAM frame lastFrame.DataLenPresent = false length += lastFrame.Length(f.version) - lastFrameLen } diff --git a/framer_test.go b/framer_test.go index 37f68342..3d595bd6 100644 --- a/framer_test.go +++ b/framer_test.go @@ -3,6 +3,8 @@ package quic import ( "bytes" + "github.com/lucas-clemente/quic-go/internal/ackhandler" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -37,21 +39,25 @@ var _ = Describe("Framer", func() { Context("handling control frames", func() { It("adds control frames", func() { mdf := &wire.MaxDataFrame{ByteOffset: 0x42} - msdf := &wire.MaxStreamDataFrame{ByteOffset: 0x1337} + msf := &wire.MaxStreamsFrame{MaxStreamNum: 0x1337} framer.QueueControlFrame(mdf) - framer.QueueControlFrame(msdf) + framer.QueueControlFrame(msf) frames, length := framer.AppendControlFrames(nil, 1000) - Expect(frames).To(ContainElement(mdf)) - Expect(frames).To(ContainElement(msdf)) - Expect(length).To(Equal(mdf.Length(version) + msdf.Length(version))) + Expect(frames).To(HaveLen(2)) + fs := []wire.Frame{frames[0].Frame, frames[1].Frame} + Expect(fs).To(ContainElement(mdf)) + Expect(fs).To(ContainElement(msf)) + Expect(length).To(Equal(mdf.Length(version) + msf.Length(version))) }) It("appends to the slice given", func() { - ack := &wire.AckFrame{} + ping := &wire.PingFrame{} mdf := &wire.MaxDataFrame{ByteOffset: 0x42} framer.QueueControlFrame(mdf) - frames, length := framer.AppendControlFrames([]wire.Frame{ack}, 1000) - Expect(frames).To(Equal([]wire.Frame{ack, mdf})) + frames, length := framer.AppendControlFrames([]ackhandler.Frame{{Frame: ping}}, 1000) + Expect(frames).To(HaveLen(2)) + Expect(frames[0].Frame).To(Equal(ping)) + Expect(frames[1].Frame).To(Equal(mdf)) Expect(length).To(Equal(mdf.Length(version))) }) @@ -88,8 +94,8 @@ var _ = Describe("Framer", func() { stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) framer.AddActiveStream(id1) fs, length := framer.AppendStreamFrames(nil, 1000) - Expect(fs).To(Equal([]wire.Frame{f})) - Expect(fs[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + Expect(fs).To(HaveLen(1)) + Expect(fs[0].Frame.(*wire.StreamFrame).DataLenPresent).To(BeFalse()) Expect(length).To(Equal(f.Length(version))) }) @@ -103,10 +109,12 @@ var _ = Describe("Framer", func() { stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) framer.AddActiveStream(id1) mdf := &wire.MaxDataFrame{ByteOffset: 1337} - frames := []wire.Frame{mdf} + frames := []ackhandler.Frame{{Frame: mdf}} fs, length := framer.AppendStreamFrames(frames, 1000) - Expect(fs).To(Equal([]wire.Frame{mdf, f})) - Expect(fs[1].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + Expect(fs).To(HaveLen(2)) + Expect(fs[0].Frame).To(Equal(mdf)) + Expect(fs[1].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + Expect(fs[1].Frame.(*wire.StreamFrame).DataLenPresent).To(BeFalse()) Expect(length).To(Equal(f.Length(version))) }) @@ -122,7 +130,8 @@ var _ = Describe("Framer", func() { framer.AddActiveStream(id1) framer.AddActiveStream(id2) frames, _ := framer.AppendStreamFrames(nil, 1000) - Expect(frames).To(Equal([]wire.Frame{f})) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f)) }) It("skips a stream that was reported active, but doesn't have any data", func() { @@ -138,7 +147,8 @@ var _ = Describe("Framer", func() { framer.AddActiveStream(id1) framer.AddActiveStream(id2) frames, _ := framer.AppendStreamFrames(nil, 1000) - Expect(frames).To(Equal([]wire.Frame{f})) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f)) }) It("pops from a stream multiple times, if it has enough data", func() { @@ -149,9 +159,11 @@ var _ = Describe("Framer", func() { stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) framer.AddActiveStream(id1) // only add it once frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) - Expect(frames).To(Equal([]wire.Frame{f1})) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f1)) frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) - Expect(frames).To(Equal([]wire.Frame{f2})) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f2)) // no further calls to popStreamFrame, after popStreamFrame said there's no more data frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) Expect(frames).To(BeNil()) @@ -170,13 +182,16 @@ var _ = Describe("Framer", func() { framer.AddActiveStream(id2) // first a frame from stream 1 frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) - Expect(frames).To(Equal([]wire.Frame{f11})) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f11)) // then a frame from stream 2 frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) - Expect(frames).To(Equal([]wire.Frame{f2})) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f2)) // then another frame from stream 1 frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) - Expect(frames).To(Equal([]wire.Frame{f12})) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f12)) }) It("only dequeues data from each stream once per packet", func() { @@ -190,7 +205,9 @@ var _ = Describe("Framer", func() { framer.AddActiveStream(id1) framer.AddActiveStream(id2) frames, length := framer.AppendStreamFrames(nil, 1000) - Expect(frames).To(Equal([]wire.Frame{f1, f2})) + Expect(frames).To(HaveLen(2)) + Expect(frames[0].Frame).To(Equal(f1)) + Expect(frames[1].Frame).To(Equal(f2)) Expect(length).To(Equal(f1.Length(version) + f2.Length(version))) }) @@ -204,7 +221,9 @@ var _ = Describe("Framer", func() { framer.AddActiveStream(id2) framer.AddActiveStream(id1) frames, _ := framer.AppendStreamFrames(nil, 1000) - Expect(frames).To(Equal([]wire.Frame{f2, f1})) + Expect(frames).To(HaveLen(2)) + Expect(frames[0].Frame).To(Equal(f2)) + Expect(frames[1].Frame).To(Equal(f1)) }) It("only asks a stream for data once, even if it was reported active multiple times", func() { @@ -238,7 +257,7 @@ var _ = Describe("Framer", func() { framer.AddActiveStream(id1) frames, _ := framer.AppendStreamFrames(nil, i) Expect(frames).To(HaveLen(1)) - f := frames[0].(*wire.StreamFrame) + f := frames[0].Frame.(*wire.StreamFrame) Expect(f.DataLenPresent).To(BeFalse()) Expect(f.Length(version)).To(Equal(i)) } @@ -269,8 +288,8 @@ var _ = Describe("Framer", func() { framer.AddActiveStream(id2) frames, _ := framer.AppendStreamFrames(nil, i) Expect(frames).To(HaveLen(2)) - f1 := frames[0].(*wire.StreamFrame) - f2 := frames[1].(*wire.StreamFrame) + f1 := frames[0].Frame.(*wire.StreamFrame) + f2 := frames[1].Frame.(*wire.StreamFrame) Expect(f1.DataLenPresent).To(BeTrue()) Expect(f2.DataLenPresent).To(BeFalse()) Expect(f1.Length(version) + f2.Length(version)).To(Equal(i)) @@ -300,7 +319,8 @@ var _ = Describe("Framer", func() { stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) framer.AddActiveStream(id1) fs, length := framer.AppendStreamFrames(nil, 500) - Expect(fs).To(Equal([]wire.Frame{f})) + Expect(fs).To(HaveLen(1)) + Expect(fs[0].Frame).To(Equal(f)) Expect(length).To(Equal(f.Length(version))) }) }) diff --git a/internal/ackhandler/ack_eliciting.go b/internal/ackhandler/ack_eliciting.go index bf4e971c..15fa87a3 100644 --- a/internal/ackhandler/ack_eliciting.go +++ b/internal/ackhandler/ack_eliciting.go @@ -9,9 +9,9 @@ func IsFrameAckEliciting(f wire.Frame) bool { } // HasAckElicitingFrames returns true if at least one frame is ack-eliciting. -func HasAckElicitingFrames(fs []wire.Frame) bool { +func HasAckElicitingFrames(fs []Frame) bool { for _, f := range fs { - if IsFrameAckEliciting(f) { + if IsFrameAckEliciting(f.Frame) { return true } } diff --git a/internal/ackhandler/ack_eliciting_test.go b/internal/ackhandler/ack_eliciting_test.go index 65263af6..f38bc1ac 100644 --- a/internal/ackhandler/ack_eliciting_test.go +++ b/internal/ackhandler/ack_eliciting_test.go @@ -28,7 +28,7 @@ var _ = Describe("ack-eliciting frames", func() { }) It("HasAckElicitingFrames works for "+fName, func() { - Expect(HasAckElicitingFrames([]wire.Frame{f})).To(Equal(e)) + Expect(HasAckElicitingFrames([]Frame{{Frame: f}})).To(Equal(e)) }) } }) diff --git a/mock_frame_source_test.go b/mock_frame_source_test.go index 676da023..24693286 100644 --- a/mock_frame_source_test.go +++ b/mock_frame_source_test.go @@ -8,8 +8,8 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + ackhandler "github.com/lucas-clemente/quic-go/internal/ackhandler" protocol "github.com/lucas-clemente/quic-go/internal/protocol" - wire "github.com/lucas-clemente/quic-go/internal/wire" ) // MockFrameSource is a mock of FrameSource interface @@ -36,10 +36,10 @@ func (m *MockFrameSource) EXPECT() *MockFrameSourceMockRecorder { } // AppendControlFrames mocks base method -func (m *MockFrameSource) AppendControlFrames(arg0 []wire.Frame, arg1 protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { +func (m *MockFrameSource) AppendControlFrames(arg0 []ackhandler.Frame, arg1 protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AppendControlFrames", arg0, arg1) - ret0, _ := ret[0].([]wire.Frame) + ret0, _ := ret[0].([]ackhandler.Frame) ret1, _ := ret[1].(protocol.ByteCount) return ret0, ret1 } @@ -51,10 +51,10 @@ func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1 interface{ } // AppendStreamFrames mocks base method -func (m *MockFrameSource) AppendStreamFrames(arg0 []wire.Frame, arg1 protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { +func (m *MockFrameSource) AppendStreamFrames(arg0 []ackhandler.Frame, arg1 protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AppendStreamFrames", arg0, arg1) - ret0, _ := ret[0].([]wire.Frame) + ret0, _ := ret[0].([]ackhandler.Frame) ret1, _ := ret[1].(protocol.ByteCount) return ret0, ret1 } diff --git a/packet_packer.go b/packet_packer.go index 898d42b3..999fa610 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -29,7 +29,7 @@ type sealer interface { } type payload struct { - frames []wire.Frame + frames []ackhandler.Frame ack *wire.AckFrame length protocol.ByteCount } @@ -38,7 +38,7 @@ type packedPacket struct { header *wire.ExtendedHeader raw []byte ack *wire.AckFrame - frames []wire.Frame + frames []ackhandler.Frame buffer *packetBuffer } @@ -67,23 +67,20 @@ func (p *packedPacket) ToAckHandlerPacket(q *retransmissionQueue) *ackhandler.Pa largestAcked = p.ack.LargestAcked() } encLevel := p.EncryptionLevel() - frames := make([]ackhandler.Frame, len(p.frames)) - for i, f := range p.frames { - frame := f - frames[i].Frame = frame + for i := range p.frames { switch encLevel { case protocol.EncryptionInitial: - frames[i].OnLost = q.AddInitial + p.frames[i].OnLost = q.AddInitial case protocol.EncryptionHandshake: - frames[i].OnLost = q.AddHandshake + p.frames[i].OnLost = q.AddHandshake case protocol.Encryption1RTT: - frames[i].OnLost = q.AddAppData + p.frames[i].OnLost = q.AddAppData } } return &ackhandler.Packet{ PacketNumber: p.header.PacketNumber, LargestAcked: largestAcked, - Frames: frames, + Frames: p.frames, Length: protocol.ByteCount(len(p.raw)), EncryptionLevel: encLevel, SendTime: time.Now(), @@ -119,8 +116,8 @@ type sealingManager interface { } type frameSource interface { - AppendStreamFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) - AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) + AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) + AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) } type ackFrameSource interface { @@ -187,7 +184,7 @@ func newPacketPacker( // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) { payload := payload{ - frames: []wire.Frame{ccf}, + frames: []ackhandler.Frame{{Frame: ccf}}, length: ccf.Length(p.version), } // send the CONNECTION_CLOSE frame with the highest available encryption level @@ -287,7 +284,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { if len(payload.frames) == 0 { // the packet only contains an ACK if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { ping := &wire.PingFrame{} - payload.frames = append(payload.frames, ping) + payload.frames = append(payload.frames, ackhandler.Frame{Frame: ping}) payload.length += ping.Length(p.version) p.numNonAckElicitingAcks = 0 } else { @@ -361,12 +358,12 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { if f == nil { break } - payload.frames = append(payload.frames, f) + payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) payload.length += f.Length(p.version) } } else if hasData { cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length) - payload.frames = []wire.Frame{cf} + payload.frames = []ackhandler.Frame{{Frame: cf}} payload.length += cf.Length(p.version) } return p.writeAndSealPacket(hdr, payload, encLevel, sealer) @@ -389,7 +386,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) (paylo if f == nil { break } - payload.frames = append(payload.frames, f) + payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) payload.length += f.Length(p.version) } diff --git a/packet_packer_test.go b/packet_packer_test.go index d9309c5f..db6b30a6 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -7,6 +7,8 @@ import ( "net" "time" + "github.com/lucas-clemente/quic-go/internal/ackhandler" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/mocks" @@ -41,22 +43,22 @@ var _ = Describe("Packet packer", func() { ExpectWithOffset(1, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen))) } - appendFrames := func(fs, frames []wire.Frame) ([]wire.Frame, protocol.ByteCount) { + appendFrames := func(fs, frames []ackhandler.Frame) ([]ackhandler.Frame, protocol.ByteCount) { var length protocol.ByteCount for _, f := range frames { - length += f.Length(packer.version) + length += f.Frame.Length(packer.version) } return append(fs, frames...), length } - expectAppendStreamFrames := func(frames ...wire.Frame) { - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + expectAppendStreamFrames := func(frames ...ackhandler.Frame) { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { return appendFrames(fs, frames) }) } - expectAppendControlFrames := func(frames ...wire.Frame) { - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + expectAppendControlFrames := func(frames ...ackhandler.Frame) { + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { return appendFrames(fs, frames) }) } @@ -189,11 +191,11 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() f := &wire.StreamFrame{Data: []byte{0xde, 0xca, 0xfb, 0xad}} - expectAppendStreamFrames(f) + expectAppendStreamFrames(ackhandler.Frame{Frame: f}) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.frames).To(Equal([]wire.Frame{f})) + Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: f}})) hdrRawEncrypted := append([]byte{}, hdrRaw...) hdrRawEncrypted[0] ^= 0xff hdrRawEncrypted[len(hdrRaw)-2] ^= 0xff @@ -288,13 +290,13 @@ var _ = Describe("Packet packer", func() { StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, } - expectAppendStreamFrames(f) + expectAppendStreamFrames(ackhandler.Frame{Frame: f}) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b := &bytes.Buffer{} f.Write(b, packer.version) - Expect(p.frames).To(Equal([]wire.Frame{f})) + Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: f}})) Expect(p.raw).To(ContainSubstring(b.String())) }) @@ -304,10 +306,10 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() - expectAppendStreamFrames(&wire.StreamFrame{ + expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), - }) + }}) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) @@ -339,7 +341,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackConnectionClose(&ccf) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(Equal(&ccf)) + Expect(p.frames[0].Frame).To(Equal(&ccf)) }) It("packs control frames", func() { @@ -347,7 +349,10 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - frames := []wire.Frame{&wire.ResetStreamFrame{}, &wire.MaxDataFrame{}} + frames := []ackhandler.Frame{ + {Frame: &wire.ResetStreamFrame{}}, + {Frame: &wire.MaxDataFrame{}}, + } expectAppendControlFrames(frames...) expectAppendStreamFrames() p, err := packer.PackPacket() @@ -363,11 +368,11 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) var maxSize protocol.ByteCount gomock.InOrder( - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { maxSize = maxLen return fs, 444 }), - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Do(func(fs []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Do(func(fs []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { Expect(maxLen).To(Equal(maxSize - 444)) return fs, 0 }), @@ -386,8 +391,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()) - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Return([]wire.Frame{f}, f.Length(packer.version)) + expectAppendControlFrames() + expectAppendStreamFrames(ackhandler.Frame{Frame: f}) packet, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added @@ -432,14 +437,14 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() - expectAppendStreamFrames(f1, f2, f3) + expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3}) p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(3)) - Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) - Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) - Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) + Expect(p.frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) + Expect(p.frames[1].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) + Expect(p.frames[2].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) }) It("adds retransmissions", func() { @@ -453,16 +458,16 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() f2 := &wire.StreamFrame{Data: []byte("frame 2")} - expectAppendStreamFrames(f2) + expectAppendStreamFrames(ackhandler.Frame{Frame: f2}) p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(3)) - Expect(p.frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) - Expect(p.frames[1]).To(Equal(cf)) - Expect(p.frames[2]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) + Expect(p.frames[0].Frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + Expect(p.frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) + Expect(p.frames[1].Frame).To(Equal(cf)) + Expect(p.frames[2].Frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + Expect(p.frames[2].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) }) Context("making ACK packets ack-eliciting", func() { @@ -493,7 +498,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(ContainElement(&wire.PingFrame{})) + Expect(p.frames).To(ContainElement(ackhandler.Frame{Frame: &wire.PingFrame{}})) // make sure the next packet doesn't contain another PING pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) @@ -530,7 +535,7 @@ var _ = Describe("Packet packer", func() { p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.ack).To(Equal(ack)) - Expect(p.frames).To(Equal([]wire.Frame{&wire.PingFrame{}})) + Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PingFrame{}}})) }) It("doesn't send a PING if it already sent another ack-eliciting frame", func() { @@ -540,7 +545,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendStreamFrames() - expectAppendControlFrames(&wire.MaxDataFrame{}) + expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -554,7 +559,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Times(2) var initialMaxPacketSize protocol.ByteCount - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { initialMaxPacketSize = maxLen return nil, 0 }) @@ -565,7 +570,7 @@ var _ = Describe("Packet packer", func() { packer.HandleTransportParameters(&handshake.TransportParameters{ MaxPacketSize: maxPacketSize - 10, }) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { Expect(maxLen).To(Equal(initialMaxPacketSize - 10)) return nil, 0 }) @@ -579,7 +584,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Times(2) var initialMaxPacketSize protocol.ByteCount - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { initialMaxPacketSize = maxLen return nil, 0 }) @@ -590,7 +595,7 @@ var _ = Describe("Packet packer", func() { packer.HandleTransportParameters(&handshake.TransportParameters{ MaxPacketSize: maxPacketSize + 10, }) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { Expect(maxLen).To(Equal(initialMaxPacketSize)) return nil, 0 }) @@ -656,7 +661,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.frames).To(Equal([]wire.Frame{f})) + Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: f}})) Expect(p.header.IsLongHeader).To(BeTrue()) checkLength(p.raw) }) @@ -706,7 +711,7 @@ var _ = Describe("Packet packer", func() { Expect(packet.header.Token).To(Equal(token)) Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize)) Expect(packet.frames).To(HaveLen(1)) - cf := packet.frames[0].(*wire.CryptoFrame) + cf := packet.frames[0].Frame.(*wire.CryptoFrame) Expect(cf.Data).To(Equal([]byte("foobar"))) }) @@ -756,7 +761,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - expectAppendControlFrames(&wire.PingFrame{}) + expectAppendControlFrames(ackhandler.Frame{Frame: &wire.PingFrame{}}) expectAppendStreamFrames() packet, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -767,7 +772,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43)) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - expectAppendControlFrames(&wire.PingFrame{}) + expectAppendControlFrames(ackhandler.Frame{Frame: &wire.PingFrame{}}) expectAppendStreamFrames() packet, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -779,19 +784,15 @@ var _ = Describe("Packet packer", func() { var _ = Describe("Converting to AckHandler packets", func() { It("convert a packet", func() { - f1 := &wire.MaxDataFrame{} - f2 := &wire.PingFrame{} packet := &packedPacket{ header: &wire.ExtendedHeader{Header: wire.Header{}}, - frames: []wire.Frame{f1, f2}, + frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}}}, ack: &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100, Smallest: 80}}}, raw: []byte("foobar"), } p := packet.ToAckHandlerPacket(nil) Expect(p.Length).To(Equal(protocol.ByteCount(6))) - Expect(p.Frames).To(HaveLen(2)) - Expect(p.Frames[0].Frame).To(Equal(f1)) - Expect(p.Frames[1].Frame).To(Equal(f2)) + Expect(p.Frames).To(Equal(packet.frames)) Expect(p.LargestAcked).To(Equal(protocol.PacketNumber(100))) Expect(p.SendTime).To(BeTemporally("~", time.Now(), 50*time.Millisecond)) }) @@ -799,10 +800,20 @@ var _ = Describe("Converting to AckHandler packets", func() { It("sets the LargestAcked to invalid, if the packet doesn't have an ACK frame", func() { packet := &packedPacket{ header: &wire.ExtendedHeader{Header: wire.Header{}}, - frames: []wire.Frame{&wire.MaxDataFrame{}, &wire.PingFrame{}}, + frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}}}, raw: []byte("foobar"), } p := packet.ToAckHandlerPacket(nil) Expect(p.LargestAcked).To(Equal(protocol.InvalidPacketNumber)) }) + + It("sets the OnLost callback", func() { + packet := &packedPacket{ + header: &wire.ExtendedHeader{Header: wire.Header{}}, + frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}}, + raw: []byte("foobar"), + } + p := packet.ToAckHandlerPacket(newRetransmissionQueue(protocol.VersionTLS)) + Expect(p.Frames[0].OnLost).ToNot(BeNil()) + }) }) diff --git a/session.go b/session.go index cd9db0fb..19e2cc3a 100644 --- a/session.go +++ b/session.go @@ -1198,7 +1198,8 @@ func (s *session) sendPackedPacket(packet *packedPacket) { EncryptionLevel: packet.EncryptionLevel(), PacketNumber: packet.header.PacketNumber, PacketSize: protocol.ByteCount(len(packet.raw)), - Frames: packet.frames, + // TODO: trace frames + // Frames: packet.frames, }) } s.logPacket(packet) diff --git a/session_test.go b/session_test.go index ddd708ec..1c5ed3ef 100644 --- a/session_test.go +++ b/session_test.go @@ -309,7 +309,7 @@ var _ = Describe("Session", func() { err := sess.handleFrame(&wire.PathChallengeFrame{Data: data}, 0, protocol.EncryptionUnspecified) Expect(err).ToNot(HaveOccurred()) frames, _ := sess.framer.AppendControlFrames(nil, 1000) - Expect(frames).To(Equal([]wire.Frame{&wire.PathResponseFrame{Data: data}})) + Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PathResponseFrame{Data: data}}})) }) It("rejects NEW_TOKEN frames", func() { @@ -889,7 +889,7 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) Expect(sent).To(BeTrue()) frames, _ := sess.framer.AppendControlFrames(nil, 1000) - Expect(frames).To(Equal([]wire.Frame{&wire.DataBlockedFrame{DataLimit: 1337}})) + Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.DataBlockedFrame{DataLimit: 1337}}})) }) It("sends a probe packet", func() { @@ -923,7 +923,7 @@ var _ = Describe("Session", func() { // We're using a mock packet packer in this test. // We therefore need to test separately that the PING was actually queued. frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) - Expect(frames).To(Equal([]wire.Frame{&wire.PingFrame{}})) + Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PingFrame{}}})) }) It("doesn't send when the SentPacketHandler doesn't allow it", func() {