diff --git a/ackhandler/interfaces.go b/ackhandler/interfaces.go index dbcdcb4c..b7715979 100644 --- a/ackhandler/interfaces.go +++ b/ackhandler/interfaces.go @@ -1,8 +1,6 @@ package ackhandler import ( - "time" - "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/protocol" ) @@ -10,13 +8,15 @@ import ( // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { SentPacket(packet *Packet) error - ReceivedAck(ackFrame *frames.AckFrame) (time.Duration, []*Packet, []*Packet, error) + ReceivedAck(ackFrame *frames.AckFrame) error HasPacketForRetransmission() bool DequeuePacketForRetransmission() (packet *Packet) BytesInFlight() protocol.ByteCount GetLargestObserved() protocol.PacketNumber + + AllowsSending() bool } // ReceivedPacketHandler handles ACKs needed to send for incoming packets diff --git a/ackhandler/sent_packet_handler.go b/ackhandler/sent_packet_handler.go index 663d3d0e..78d4839d 100644 --- a/ackhandler/sent_packet_handler.go +++ b/ackhandler/sent_packet_handler.go @@ -4,9 +4,11 @@ import ( "errors" "time" + "github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" + "github.com/lucas-clemente/quic-go/utils" ) var ( @@ -42,13 +44,28 @@ type sentPacketHandler struct { stopWaitingManager StopWaitingManager bytesInFlight protocol.ByteCount + + rttStats *congestion.RTTStats + congestion congestion.SendAlgorithm } // NewSentPacketHandler creates a new sentPacketHandler func NewSentPacketHandler(stopWaitingManager StopWaitingManager) SentPacketHandler { + rttStats := &congestion.RTTStats{} + + congestion := congestion.NewCubicSender( + congestion.DefaultClock{}, + rttStats, + false, /* don't use reno since chromium doesn't (why?) */ + protocol.InitialCongestionWindow, + protocol.DefaultMaxCongestionWindow, + ) + return &sentPacketHandler{ packetHistory: make(map[protocol.PacketNumber]*Packet), stopWaitingManager: stopWaitingManager, + rttStats: rttStats, + congestion: congestion, } } @@ -111,6 +128,15 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { packet.Entropy = h.lastSentPacketEntropy h.lastSentPacketNumber = packet.PacketNumber h.packetHistory[packet.PacketNumber] = packet + + h.congestion.OnPacketSent( + time.Now(), + h.BytesInFlight(), + packet.PacketNumber, + packet.Length, + true, /* TODO: is retransmittable */ + ) + return nil } @@ -144,39 +170,42 @@ func (h *sentPacketHandler) calculateExpectedEntropy(ackFrame *frames.AckFrame) } // TODO: Simplify return types -func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) (time.Duration, []*Packet, []*Packet, error) { +func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) error { if ackFrame.LargestObserved > h.lastSentPacketNumber { - return 0, nil, nil, errAckForUnsentPacket + return errAckForUnsentPacket } if ackFrame.LargestObserved <= h.LargestObserved { // duplicate or out-of-order AckFrame - return 0, nil, nil, ErrDuplicateOrOutOfOrderAck + return ErrDuplicateOrOutOfOrderAck } expectedEntropy, err := h.calculateExpectedEntropy(ackFrame) if err != nil { - return 0, nil, nil, err + return err } if byte(expectedEntropy) != ackFrame.Entropy { - return 0, nil, nil, ErrEntropy + return ErrEntropy } // Entropy ok. Now actually process the ACK packet h.LargestObserved = ackFrame.LargestObserved highestInOrderAckedPacketNumber := ackFrame.GetHighestInOrderPacketNumber() - // Calculate the RTT + // Update the RTT timeDelta := time.Now().Sub(h.packetHistory[h.LargestObserved].sendTime) + // TODO: Don't always update RTT + h.rttStats.UpdateRTT(timeDelta, ackFrame.DelayTime, time.Now()) + utils.Debugf("\tEstimated RTT: %dms", h.rttStats.SmoothedRTT()/time.Millisecond) - var ackedPackets []*Packet - var lostPackets []*Packet + var ackedPackets congestion.PacketVector + var lostPackets congestion.PacketVector // ACK all packets below the highestInOrderAckedPacketNumber for i := h.highestInOrderAckedPacketNumber; i <= highestInOrderAckedPacketNumber; i++ { p := h.ackPacket(i) if p != nil { - ackedPackets = append(ackedPackets, p) + ackedPackets = append(ackedPackets, congestion.PacketInfo{Number: p.PacketNumber, Length: p.Length}) } } @@ -193,15 +222,15 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) (time.Duratio if nackRange.ContainsPacketNumber(i) { p, err := h.nackPacket(i) if err != nil { - return 0, nil, nil, err + return err } if p != nil { - lostPackets = append(lostPackets, p) + lostPackets = append(lostPackets, congestion.PacketInfo{Number: p.PacketNumber, Length: p.Length}) } } else { p := h.ackPacket(i) if p != nil { - ackedPackets = append(ackedPackets, p) + ackedPackets = append(ackedPackets, congestion.PacketInfo{Number: p.PacketNumber, Length: p.Length}) } } } @@ -209,7 +238,14 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) (time.Duratio h.highestInOrderAckedPacketNumber = highestInOrderAckedPacketNumber - return timeDelta, ackedPackets, lostPackets, nil + h.congestion.OnCongestionEvent( + true, /* TODO: rtt updated */ + h.BytesInFlight(), + ackedPackets, + lostPackets, + ) + + return nil } func (h *sentPacketHandler) HasPacketForRetransmission() bool { @@ -238,3 +274,7 @@ func (h *sentPacketHandler) BytesInFlight() protocol.ByteCount { func (h *sentPacketHandler) GetLargestObserved() protocol.PacketNumber { return h.LargestObserved } + +func (h *sentPacketHandler) AllowsSending() bool { + return h.BytesInFlight() <= h.congestion.GetCongestionWindow() +} diff --git a/ackhandler/sent_packet_handler_test.go b/ackhandler/sent_packet_handler_test.go index 1e228ff7..6df2aa42 100644 --- a/ackhandler/sent_packet_handler_test.go +++ b/ackhandler/sent_packet_handler_test.go @@ -3,15 +3,51 @@ package ackhandler import ( "time" + "github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) +type mockCongestion struct { + nCalls int + argsOnPacketSent []interface{} + argsOnCongestionEvent []interface{} +} + +func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration { + panic("not implemented") +} + +func (m *mockCongestion) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool { + m.nCalls++ + m.argsOnPacketSent = []interface{}{sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable} + return false +} + +func (m *mockCongestion) GetCongestionWindow() protocol.ByteCount { + m.nCalls++ + return protocol.DefaultTCPMSS +} + +func (m *mockCongestion) OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets congestion.PacketVector, lostPackets congestion.PacketVector) { + m.nCalls++ + m.argsOnCongestionEvent = []interface{}{rttUpdated, bytesInFlight, ackedPackets, lostPackets} +} + +func (m *mockCongestion) SetNumEmulatedConnections(n int) { panic("not implemented") } +func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) { panic("not implemented") } +func (m *mockCongestion) OnConnectionMigration() { panic("not implemented") } +func (m *mockCongestion) RetransmissionDelay() time.Duration { panic("not implemented") } +func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { panic("not implemented") } + var _ = Describe("SentPacketHandler", func() { - var handler *sentPacketHandler - var streamFrame frames.StreamFrame + var ( + handler *sentPacketHandler + streamFrame frames.StreamFrame + ) + BeforeEach(func() { stopWaitingManager := NewStopWaitingManager() handler = NewSentPacketHandler(stopWaitingManager).(*sentPacketHandler) @@ -88,7 +124,7 @@ var _ = Describe("SentPacketHandler", func() { LargestObserved: 1, Entropy: byte(entropy), } - _, _, _, err = handler.ReceivedAck(&ack) + err = handler.ReceivedAck(&ack) Expect(err).ToNot(HaveOccurred()) Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(0))) err = handler.SentPacket(&packet2) @@ -200,7 +236,7 @@ var _ = Describe("SentPacketHandler", func() { LargestObserved: 4, Entropy: 1, } - _, _, _, err := handler.ReceivedAck(&ack) + err := handler.ReceivedAck(&ack) Expect(err).To(MatchError(ErrEntropy)) Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(6))) }) @@ -215,19 +251,13 @@ var _ = Describe("SentPacketHandler", func() { LargestObserved: protocol.PacketNumber(largestObserved), Entropy: byte(entropy), } - _, acked, lost, err := handler.ReceivedAck(&ack) + err := handler.ReceivedAck(&ack) Expect(err).ToNot(HaveOccurred()) Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(2))) Expect(handler.LargestObserved).To(Equal(protocol.PacketNumber(largestObserved))) Expect(handler.highestInOrderAckedPacketNumber).To(Equal(protocol.PacketNumber(largestObserved))) Expect(handler.packetHistory).ToNot(HaveKey(protocol.PacketNumber(largestObserved - 1))) Expect(handler.packetHistory).To(HaveKey(protocol.PacketNumber(largestObserved + 1))) - Expect(acked).To(HaveLen(4)) - Expect(acked[0].PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(acked[1].PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(acked[2].PacketNumber).To(Equal(protocol.PacketNumber(3))) - Expect(acked[3].PacketNumber).To(Equal(protocol.PacketNumber(4))) - Expect(lost).To(BeEmpty()) }) It("completely processes an ACK with a NACK range", func() { @@ -247,7 +277,7 @@ var _ = Describe("SentPacketHandler", func() { {FirstPacketNumber: 3, LastPacketNumber: 3}, }, } - _, acked, lost, err := handler.ReceivedAck(&ack) + err := handler.ReceivedAck(&ack) Expect(err).ToNot(HaveOccurred()) Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(2))) Expect(handler.LargestObserved).To(Equal(protocol.PacketNumber(largestObserved))) @@ -257,12 +287,6 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.packetHistory).ToNot(HaveKey(protocol.PacketNumber(4))) Expect(handler.packetHistory).To(HaveKey(protocol.PacketNumber(5))) Expect(handler.packetHistory).ToNot(HaveKey(protocol.PacketNumber(6))) - Expect(acked).To(HaveLen(4)) - Expect(acked[0].PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(acked[1].PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(acked[2].PacketNumber).To(Equal(protocol.PacketNumber(6))) - Expect(acked[3].PacketNumber).To(Equal(protocol.PacketNumber(4))) - Expect(lost).To(BeEmpty()) }) }) @@ -290,10 +314,10 @@ var _ = Describe("SentPacketHandler", func() { ack := frames.AckFrame{ LargestObserved: protocol.PacketNumber(largestObserved), } - _, _, _, err := handler.ReceivedAck(&ack) + err := handler.ReceivedAck(&ack) Expect(err).ToNot(HaveOccurred()) Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(3))) - _, _, _, err = handler.ReceivedAck(&ack) + err = handler.ReceivedAck(&ack) Expect(err).To(MatchError(ErrDuplicateOrOutOfOrderAck)) Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(3))) }) @@ -303,11 +327,11 @@ var _ = Describe("SentPacketHandler", func() { ack := frames.AckFrame{ LargestObserved: protocol.PacketNumber(largestObserved), } - _, _, _, err := handler.ReceivedAck(&ack) + err := handler.ReceivedAck(&ack) Expect(err).ToNot(HaveOccurred()) Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(3))) ack.LargestObserved-- - _, _, _, err = handler.ReceivedAck(&ack) + err = handler.ReceivedAck(&ack) Expect(err).To(MatchError(ErrDuplicateOrOutOfOrderAck)) Expect(handler.LargestObserved).To(Equal(protocol.PacketNumber(largestObserved))) Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(3))) @@ -317,29 +341,39 @@ var _ = Describe("SentPacketHandler", func() { ack := frames.AckFrame{ LargestObserved: packets[len(packets)-1].PacketNumber + 1337, } - _, _, _, err := handler.ReceivedAck(&ack) + err := handler.ReceivedAck(&ack) Expect(err).To(MatchError(errAckForUnsentPacket)) Expect(handler.highestInOrderAckedPacketNumber).To(Equal(protocol.PacketNumber(0))) Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(6))) }) }) - It("calculates the time delta", func() { - now := time.Now() - // First, fake the sent times of the first, second and last packet - handler.packetHistory[1].sendTime = now.Add(-10 * time.Minute) - handler.packetHistory[2].sendTime = now.Add(-5 * time.Minute) - handler.packetHistory[6].sendTime = now.Add(-1 * time.Minute) - // Now, check that the proper times are used when calculating the deltas - d, _, _, err := handler.ReceivedAck(&frames.AckFrame{LargestObserved: 1}) - Expect(err).NotTo(HaveOccurred()) - Expect(d).To(BeNumerically("~", 10*time.Minute, 1*time.Second)) - d, _, _, err = handler.ReceivedAck(&frames.AckFrame{LargestObserved: 2}) - Expect(err).NotTo(HaveOccurred()) - Expect(d).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) - d, _, _, err = handler.ReceivedAck(&frames.AckFrame{LargestObserved: 6}) - Expect(err).NotTo(HaveOccurred()) - Expect(d).To(BeNumerically("~", 1*time.Minute, 1*time.Second)) + Context("calculating RTT", func() { + It("calculates the RTT", func() { + now := time.Now() + // First, fake the sent times of the first, second and last packet + handler.packetHistory[1].sendTime = now.Add(-10 * time.Minute) + handler.packetHistory[2].sendTime = now.Add(-5 * time.Minute) + handler.packetHistory[6].sendTime = now.Add(-1 * time.Minute) + // Now, check that the proper times are used when calculating the deltas + err := handler.ReceivedAck(&frames.AckFrame{LargestObserved: 1}) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 10*time.Minute, 1*time.Second)) + err = handler.ReceivedAck(&frames.AckFrame{LargestObserved: 2}) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) + err = handler.ReceivedAck(&frames.AckFrame{LargestObserved: 6}) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 1*time.Minute, 1*time.Second)) + }) + + It("uses the DelayTime in the ack frame", func() { + now := time.Now() + handler.packetHistory[1].sendTime = now.Add(-10 * time.Minute) + err := handler.ReceivedAck(&frames.AckFrame{LargestObserved: 1, DelayTime: 5 * time.Minute}) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) + }) }) }) @@ -411,7 +445,7 @@ var _ = Describe("SentPacketHandler", func() { LargestObserved: 4, NackRanges: []frames.NackRange{{FirstPacketNumber: 3, LastPacketNumber: 3}}, } - _, _, _, err := handler.ReceivedAck(&ack) + err := handler.ReceivedAck(&ack) Expect(err).ToNot(HaveOccurred()) Expect(handler.highestInOrderAckedPacketNumber).To(Equal(protocol.PacketNumber(2))) handler.nackPacket(3) // this is the second NACK for this packet @@ -434,7 +468,7 @@ var _ = Describe("SentPacketHandler", func() { LargestObserved: 2, NackRanges: []frames.NackRange{{FirstPacketNumber: 1, LastPacketNumber: 1}}, } - _, _, _, err = handler.ReceivedAck(&ack) + err = handler.ReceivedAck(&ack) Expect(err).NotTo(HaveOccurred()) Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(1))) @@ -453,32 +487,70 @@ var _ = Describe("SentPacketHandler", func() { ack = frames.AckFrame{ LargestObserved: 3, } - _, _, _, err = handler.ReceivedAck(&ack) + err = handler.ReceivedAck(&ack) Expect(err).NotTo(HaveOccurred()) Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(0))) }) }) - It("returns lost packets in ReceivedAck()", func() { - packet1 := Packet{PacketNumber: 1, Frames: []frames.Frame{&streamFrame}, EntropyBit: false, Length: 1} - packet2 := Packet{PacketNumber: 2, Frames: []frames.Frame{&streamFrame}, EntropyBit: false, Length: 2} - err := handler.SentPacket(&packet1) - Expect(err).NotTo(HaveOccurred()) - err = handler.SentPacket(&packet2) - Expect(err).NotTo(HaveOccurred()) + Context("congestion", func() { + var ( + cong *mockCongestion + ) - // First, simulate a NACK for packet number 1 - handler.nackPacket(1) - // Now, simulate an ack frame - ack := &frames.AckFrame{ - LargestObserved: 2, - NackRanges: []frames.NackRange{{FirstPacketNumber: 1, LastPacketNumber: 1}}, - } - _, acked, lost, err := handler.ReceivedAck(ack) - Expect(err).NotTo(HaveOccurred()) - Expect(acked).To(HaveLen(1)) - Expect(acked[0].PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(lost).To(HaveLen(1)) - Expect(lost[0].PacketNumber).To(Equal(protocol.PacketNumber(1))) + BeforeEach(func() { + cong = &mockCongestion{} + handler.congestion = cong + }) + + It("should call OnSent", func() { + p := &Packet{ + PacketNumber: 1, + Frames: []frames.Frame{&frames.StreamFrame{StreamID: 5}}, + Length: 42, + } + err := handler.SentPacket(p) + Expect(err).NotTo(HaveOccurred()) + Expect(cong.nCalls).To(Equal(1)) + Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(42))) + Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1))) + Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(42))) + Expect(cong.argsOnPacketSent[4]).To(BeTrue()) + }) + + It("should call OnCongestionEvent", func() { + handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) + handler.SentPacket(&Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 2}) + handler.SentPacket(&Packet{PacketNumber: 3, Frames: []frames.Frame{}, Length: 3}) + err := handler.ReceivedAck(&frames.AckFrame{ + LargestObserved: 3, + NackRanges: []frames.NackRange{{2, 2}}, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(cong.nCalls).To(Equal(4)) // 3 * SentPacket + 1 * ReceivedAck + // rttUpdated, bytesInFlight, ackedPackets, lostPackets + Expect(cong.argsOnCongestionEvent[0]).To(BeTrue()) + Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(2))) + Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{1, 1}, {3, 3}})) + Expect(cong.argsOnCongestionEvent[3]).To(BeEmpty()) + + // Loose the packet + handler.SentPacket(&Packet{PacketNumber: 4, Frames: []frames.Frame{}, Length: 4}) + err = handler.ReceivedAck(&frames.AckFrame{ + LargestObserved: 4, + NackRanges: []frames.NackRange{{2, 2}}, + }) + Expect(err).NotTo(HaveOccurred()) + + Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{4, 4}})) + Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{{2, 2}})) + }) + + It("allows or denies sending", func() { + Expect(handler.AllowsSending()).To(BeTrue()) + err := handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: protocol.DefaultTCPMSS + 1}) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.AllowsSending()).To(BeFalse()) + }) }) }) diff --git a/packet_packer_test.go b/packet_packer_test.go index 027545ba..798bbf2b 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -2,7 +2,6 @@ package quic import ( "bytes" - "time" "github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/crypto" @@ -15,29 +14,13 @@ import ( type mockSentPacketHandler struct{} -func (h *mockSentPacketHandler) SentPacket(packet *ackhandler.Packet) error { - return nil -} - -func (h *mockSentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) (time.Duration, []*ackhandler.Packet, []*ackhandler.Packet, error) { - return 0, nil, nil, nil -} - -func (h *mockSentPacketHandler) DequeuePacketForRetransmission() (packet *ackhandler.Packet) { - return nil -} - -func (h *mockSentPacketHandler) HasPacketForRetransmission() bool { - return false -} - -func (h *mockSentPacketHandler) BytesInFlight() protocol.ByteCount { - return 0 -} - -func (h *mockSentPacketHandler) GetLargestObserved() protocol.PacketNumber { - return 1 -} +func (h *mockSentPacketHandler) SentPacket(packet *ackhandler.Packet) error { return nil } +func (h *mockSentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) error { return nil } +func (h *mockSentPacketHandler) DequeuePacketForRetransmission() *ackhandler.Packet { return nil } +func (h *mockSentPacketHandler) HasPacketForRetransmission() bool { return false } +func (h *mockSentPacketHandler) BytesInFlight() protocol.ByteCount { return 0 } +func (h *mockSentPacketHandler) GetLargestObserved() protocol.PacketNumber { return 1 } +func (h *mockSentPacketHandler) AllowsSending() bool { panic("not implemented") } func newMockSentPacketHandler() ackhandler.SentPacketHandler { return &mockSentPacketHandler{} diff --git a/session.go b/session.go index 30eba1c8..0ecedceb 100644 --- a/session.go +++ b/session.go @@ -9,7 +9,6 @@ import ( "time" "github.com/lucas-clemente/quic-go/ackhandler" - "github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" @@ -76,9 +75,6 @@ type Session struct { // Used to calculate the next packet number from the truncated wire // representation, and sent back in public reset packets lastRcvdPacketNumber protocol.PacketNumber - - rttStats congestion.RTTStats - congestion congestion.SendAlgorithm } // newSession makes a new session @@ -100,7 +96,6 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol receivedPackets: make(chan receivedPacket, protocol.MaxSessionUnprocessedPackets), closeChan: make(chan struct{}, 1), sendingScheduled: make(chan struct{}, 1), - rttStats: congestion.RTTStats{}, connectionParametersManager: connectionParametersManager, undecryptablePackets: make([]receivedPacket, 0, protocol.MaxUndecryptablePackets), aeadChanged: make(chan struct{}, 1), @@ -122,14 +117,6 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol } session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: v} - session.congestion = congestion.NewCubicSender( - congestion.DefaultClock{}, - &session.rttStats, - false, /* don't use reno since chromium doesn't (why?) */ - protocol.InitialCongestionWindow, - protocol.DefaultMaxCongestionWindow, - ) - return session, err } @@ -336,33 +323,12 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { } func (s *Session) handleAckFrame(frame *frames.AckFrame) error { - duration, acked, lost, err := s.sentPacketHandler.ReceivedAck(frame) - if err != nil { + + if err := s.sentPacketHandler.ReceivedAck(frame); err != nil { return err } - // TODO: Don't always update RTT - s.rttStats.UpdateRTT(duration, frame.DelayTime, time.Now()) - - cAcked := make(congestion.PacketVector, len(acked)) - for i, v := range acked { - cAcked[i].Number = v.PacketNumber - cAcked[i].Length = v.Length - } - cLost := make(congestion.PacketVector, len(lost)) - for i, v := range lost { - cLost[i].Number = v.PacketNumber - cLost[i].Length = v.Length - } - s.congestion.OnCongestionEvent( - true, /* rtt updated */ - s.sentPacketHandler.BytesInFlight(), - cAcked, - cLost, - ) - utils.Debugf("\t<- %#v", frame) - utils.Debugf("\tEstimated RTT: %dms", s.rttStats.SmoothedRTT()/time.Millisecond) return nil } @@ -409,7 +375,7 @@ func (s *Session) closeStreamsWithError(err error) { } func (s *Session) maybeSendPacket() error { - if !s.congestionAllowsSending() { + if !s.sentPacketHandler.AllowsSending() { return nil } @@ -452,7 +418,7 @@ func (s *Session) maybeSendPacket() error { } func (s *Session) sendPacket() error { - if !s.congestionAllowsSending() { + if !s.sentPacketHandler.AllowsSending() { return nil } @@ -514,14 +480,6 @@ func (s *Session) sendPacket() error { return err } - s.congestion.OnPacketSent( - time.Now(), - s.sentPacketHandler.BytesInFlight(), - packet.number, - protocol.ByteCount(len(packet.raw)), - true, /* TODO: is retransmittable */ - ) - s.stopWaitingManager.SentStopWaitingWithPacket(packet.number) utils.Debugf("-> Sending packet 0x%x (%d bytes)", packet.number, len(packet.raw)) @@ -635,10 +593,6 @@ func (s *Session) scheduleSending() { } } -func (s *Session) congestionAllowsSending() bool { - return s.sentPacketHandler.BytesInFlight() <= s.congestion.GetCongestionWindow() -} - func (s *Session) tryQueueingUndecryptablePacket(p receivedPacket) { utils.Debugf("Queueing packet 0x%x for later decryption", p.publicHeader.PacketNumber) if len(s.undecryptablePackets)+1 >= protocol.MaxUndecryptablePackets { diff --git a/session_test.go b/session_test.go index 6e3b3ad6..43166a7e 100644 --- a/session_test.go +++ b/session_test.go @@ -12,7 +12,6 @@ import ( . "github.com/onsi/gomega" "github.com/lucas-clemente/quic-go/ackhandler" - "github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/handshake" @@ -33,52 +32,6 @@ func (m *mockConnection) write(p []byte) error { func (*mockConnection) setCurrentRemoteAddr(addr interface{}) {} -type mockCongestion struct { - nCalls int - argsOnPacketSent []interface{} - argsOnCongestionEvent []interface{} -} - -func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration { - panic("not implemented") -} - -func (m *mockCongestion) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool { - m.nCalls++ - m.argsOnPacketSent = []interface{}{sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable} - return false -} - -func (m *mockCongestion) GetCongestionWindow() protocol.ByteCount { - m.nCalls++ - return protocol.DefaultTCPMSS -} - -func (m *mockCongestion) OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets congestion.PacketVector, lostPackets congestion.PacketVector) { - m.nCalls++ - m.argsOnCongestionEvent = []interface{}{rttUpdated, bytesInFlight, ackedPackets, lostPackets} -} - -func (m *mockCongestion) SetNumEmulatedConnections(n int) { - panic("not implemented") -} - -func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) { - panic("not implemented") -} - -func (m *mockCongestion) OnConnectionMigration() { - panic("not implemented") -} - -func (m *mockCongestion) RetransmissionDelay() time.Duration { - panic("not implemented") -} - -func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { - panic("not implemented") -} - var _ = Describe("Session", func() { var ( session *Session @@ -604,68 +557,7 @@ var _ = Describe("Session", func() { close(done) }, 0.5) - Context("congestion", func() { - var ( - cong *mockCongestion - ) - - BeforeEach(func() { - cong = &mockCongestion{} - session.congestion = cong - }) - - It("should call OnSent", func() { - session.queueStreamFrame(&frames.StreamFrame{StreamID: 5}) - session.sendPacket() - Expect(cong.nCalls).To(Equal(2)) // OnPacketSent + GetCongestionWindow - Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(25))) - Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1))) - Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(25))) - Expect(cong.argsOnPacketSent[4]).To(BeTrue()) - }) - - It("should call OnCongestionEvent", func() { - session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) - session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 2}) - session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 3, Frames: []frames.Frame{}, Length: 3}) - err := session.handleAckFrame(&frames.AckFrame{ - LargestObserved: 3, - NackRanges: []frames.NackRange{{2, 2}}, - }) - Expect(err).NotTo(HaveOccurred()) - Expect(cong.nCalls).To(Equal(1)) - //(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets cong.PacketVector, lostPackets cong.PacketVector) - Expect(cong.argsOnCongestionEvent[0]).To(BeTrue()) - Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(2))) - Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{1, 1}, {3, 3}})) - Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{})) - - // Loose the packet - session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 4, Frames: []frames.Frame{}, Length: 4}) - err = session.handleAckFrame(&frames.AckFrame{ - LargestObserved: 4, - NackRanges: []frames.NackRange{{2, 2}}, - }) - Expect(err).NotTo(HaveOccurred()) - session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 5, Frames: []frames.Frame{}, Length: 5}) - err = session.handleAckFrame(&frames.AckFrame{ - LargestObserved: 5, - NackRanges: []frames.NackRange{{2, 2}}, - }) - Expect(err).NotTo(HaveOccurred()) - session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 6, Frames: []frames.Frame{}, Length: 6}) - err = session.handleAckFrame(&frames.AckFrame{ - LargestObserved: 6, - NackRanges: []frames.NackRange{{2, 2}}, - }) - Expect(err).NotTo(HaveOccurred()) - - Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{6, 6}})) - Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{{2, 2}})) - }) - }) - - It("stored up to MaxSessionUnprocessedPackets packets", func(done Done) { + It("stores up to MaxSessionUnprocessedPackets packets", func(done Done) { // Nothing here should block for i := 0; i < protocol.MaxSessionUnprocessedPackets+10; i++ { session.handlePacket(nil, nil, nil)