diff --git a/internal/ackhandler/frame.go b/internal/ackhandler/frame.go index 83b09ceb..5731c2bc 100644 --- a/internal/ackhandler/frame.go +++ b/internal/ackhandler/frame.go @@ -5,4 +5,5 @@ import "github.com/lucas-clemente/quic-go/internal/wire" type Frame struct { wire.Frame // nil if the frame has already been acknowledged in another packet OnLost func(wire.Frame) + OnAcked func() } diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 6a339b4c..4b481563 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -444,6 +444,11 @@ func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error { return nil } + for _, f := range p.Frames { + if f.OnAcked != nil { + f.OnAcked() + } + } if p.includedInBytesInFlight { h.bytesInFlight -= p.Length } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 74b81214..af1ee4a8 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -48,8 +48,10 @@ var _ = Describe("SentPacketHandler", func() { if p.SendTime.IsZero() { p.SendTime = time.Now() } - p.Frames = []Frame{ - {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, p.PacketNumber) }}, + if len(p.Frames) == 0 { + p.Frames = []Frame{ + {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, p.PacketNumber) }}, + } } return p } @@ -198,6 +200,17 @@ var _ = Describe("SentPacketHandler", func() { expectInPacketHistoryOrLost([]protocol.PacketNumber{1, 2, 3, 4, 5, 6, 7, 8, 9}, protocol.Encryption1RTT) }) + It("calls the OnAcked callback", func() { + var acked bool + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 13, + Frames: []Frame{{Frame: &wire.PingFrame{}, OnAcked: func() { acked = true }}}, + })) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} + Expect(handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())).To(Succeed()) + Expect(acked).To(BeTrue()) + }) + It("handles an ACK frame with one missing packet range", func() { ack := &wire.AckFrame{ // lose 4 and 5 AckRanges: []wire.AckRange{