diff --git a/internal/ackhandler/frame.go b/internal/ackhandler/frame.go index 8c284e20..e03a8080 100644 --- a/internal/ackhandler/frame.go +++ b/internal/ackhandler/frame.go @@ -4,16 +4,18 @@ import ( "github.com/quic-go/quic-go/internal/wire" ) +// FrameHandler handles the acknowledgement and the loss of a frame. +type FrameHandler interface { + OnAcked(wire.Frame) + OnLost(wire.Frame) +} + type Frame struct { Frame wire.Frame // nil if the frame has already been acknowledged in another packet - OnLost func(wire.Frame) - OnAcked func(wire.Frame) + Handler FrameHandler } type StreamFrame struct { Frame *wire.StreamFrame - Handler interface { - OnLost(*wire.StreamFrame) - OnAcked(*wire.StreamFrame) - } + Handler FrameHandler } diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 4451b0e5..6d97792e 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -413,12 +413,14 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL } for _, f := range p.Frames { - if f.OnAcked != nil { - f.OnAcked(f.Frame) + if f.Handler != nil { + f.Handler.OnAcked(f.Frame) } } for _, f := range p.StreamFrames { - f.Handler.OnAcked(f.Frame) + if f.Handler != nil { + f.Handler.OnAcked(f.Frame) + } } if err := pnSpace.history.Remove(p.PacketNumber); err != nil { return nil, err @@ -795,10 +797,14 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) { panic("no frames") } for _, f := range p.Frames { - f.OnLost(f.Frame) + if f.Handler != nil { + f.Handler.OnLost(f.Frame) + } } for _, f := range p.StreamFrames { - f.Handler.OnLost(f.Frame) + if f.Handler != nil { + f.Handler.OnLost(f.Frame) + } } p.StreamFrames = nil p.Frames = nil diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index e12c42fa..c98eae25 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -16,6 +16,22 @@ import ( . "github.com/onsi/gomega" ) +type customFrameHandler struct { + onLost, onAcked func(wire.Frame) +} + +func (h *customFrameHandler) OnLost(f wire.Frame) { + if h.onLost != nil { + h.onLost(f) + } +} + +func (h *customFrameHandler) OnAcked(f wire.Frame) { + if h.onAcked != nil { + h.onAcked(f) + } +} + var _ = Describe("SentPacketHandler", func() { var ( handler *sentPacketHandler @@ -57,7 +73,9 @@ var _ = Describe("SentPacketHandler", func() { } if len(p.Frames) == 0 { p.Frames = []Frame{ - {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, p.PacketNumber) }}, + {Frame: &wire.PingFrame{}, Handler: &customFrameHandler{ + onLost: func(wire.Frame) { lostPackets = append(lostPackets, p.PacketNumber) }, + }}, } } return p @@ -280,9 +298,12 @@ var _ = Describe("SentPacketHandler", func() { handler.SentPacket(ackElicitingPacket(&Packet{ PacketNumber: 10, Frames: []Frame{{ - Frame: ping, OnAcked: func(f wire.Frame) { - Expect(f).To(Equal(ping)) - acked = true + Frame: ping, + Handler: &customFrameHandler{ + onAcked: func(f wire.Frame) { + Expect(f).To(Equal(ping)) + acked = true + }, }, }}, })) @@ -431,20 +452,20 @@ var _ = Describe("SentPacketHandler", func() { { PacketNumber: 10, LargestAcked: 100, - Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, + Frames: []Frame{{Frame: &streamFrame}}, Length: 1, EncryptionLevel: protocol.Encryption1RTT, }, { PacketNumber: 11, LargestAcked: 200, - Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, + Frames: []Frame{{Frame: &streamFrame}}, Length: 1, EncryptionLevel: protocol.Encryption1RTT, }, { PacketNumber: 12, - Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, + Frames: []Frame{{Frame: &streamFrame}}, Length: 1, EncryptionLevel: protocol.Encryption1RTT, }, @@ -504,7 +525,7 @@ var _ = Describe("SentPacketHandler", func() { handler.SentPacket(&Packet{ PacketNumber: 1, Length: 42, - Frames: []Frame{{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) {}}}, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, EncryptionLevel: protocol.Encryption1RTT, }) }) @@ -551,7 +572,12 @@ var _ = Describe("SentPacketHandler", func() { PacketNumber: 1, SendTime: time.Now().Add(-time.Hour), IsPathMTUProbePacket: true, - Frames: []Frame{{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { mtuPacketDeclaredLost = true }}}, + Frames: []Frame{ + { + Frame: &wire.PingFrame{}, + Handler: &customFrameHandler{onLost: func(wire.Frame) { mtuPacketDeclaredLost = true }}, + }, + }, })) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) // lose packet 1, but don't EXPECT any calls to OnPacketLost() @@ -756,7 +782,10 @@ var _ = Describe("SentPacketHandler", func() { PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT), SendTime: time.Now().Add(-time.Hour), Frames: []Frame{ - {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, 1) }}, + { + Frame: &wire.PingFrame{}, + Handler: &customFrameHandler{onLost: func(wire.Frame) { lostPackets = append(lostPackets, 1) }}, + }, }, })) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) @@ -1148,7 +1177,12 @@ var _ = Describe("SentPacketHandler", func() { PacketNumber: 1, SendTime: now.Add(-3 * time.Second), IsPathMTUProbePacket: true, - Frames: []Frame{{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { mtuPacketDeclaredLost = true }}}, + Frames: []Frame{ + { + Frame: &wire.PingFrame{}, + Handler: &customFrameHandler{onLost: func(wire.Frame) { mtuPacketDeclaredLost = true }}, + }, + }, })) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-3 * time.Second)})) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} @@ -1335,7 +1369,10 @@ var _ = Describe("SentPacketHandler", func() { PacketNumber: 13, EncryptionLevel: protocol.EncryptionInitial, Frames: []Frame{ - {Frame: &wire.CryptoFrame{Data: []byte("foobar")}, OnLost: func(wire.Frame) { lostInitial = true }}, + { + Frame: &wire.CryptoFrame{Data: []byte("foobar")}, + Handler: &customFrameHandler{onLost: func(wire.Frame) { lostInitial = true }}, + }, }, Length: 100, }) @@ -1344,7 +1381,10 @@ var _ = Describe("SentPacketHandler", func() { PacketNumber: pn, EncryptionLevel: protocol.Encryption0RTT, Frames: []Frame{ - {Frame: &wire.StreamFrame{Data: []byte("foobar")}, OnLost: func(wire.Frame) { lost0RTT = true }}, + { + Frame: &wire.StreamFrame{Data: []byte("foobar")}, + Handler: &customFrameHandler{onLost: func(wire.Frame) { lost0RTT = true }}, + }, }, Length: 999, }) diff --git a/mtu_discoverer.go b/mtu_discoverer.go index 957abd37..317b0929 100644 --- a/mtu_discoverer.go +++ b/mtu_discoverer.go @@ -43,10 +43,10 @@ func getMaxPacketSize(addr net.Addr) protocol.ByteCount { type mtuFinder struct { lastProbeTime time.Time - probeInFlight bool mtuIncreased func(protocol.ByteCount) rttStats *utils.RTTStats + inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight current protocol.ByteCount max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer) } @@ -55,6 +55,7 @@ var _ mtuDiscoverer = &mtuFinder{} func newMTUDiscoverer(rttStats *utils.RTTStats, start protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder { return &mtuFinder{ + inFlight: protocol.InvalidByteCount, current: start, rttStats: rttStats, mtuIncreased: mtuIncreased, @@ -74,7 +75,7 @@ func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { if f.max == 0 || f.lastProbeTime.IsZero() { return false } - if f.probeInFlight || f.done() { + if f.inFlight != protocol.InvalidByteCount || f.done() { return false } return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT())) @@ -83,21 +84,36 @@ func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) { size := (f.max + f.current) / 2 f.lastProbeTime = time.Now() - f.probeInFlight = true + f.inFlight = size return ackhandler.Frame{ - Frame: &wire.PingFrame{}, - OnLost: func(wire.Frame) { - f.probeInFlight = false - f.max = size - }, - OnAcked: func(wire.Frame) { - f.probeInFlight = false - f.current = size - f.mtuIncreased(size) - }, + Frame: &wire.PingFrame{}, + Handler: (*mtuFinderAckHandler)(f), }, size } func (f *mtuFinder) CurrentSize() protocol.ByteCount { return f.current } + +type mtuFinderAckHandler mtuFinder + +var _ ackhandler.FrameHandler = &mtuFinderAckHandler{} + +func (h *mtuFinderAckHandler) OnAcked(wire.Frame) { + size := h.inFlight + if size == protocol.InvalidByteCount { + panic("OnAcked callback called although there's no MTU probe packet in flight") + } + h.inFlight = protocol.InvalidByteCount + h.current = size + h.mtuIncreased(size) +} + +func (h *mtuFinderAckHandler) OnLost(wire.Frame) { + size := h.inFlight + if size == protocol.InvalidByteCount { + panic("OnLost callback called although there's no MTU probe packet in flight") + } + h.max = size + h.inFlight = protocol.InvalidByteCount +} diff --git a/mtu_discoverer_test.go b/mtu_discoverer_test.go index c602ca30..6e01f570 100644 --- a/mtu_discoverer_test.go +++ b/mtu_discoverer_test.go @@ -43,14 +43,14 @@ var _ = Describe("MTU Discoverer", func() { It("doesn't allow a probe if another probe is still in flight", func() { ping, _ := d.GetPing() Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeFalse()) - ping.OnLost(ping.Frame) + ping.Handler.OnLost(ping.Frame) Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeTrue()) }) It("tries a lower size when a probe is lost", func() { ping, size := d.GetPing() Expect(size).To(Equal(protocol.ByteCount(1500))) - ping.OnLost(ping.Frame) + ping.Handler.OnLost(ping.Frame) _, size = d.GetPing() Expect(size).To(Equal(protocol.ByteCount(1250))) }) @@ -58,7 +58,7 @@ var _ = Describe("MTU Discoverer", func() { It("tries a higher size and calls the callback when a probe is acknowledged", func() { ping, size := d.GetPing() Expect(size).To(Equal(protocol.ByteCount(1500))) - ping.OnAcked(ping.Frame) + ping.Handler.OnAcked(ping.Frame) Expect(discoveredMTU).To(Equal(protocol.ByteCount(1500))) _, size = d.GetPing() Expect(size).To(Equal(protocol.ByteCount(1750))) @@ -69,7 +69,7 @@ var _ = Describe("MTU Discoverer", func() { t := now.Add(5 * rtt) for d.ShouldSendProbe(t) { ping, size := d.GetPing() - ping.OnAcked(ping.Frame) + ping.Handler.OnAcked(ping.Frame) sizes = append(sizes, size) t = t.Add(5 * rtt) } @@ -104,9 +104,9 @@ var _ = Describe("MTU Discoverer", func() { ping, size := d.GetPing() if size <= realMTU { - ping.OnAcked(ping.Frame) + ping.Handler.OnAcked(ping.Frame) } else { - ping.OnLost(ping.Frame) + ping.Handler.OnLost(ping.Frame) } t = t.Add(mtuProbeDelay * rtt) } diff --git a/packet_packer.go b/packet_packer.go index 8454c510..5d182cf5 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -37,6 +37,13 @@ type payload struct { length protocol.ByteCount } +type nullFrameHandler struct{} + +func (n nullFrameHandler) OnAcked(wire.Frame) {} +func (n nullFrameHandler) OnLost(wire.Frame) {} + +var doNothingFrameHandler ackhandler.FrameHandler = &nullFrameHandler{} + type longHeaderPacket struct { header *wire.ExtendedHeader ack *wire.AckFrame @@ -88,17 +95,17 @@ func (p *longHeaderPacket) ToAckHandlerPacket(now time.Time, q *retransmissionQu } encLevel := p.EncryptionLevel() for i := range p.frames { - if p.frames[i].OnLost != nil { + if p.frames[i].Handler != nil { continue } //nolint:exhaustive // Short header packets are handled separately. switch encLevel { case protocol.EncryptionInitial: - p.frames[i].OnLost = q.AddInitial + p.frames[i].Handler = q.InitialAckHandler() case protocol.EncryptionHandshake: - p.frames[i].OnLost = q.AddHandshake + p.frames[i].Handler = q.HandshakeAckHandler() case protocol.Encryption0RTT: - p.frames[i].OnLost = q.AddAppData + p.frames[i].Handler = q.AppDataAckHandler() } } @@ -605,8 +612,8 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { ping := &wire.PingFrame{} pl.frames = append(pl.frames, ackhandler.Frame{ - Frame: ping, - OnLost: func(wire.Frame) {}, // don't retransmit the PING frame when it is lost + Frame: ping, + Handler: doNothingFrameHandler, // don't retransmit the PING frame when it is lost }) pl.length += ping.Length(v) p.numNonAckElicitingAcks = 0 @@ -649,8 +656,9 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc size := f.Length(v) if size <= maxFrameSize-pl.length { pl.frames = append(pl.frames, ackhandler.Frame{ - Frame: f, - OnLost: func(wire.Frame) {}, // set it to a no-op. Then we won't set the default callback, which would retransmit the frame. + Frame: f, + // Set it to a no-op. Then we won't set the default callback, which would retransmit the frame. + Handler: doNothingFrameHandler, }) pl.length += size p.datagramQueue.Pop() @@ -891,10 +899,10 @@ func (p *packetPacker) appendShortHeaderPacket( largestAcked = pl.ack.LargestAcked() } for i := range pl.frames { - if pl.frames[i].OnLost != nil { + if pl.frames[i].Handler != nil { continue } - pl.frames[i].OnLost = p.retransmissionQueue.AddAppData + pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler() } ap := ackhandler.GetPacket() diff --git a/packet_packer_test.go b/packet_packer_test.go index ef04db9e..80d167d3 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -789,7 +789,7 @@ var _ = Describe("Packet packer", func() { for _, f := range p.Frames { if _, ok := f.Frame.(*wire.PingFrame); ok { hasPing = true - Expect(f.OnLost).ToNot(BeNil()) // make sure the PING is not retransmitted if lost + Expect(f.Handler.OnLost).ToNot(BeNil()) // make sure the PING is not retransmitted if lost } } Expect(hasPing).To(BeTrue()) @@ -835,7 +835,7 @@ var _ = Describe("Packet packer", func() { for _, f := range p.Frames { if _, ok := f.Frame.(*wire.PingFrame); ok { hasPing = true - Expect(f.OnLost).ToNot(BeNil()) // make sure the PING is not retransmitted if lost + Expect(f.Handler.OnLost).ToNot(BeNil()) // make sure the PING is not retransmitted if lost } } Expect(hasPing).To(BeTrue()) @@ -1507,26 +1507,4 @@ var _ = Describe("Converting to ackhandler.Packet", func() { p := packet.ToAckHandlerPacket(time.Now(), nil) Expect(p.LargestAcked).To(Equal(protocol.InvalidPacketNumber)) }) - - DescribeTable( - "doesn't overwrite the OnLost callback, if it is set", - func(hdr wire.Header) { - var pingLost bool - packet := &longHeaderPacket{ - header: &wire.ExtendedHeader{Header: hdr}, - frames: []ackhandler.Frame{ - {Frame: &wire.MaxDataFrame{}}, - {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { pingLost = true }}, - }, - } - p := packet.ToAckHandlerPacket(time.Now(), newRetransmissionQueue()) - Expect(p.Frames).To(HaveLen(2)) - Expect(p.Frames[0].OnLost).ToNot(BeNil()) - p.Frames[1].OnLost(nil) - Expect(pingLost).To(BeTrue()) - }, - Entry(protocol.EncryptionInitial.String(), wire.Header{Type: protocol.PacketTypeInitial}), - Entry(protocol.EncryptionHandshake.String(), wire.Header{Type: protocol.PacketTypeHandshake}), - Entry(protocol.Encryption0RTT.String(), wire.Header{Type: protocol.PacketType0RTT}), - ) }) diff --git a/retransmission_queue.go b/retransmission_queue.go index 2ce0b893..d5e844d0 100644 --- a/retransmission_queue.go +++ b/retransmission_queue.go @@ -3,6 +3,8 @@ package quic import ( "fmt" + "github.com/quic-go/quic-go/internal/ackhandler" + "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) @@ -127,3 +129,36 @@ func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) { panic(fmt.Sprintf("unexpected encryption level: %s", encLevel)) } } + +func (q *retransmissionQueue) InitialAckHandler() ackhandler.FrameHandler { + return (*retransmissionQueueInitialAckHandler)(q) +} + +func (q *retransmissionQueue) HandshakeAckHandler() ackhandler.FrameHandler { + return (*retransmissionQueueHandshakeAckHandler)(q) +} + +func (q *retransmissionQueue) AppDataAckHandler() ackhandler.FrameHandler { + return (*retransmissionQueueAppDataAckHandler)(q) +} + +type retransmissionQueueInitialAckHandler retransmissionQueue + +func (q *retransmissionQueueInitialAckHandler) OnAcked(wire.Frame) {} +func (q *retransmissionQueueInitialAckHandler) OnLost(f wire.Frame) { + (*retransmissionQueue)(q).AddInitial(f) +} + +type retransmissionQueueHandshakeAckHandler retransmissionQueue + +func (q *retransmissionQueueHandshakeAckHandler) OnAcked(wire.Frame) {} +func (q *retransmissionQueueHandshakeAckHandler) OnLost(f wire.Frame) { + (*retransmissionQueue)(q).AddHandshake(f) +} + +type retransmissionQueueAppDataAckHandler retransmissionQueue + +func (q *retransmissionQueueAppDataAckHandler) OnAcked(wire.Frame) {} +func (q *retransmissionQueueAppDataAckHandler) OnLost(f wire.Frame) { + (*retransmissionQueue)(q).AddAppData(f) +} diff --git a/retransmission_queue_test.go b/retransmission_queue_test.go index 78d181b3..0eaad18c 100644 --- a/retransmission_queue_test.go +++ b/retransmission_queue_test.go @@ -89,6 +89,13 @@ var _ = Describe("Retransmission queue", func() { Expect(q.HasInitialData()).To(BeFalse()) Expect(q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1)).To(BeNil()) }) + + It("retransmits a frame", func() { + f := &wire.MaxDataFrame{MaximumData: 0x42} + q.InitialAckHandler().OnLost(f) + Expect(q.HasInitialData()).To(BeTrue()) + Expect(q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1)).To(Equal(f)) + }) }) Context("Handshake data", func() { @@ -165,6 +172,13 @@ var _ = Describe("Retransmission queue", func() { Expect(q.HasHandshakeData()).To(BeFalse()) Expect(q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1)).To(BeNil()) }) + + It("retransmits a frame", func() { + f := &wire.MaxDataFrame{MaximumData: 0x42} + q.HandshakeAckHandler().OnLost(f) + Expect(q.HasHandshakeData()).To(BeTrue()) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1)).To(Equal(f)) + }) }) Context("Application data", func() { @@ -181,5 +195,12 @@ var _ = Describe("Retransmission queue", func() { Expect(q.GetAppDataFrame(f.Length(protocol.Version1), protocol.Version1)).To(Equal(f)) Expect(q.HasAppData()).To(BeFalse()) }) + + It("retransmits a frame", func() { + f := &wire.MaxDataFrame{MaximumData: 0x42} + q.AppDataAckHandler().OnLost(f) + Expect(q.HasAppData()).To(BeTrue()) + Expect(q.GetAppDataFrame(protocol.MaxByteCount, protocol.Version1)).To(Equal(f)) + }) }) }) diff --git a/send_stream.go b/send_stream.go index abe1067e..62ebe2ea 100644 --- a/send_stream.go +++ b/send_stream.go @@ -211,7 +211,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers } return ackhandler.StreamFrame{ Frame: f, - Handler: s, + Handler: (*sendStreamAckHandler)(s), }, true, hasMoreData } @@ -347,25 +347,6 @@ func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.By } } -func (s *sendStream) OnAcked(f *wire.StreamFrame) { - f.PutBack() - s.mutex.Lock() - if s.cancelWriteErr != nil { - s.mutex.Unlock() - return - } - s.numOutstandingFrames-- - if s.numOutstandingFrames < 0 { - panic("numOutStandingFrames negative") - } - newlyCompleted := s.isNewlyCompleted() - s.mutex.Unlock() - - if newlyCompleted { - s.sender.onStreamCompleted(s.streamID) - } -} - func (s *sendStream) isNewlyCompleted() bool { completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 if completed && !s.completed { @@ -375,23 +356,6 @@ func (s *sendStream) isNewlyCompleted() bool { return false } -func (s *sendStream) OnLost(f *wire.StreamFrame) { - s.mutex.Lock() - if s.cancelWriteErr != nil { - s.mutex.Unlock() - return - } - f.DataLenPresent = true - s.retransmissionQueue = append(s.retransmissionQueue, f) - s.numOutstandingFrames-- - if s.numOutstandingFrames < 0 { - panic("numOutStandingFrames negative") - } - s.mutex.Unlock() - - s.sender.onHasStreamData(s.streamID) -} - func (s *sendStream) Close() error { s.mutex.Lock() if s.closeForShutdownErr != nil { @@ -484,3 +448,45 @@ func (s *sendStream) signalWrite() { default: } } + +type sendStreamAckHandler sendStream + +var _ ackhandler.FrameHandler = &sendStreamAckHandler{} + +func (s *sendStreamAckHandler) OnAcked(f wire.Frame) { + sf := f.(*wire.StreamFrame) + sf.PutBack() + s.mutex.Lock() + if s.cancelWriteErr != nil { + s.mutex.Unlock() + return + } + s.numOutstandingFrames-- + if s.numOutstandingFrames < 0 { + panic("numOutStandingFrames negative") + } + newlyCompleted := (*sendStream)(s).isNewlyCompleted() + s.mutex.Unlock() + + if newlyCompleted { + s.sender.onStreamCompleted(s.streamID) + } +} + +func (s *sendStreamAckHandler) OnLost(f wire.Frame) { + sf := f.(*wire.StreamFrame) + s.mutex.Lock() + if s.cancelWriteErr != nil { + s.mutex.Unlock() + return + } + sf.DataLenPresent = true + s.retransmissionQueue = append(s.retransmissionQueue, sf) + s.numOutstandingFrames-- + if s.numOutstandingFrames < 0 { + panic("numOutStandingFrames negative") + } + s.mutex.Unlock() + + s.sender.onHasStreamData(s.streamID) +} diff --git a/send_stream_test.go b/send_stream_test.go index f2929159..c8624152 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -940,7 +940,7 @@ var _ = Describe("Send Stream", func() { DataLenPresent: false, } mockSender.EXPECT().onHasStreamData(streamID) - str.OnLost(f) + (*sendStreamAckHandler)(str).OnLost(f) frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) @@ -958,7 +958,7 @@ var _ = Describe("Send Stream", func() { DataLenPresent: false, } mockSender.EXPECT().onHasStreamData(streamID) - str.OnLost(sf) + (*sendStreamAckHandler)(str).OnLost(sf) frame, ok, hasMoreData := str.popStreamFrame(sf.Length(protocol.Version1)-3, protocol.Version1) Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) @@ -984,7 +984,7 @@ var _ = Describe("Send Stream", func() { DataLenPresent: false, } mockSender.EXPECT().onHasStreamData(streamID) - str.OnLost(f) + (*sendStreamAckHandler)(str).OnLost(f) _, ok, hasMoreData := str.popStreamFrame(2, protocol.Version1) Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeTrue())