diff --git a/internal/wire/stop_waiting_frame.go b/internal/wire/stop_waiting_frame.go index 4f7a1c8b..48fbd44a 100644 --- a/internal/wire/stop_waiting_frame.go +++ b/internal/wire/stop_waiting_frame.go @@ -22,7 +22,10 @@ var ( errPacketNumberLenNotSet = errors.New("StopWaitingFrame: PacketNumberLen not set") ) -func (f *StopWaitingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { +func (f *StopWaitingFrame) Write(b *bytes.Buffer, v protocol.VersionNumber) error { + if v.UsesIETFFrameFormat() { + return errors.New("STOP_WAITING not defined in IETF QUIC") + } // make sure the PacketNumber was set if f.PacketNumber == protocol.PacketNumber(0) { return errPacketNumberNotSet diff --git a/internal/wire/stop_waiting_frame_test.go b/internal/wire/stop_waiting_frame_test.go index fb0f61c3..a46ddd98 100644 --- a/internal/wire/stop_waiting_frame_test.go +++ b/internal/wire/stop_waiting_frame_test.go @@ -84,7 +84,7 @@ var _ = Describe("StopWaitingFrame", func() { LeastUnacked: 10, PacketNumberLen: protocol.PacketNumberLen1, } - err := frame.Write(b, protocol.VersionWhatever) + err := frame.Write(b, versionBigEndian) Expect(err).To(MatchError(errPacketNumberNotSet)) }) @@ -94,7 +94,7 @@ var _ = Describe("StopWaitingFrame", func() { LeastUnacked: 10, PacketNumber: 13, } - err := frame.Write(b, protocol.VersionWhatever) + err := frame.Write(b, versionBigEndian) Expect(err).To(MatchError(errPacketNumberLenNotSet)) }) @@ -105,10 +105,21 @@ var _ = Describe("StopWaitingFrame", func() { PacketNumber: 5, PacketNumberLen: protocol.PacketNumberLen1, } - err := frame.Write(b, protocol.VersionWhatever) + err := frame.Write(b, versionBigEndian) Expect(err).To(MatchError(errLeastUnackedHigherThanPacketNumber)) }) + It("refuses to write for IETF QUIC", func() { + b := &bytes.Buffer{} + frame := &StopWaitingFrame{ + LeastUnacked: 10, + PacketNumber: 13, + PacketNumberLen: protocol.PacketNumberLen6, + } + err := frame.Write(b, versionIETFFrames) + Expect(err).To(MatchError("STOP_WAITING not defined in IETF QUIC")) + }) + Context("LeastUnackedDelta length", func() { Context("in big endian", func() { It("writes a 1-byte LeastUnackedDelta", func() { @@ -179,7 +190,7 @@ var _ = Describe("StopWaitingFrame", func() { }) Context("self consistency", func() { - It("reads a stop waiting frame that it wrote", func() { + It("reads a STOP_WAITING frame that it wrote", func() { packetNumber := protocol.PacketNumber(13) frame := &StopWaitingFrame{ LeastUnacked: 10, @@ -187,9 +198,9 @@ var _ = Describe("StopWaitingFrame", func() { PacketNumberLen: protocol.PacketNumberLen4, } b := &bytes.Buffer{} - err := frame.Write(b, protocol.VersionWhatever) + err := frame.Write(b, versionBigEndian) Expect(err).ToNot(HaveOccurred()) - readframe, err := ParseStopWaitingFrame(bytes.NewReader(b.Bytes()), packetNumber, protocol.PacketNumberLen4, protocol.VersionWhatever) + readframe, err := ParseStopWaitingFrame(bytes.NewReader(b.Bytes()), packetNumber, protocol.PacketNumberLen4, versionBigEndian) Expect(err).ToNot(HaveOccurred()) Expect(readframe.LeastUnacked).To(Equal(frame.LeastUnacked)) }) diff --git a/packet_packer.go b/packet_packer.go index ff4c0919..01681d2a 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -77,7 +77,7 @@ func (p *packetPacker) PackAckPacket() (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) frames := []wire.Frame{p.ackFrame} - if p.stopWaiting != nil { + if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC p.stopWaiting.PacketNumber = header.PacketNumber p.stopWaiting.PacketNumberLen = header.PacketNumberLen frames = append(frames, p.stopWaiting) @@ -102,14 +102,20 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (* if err != nil { return nil, err } - if p.stopWaiting == nil { - return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame") - } header := p.getHeader(packet.EncryptionLevel) - p.stopWaiting.PacketNumber = header.PacketNumber - p.stopWaiting.PacketNumberLen = header.PacketNumberLen - frames := append([]wire.Frame{p.stopWaiting}, packet.Frames...) - p.stopWaiting = nil + var frames []wire.Frame + if !p.version.UsesIETFFrameFormat() { // for gQUIC: pack a STOP_WAITING first + if p.stopWaiting == nil { + return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame") + } + swf := p.stopWaiting + swf.PacketNumber = header.PacketNumber + swf.PacketNumberLen = header.PacketNumberLen + p.stopWaiting = nil + frames = append([]wire.Frame{swf}, packet.Frames...) + } else { + frames = packet.Frames + } raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ header: header, @@ -217,7 +223,7 @@ func (p *packetPacker) composeNextPacket( l := p.ackFrame.MinLength(p.version) payloadLength += l } - if p.stopWaiting != nil { + if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC payloadFrames = append(payloadFrames, p.stopWaiting) payloadLength += p.stopWaiting.MinLength(p.version) } diff --git a/packet_packer_test.go b/packet_packer_test.go index 179f674c..454fe37b 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -782,8 +782,19 @@ var _ = Describe("Packet packer", func() { } p, err := packer.PackHandshakeRetransmission(packet) Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(ContainElement(sf)) - Expect(p.frames).To(ContainElement(swf)) + Expect(p.frames).To(Equal([]wire.Frame{swf, sf})) + Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) + }) + + It("doesn't add a STOP_WAITING frame for IETF QUIC", func() { + packer.version = versionIETFFrames + packet := &ackhandler.Packet{ + EncryptionLevel: protocol.EncryptionUnencrypted, + Frames: []wire.Frame{sf}, + } + p, err := packer.PackHandshakeRetransmission(packet) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(Equal([]wire.Frame{sf})) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) }) @@ -796,8 +807,7 @@ var _ = Describe("Packet packer", func() { } p, err := packer.PackHandshakeRetransmission(packet) Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(ContainElement(sf)) - Expect(p.frames).To(ContainElement(swf)) + Expect(p.frames).To(Equal([]wire.Frame{swf, sf})) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) // a packet sent by the server with initial encryption contains the SHLO // it needs to have a diversification nonce @@ -871,7 +881,7 @@ var _ = Describe("Packet packer", func() { _, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{ EncryptionLevel: protocol.EncryptionSecure, }) - Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")) + Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame")) }) }) @@ -883,7 +893,7 @@ var _ = Describe("Packet packer", func() { Expect(p.frames).To(Equal([]wire.Frame{&wire.AckFrame{DelayTime: math.MaxInt64}})) }) - It("packs ACK packets with SWFs", func() { + It("packs ACK packets with STOP_WAITING frames", func() { packer.QueueControlFrame(&wire.AckFrame{}) packer.QueueControlFrame(&wire.StopWaitingFrame{}) p, err := packer.PackAckPacket() diff --git a/packet_unpacker.go b/packet_unpacker.go index 9d09c373..f97215e3 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -106,13 +106,7 @@ func (u *packetUnpacker) parseIETFFrame(r *bytes.Reader, typeByte byte, hdr *wir if err != nil { err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) } - case 0x6: - // TODO(#964): remove STOP_WAITING frames - // TODO(#878): implement the MAX_STREAM_ID frame - frame, err = wire.ParseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidStopWaitingData, err.Error()) - } + // TODO(#878): implement the MAX_STREAM_ID frame case 0x7: frame, err = wire.ParsePingFrame(r, u.version) case 0x8: diff --git a/session.go b/session.go index 4b732c1f..b880181f 100644 --- a/session.go +++ b/session.go @@ -716,9 +716,10 @@ func (s *session) sendPacket() error { return nil } // If we aren't allowed to send, at least try sending an ACK frame - swf := s.sentPacketHandler.GetStopWaitingFrame(false) - if swf != nil { - s.packer.QueueControlFrame(swf) + if !s.version.UsesIETFFrameFormat() { + if swf := s.sentPacketHandler.GetStopWaitingFrame(false); swf != nil { + s.packer.QueueControlFrame(swf) + } } packet, err := s.packer.PackAckPacket() if err != nil { @@ -740,7 +741,9 @@ func (s *session) sendPacket() error { continue } utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) - s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) + if !s.version.UsesIETFFrameFormat() { + s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) + } packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket) if err != nil { return err @@ -764,9 +767,8 @@ func (s *session) sendPacket() error { } hasRetransmission := s.streamFramer.HasFramesForRetransmission() - if ack != nil || hasRetransmission { - swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) - if swf != nil { + if !s.version.UsesIETFFrameFormat() && (ack != nil || hasRetransmission) { + if swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission); swf != nil { s.packer.QueueControlFrame(swf) } } diff --git a/session_test.go b/session_test.go index 9e6ecece..0001e7f2 100644 --- a/session_test.go +++ b/session_test.go @@ -761,19 +761,39 @@ var _ = Describe("Session", func() { }) It("sends ACK frames when congestion limited", func() { + swf := &wire.StopWaitingFrame{LeastUnacked: 10} sph := mocks.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().SendingAllowed().Return(false) - sph.EXPECT().GetStopWaitingFrame(false) - sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().GetStopWaitingFrame(false).Return(swf) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(HaveLen(2)) + Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) + Expect(p.Frames[1]).To(Equal(swf)) + }) sess.sentPacketHandler = sph sess.packer.packetNumberGenerator.next = 0x1338 - packetNumber := protocol.PacketNumber(0x035e) - sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) + sess.receivedPacketHandler.ReceivedPacket(1, true) + err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(mconn.written).To(HaveLen(1)) + }) + + It("doesn't include a STOP_WAITING for an ACK-only packet for IETF QUIC", func() { + sess.version = versionIETFFrames + sph := mocks.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().SendingAllowed().Return(false) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(HaveLen(1)) + Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) + }) + sess.sentPacketHandler = sph + sess.packer.packetNumberGenerator.next = 0x1338 + sess.receivedPacketHandler.ReceivedPacket(1, true) err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) - Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e})))) }) It("sends a retransmittable packet when required by the SentPacketHandler", func() { @@ -846,27 +866,41 @@ var _ = Describe("Session", func() { }) Context("for handshake packets", func() { - It("retransmits an unencrypted packet", func() { + It("retransmits an unencrypted packet, and adds a STOP_WAITING frame (for gQUIC)", func() { sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} - var sentPacket *ackhandler.Packet - sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{LeastUnacked: 0x1337}) - sph.EXPECT().DequeuePacketForRetransmission().Return( - &ackhandler.Packet{ - Frames: []wire.Frame{sf}, - EncryptionLevel: protocol.EncryptionUnencrypted, - }) + swf := &wire.StopWaitingFrame{LeastUnacked: 0x1337} + sph.EXPECT().GetStopWaitingFrame(true).Return(swf) + sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ + Frames: []wire.Frame{sf}, + EncryptionLevel: protocol.EncryptionUnencrypted, + }) sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - sentPacket = p + Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) + Expect(p.Frames).To(Equal([]wire.Frame{swf, sf})) + }) + err := sess.sendPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(mconn.written).To(HaveLen(1)) + + }) + + It("retransmits an unencrypted packet, and doesn't add a STOP_WAITING frame (for IETF QUIC)", func() { + sess.version = versionIETFFrames + sess.packer.version = versionIETFFrames + sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} + sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ + Frames: []wire.Frame{sf}, + EncryptionLevel: protocol.EncryptionUnencrypted, + }) + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) + Expect(p.Frames).To(Equal([]wire.Frame{sf})) }) err := sess.sendPacket() Expect(err).ToNot(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) - Expect(sentPacket.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - Expect(sentPacket.Frames).To(HaveLen(2)) - Expect(sentPacket.Frames[1]).To(Equal(sf)) - swf := sentPacket.Frames[0].(*wire.StopWaitingFrame) - Expect(swf.LeastUnacked).To(Equal(protocol.PacketNumber(0x1337))) }) It("doesn't retransmit handshake packets when the handshake is complete", func() { @@ -885,25 +919,49 @@ var _ = Describe("Session", func() { }) Context("for packets after the handshake", func() { - It("sends a StreamFrame from a packet queued for retransmission", func() { - f := wire.StreamFrame{ + It("sends a STREAM frame from a packet queued for retransmission, and adds a STOP_WAITING (for gQUIC)", func() { + f := &wire.StreamFrame{ StreamID: 0x5, - Data: []byte("foobar1234567"), + Data: []byte("foobar"), } - sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{}) - sph.EXPECT().DequeuePacketForRetransmission().Return( - &ackhandler.Packet{ - PacketNumber: 0x1337, - Frames: []wire.Frame{&f}, - EncryptionLevel: protocol.EncryptionForwardSecure, - }) + swf := &wire.StopWaitingFrame{LeastUnacked: 10} + sph.EXPECT().GetStopWaitingFrame(true).Return(swf) + sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ + PacketNumber: 0x1337, + Frames: []wire.Frame{f}, + EncryptionLevel: protocol.EncryptionForwardSecure, + }) sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(Equal([]wire.Frame{swf, f})) + Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) + }) + sph.EXPECT().SendingAllowed() + err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(mconn.written).To(HaveLen(1)) + }) + + It("sends a STREAM frame from a packet queued for retransmission, and doesn't add a STOP_WAITING (for IETF QUIC)", func() { + sess.version = versionIETFFrames + sess.packer.version = versionIETFFrames + f := &wire.StreamFrame{ + StreamID: 0x5, + Data: []byte("foobar"), + } + sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ + Frames: []wire.Frame{f}, + EncryptionLevel: protocol.EncryptionForwardSecure, + }) + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(Equal([]wire.Frame{f})) + Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) + }) sph.EXPECT().SendingAllowed() err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) - Expect(mconn.written).To(Receive(ContainSubstring("foobar1234567"))) }) It("sends a StreamFrame from a packet queued for retransmission", func() {